How to create a custom Retriever
Overviewβ
Many LLM applications involve retrieving information from external data sources using a Retriever.
A retriever is responsible for retrieving a list of relevant Documents to a given user query.
The retrieved documents are often formatted into prompts that are fed into an LLM, allowing the LLM to use the information in the to generate an appropriate response (e.g., answering a user question based on a knowledge base).
Interfaceβ
To create your own retriever, you need to extend the BaseRetriever class and implement the following methods:
| Method | Description | Required/Optional |
|---|---|---|
_get_relevant_documents | Get documents relevant to a query. | Required |
_aget_relevant_documents | Implement to provide async native support. | Optional |
The logic inside of _get_relevant_documents can involve arbitrary calls to a database or to the web using requests.
By inherting from BaseRetriever, your retriever automatically becomes a LangChain Runnable and will gain the standard Runnable functionality out of the box!
You can use a RunnableLambda or RunnableGenerator to implement a retriever.
The main benefit of implementing a retriever as a BaseRetriever vs. a RunnableLambda (a custom runnable function) is that a BaseRetriever is a well
known LangChain entity so some tooling for monitoring may implement specialized behavior for retrievers. Another difference
is that a BaseRetriever will behave slightly differently from RunnableLambda in some APIs; e.g., the start event
in astream_events API will be on_retriever_start instead of on_chain_start.
Exampleβ
Let's implement a toy retriever that returns all documents whose text contains the text in the user query.
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ToyRetriever(BaseRetriever):
"""A toy retriever that contains the top k documents that contain the user query.
This retriever only implements the sync method _get_relevant_documents.
If the retriever were to involve file access or network access, it could benefit
from a native async implementation of `_aget_relevant_documents`.
As usual, with Runnables, there's a default async implementation that's provided
that delegates to the sync implementation running on another thread.
"""
documents: List[Document]
"""List of documents to retrieve from."""
k: int
"""Number of top results to return"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Sync implementations for retriever."""
matching_documents = []
for document in self.documents:
if len(matching_documents) > self.k:
return matching_documents
if query.lower() in document.page_content.lower():
matching_documents.append(document)
return matching_documents
# Optional: Provide a more efficient native implementation by overriding
# _aget_relevant_documents
# async def _aget_relevant_documents(
# self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
# ) -> List[Document]:
# """Asynchronously get documents relevant to a query.
# Args:
# query: String to find relevant documents for
# run_manager: The callbacks handler to use
# Returns:
# List of relevant documents
# """