Using LangChain for Inference¶
Introduction¶
After fine-tuning your RAG system to achieve desired performance, you'll want to
deploy it for inference. While FedRAG's RAGSystem
provides complete inference
capabilities out of the box, you may need additional features for production deployments
or want to leverage the ecosystem of existing RAG frameworks.
FedRAG offers a seamless integration into LangChain through our bridges system, giving you the best of both worlds: FedRAG's fine-tuning capabilities combined with the extensive inference features of LangChain.
In this example, we demonstrate how you can convert a RAGSystem
into a tuple consisting of ~langchain_core.vectorstores.VectorStore
and ~langchain_core.language_models.BaseLLM
. The former can then be transformed into a ~langchain_core.vectorestores.VectorStoreRetriever
using the as_retriever()
method, enabling the creation of a complete QA pipeline whith LangChain's LCEL.
NOTE: Streaming and async functionalities are not yet supported.
Install dependencies¶
# If running in a Google Colab, the first attempt at installing fed-rag may fail,
# though for reasons unknown to me yet, if you try a second time, it magically works...
!pip install fed-rag[huggingface,langchain] -q
Setup — The RAG System¶
import torch
from transformers.generation.configuration_utils import GenerationConfig
from fed_rag import RAGSystem, RAGConfig
from fed_rag.generators.huggingface import HFPretrainedModelGenerator
from fed_rag.retrievers.huggingface import (
HFSentenceTransformerRetriever,
)
from fed_rag.knowledge_stores import InMemoryKnowledgeStore
from fed_rag.data_structures import KnowledgeNode
QUERY_ENCODER_NAME = "nthakur/dragon-plus-query-encoder"
CONTEXT_ENCODER_NAME = "nthakur/dragon-plus-context-encoder"
PRETRAINED_MODEL_NAME = "Qwen/Qwen3-0.6B"
# Retriever
retriever = HFSentenceTransformerRetriever(
query_model_name=QUERY_ENCODER_NAME,
context_model_name=CONTEXT_ENCODER_NAME,
load_model_at_init=False,
)
# Generator
generation_cfg = GenerationConfig(
do_sample=True,
eos_token_id=151643,
bos_token_id=151643,
max_new_tokens=2048,
top_p=0.9,
temperature=0.6,
cache_implementation="offloaded",
stop_strings="</response>",
)
generator = HFPretrainedModelGenerator(
model_name=PRETRAINED_MODEL_NAME,
load_model_at_init=False,
load_model_kwargs={"device_map": "auto", "torch_dtype": torch.float16},
generation_config=generation_cfg,
)
# Knowledge store
knowledge_store = InMemoryKnowledgeStore()
# Create the RAG system
rag_system = RAGSystem(
retriever=retriever,
generator=generator,
knowledge_store=knowledge_store,
rag_config=RAGConfig(top_k=1),
)
Add some knowledge¶
text_chunks = [
"Retrieval-Augmented Generation (RAG) combines retrieval with generation.",
"LLMs can hallucinate information when they lack context.",
]
knowledge_nodes = [
KnowledgeNode(
node_type="text",
embedding=retriever.encode_context(ct).tolist(),
text_content=ct,
)
for ct in text_chunks
]
knowledge_store.load_nodes(knowledge_nodes)
rag_system.knowledge_store.count
2
Using the Bridge¶
Converting your RAG system to LangChain objects is seamless since the bridge
functionality is already built into the RAGSystem
class. The RAGSystem
inherits
from LangChainBridgeMixin
, which provides the to_langchain()
method for
effortless conversion.
NOTE: The to_langchain()
method returns a tuple consisting of FedRAGVectorStore
and FedRAGLLM
objects, which are custom implementation of the ~langchain_core.vectorstores.VectorStore
and ~langchain_core.language_models.BaseLLM
classes.
# Create the LangChain objects
vector_store, llm = rag_system.to_langchain()
# Search the vectore store directly
query = "What happens if LLMs lack context?"
results = vector_store.similarity_search_with_score(query, k=2)
for doc, score in results:
print(f"Content: {doc.page_content}, Score: {score}")
print("-" * 80)
# Or, convert it to a retriever
retriever = vector_store.as_retriever()
results = retriever.invoke(query)
for doc in results:
print(f"Content: {doc.page_content}")
print("-" * 80)
# Or, create a complete RAG chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from fed_rag.base.generator import DEFAULT_PROMPT_TEMPLATE
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
qa_chain = (
{
"context": vector_store.as_retriever() | format_docs,
"query": RunnablePassthrough(),
}
| PromptTemplate.from_template(DEFAULT_PROMPT_TEMPLATE)
| llm
| StrOutputParser()
)
response = qa_chain.invoke("What are autonomous agents?")
print(response)
Content: LLMs can hallucinate information when they lack context., Score: 0.5453173113645673 Content: Retrieval-Augmented Generation (RAG) combines retrieval with generation., Score: 0.5065647593667755 -------------------------------------------------------------------------------- Content: LLMs can hallucinate information when they lack context. Content: Retrieval-Augmented Generation (RAG) combines retrieval with generation. --------------------------------------------------------------------------------
tokenizer_config.json: 0%| | 0.00/9.73k [00:00<?, ?B/s]
vocab.json: 0%| | 0.00/2.78M [00:00<?, ?B/s]
merges.txt: 0%| | 0.00/1.67M [00:00<?, ?B/s]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]
config.json: 0%| | 0.00/726 [00:00<?, ?B/s]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
model.safetensors: 0%| | 0.00/1.50G [00:00<?, ?B/s]
generation_config.json: 0%| | 0.00/239 [00:00<?, ?B/s]
`generation_config` default values have been modified to match model-specific defaults: {'top_k': 20, 'pad_token_id': 151643}. If this is not desired, please set these values explicitly. The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
An autonomous agent is a system designed to perform specific tasks without direct human intervention. It uses AI to process information and make decisions based on that information. Autonomous agents can be used in various applications, such as healthcare, finance, and logistics, to improve efficiency and accuracy. </response>
Modifying Knowledge¶
In addition to querying the bridged index, you can also make changes to the underlying KnowledgeStore using LangChains's API:
ids = vector_store.add_texts(
texts=["some arbitrary text", "some other arbitrary text"],
metadatas=[{"source": "fed-rag"}, {"source": "fed-rag"}],
)
ids
['6797e6f0-5501-447b-ad8a-f256f97bde9b', 'aa0af869-44fa-4c50-b151-eb4c46d311e3']
# confirm that what we added above is indeed in the knowledge store
rag_system.knowledge_store.count
4
# you can also delete nodes
vector_store.delete(ids)
True
# confirm that what we deleted above is indeed removed from the knowledge store
rag_system.knowledge_store.count
2
Bridge Metadata¶
To view the metadata of the LangChain bridge, you can access the class attribute
bridge
of the RAGSystem
class, which is a dictionary object that contains the BridgeMetadata
for all of the installed bridges.
# see available bridges
print(RAGSystem.bridges)
# see the LangChain bridge metadata
print(RAGSystem.bridges["langchain-core"])
{'llama-index-core': {'bridge_version': '0.1.0', 'framework': 'llama-index-core', 'compatible_versions': {'min': '0.12.35'}, 'method_name': 'to_llamaindex'}, 'langchain-core': {'bridge_version': '0.1.0', 'framework': 'langchain-core', 'compatible_versions': {'min': '0.3.62'}, 'method_name': 'to_langchain'}} {'bridge_version': '0.1.0', 'framework': 'langchain-core', 'compatible_versions': {'min': '0.3.62'}, 'method_name': 'to_langchain'}