🦥 Using Unsloth FastModels as your RAG Generator Model¶
Introduction¶
As of v0.0.20
, the fed-rag
library includes seamless integration with Unsloth.ai, a popular open-source library that dramatically accelerates fine-tuning workflows. This integration allows you to use ~unsloth.FastLanguageModel
instances as generator models in your RAG system while fully leveraging Unsloth's efficient fine-tuning capabilities.
In this notebook, we demonstrate how to define a UnslothFastModelGenerator
, integrate it into a RAG system, and fine-tune it using our GeneratorTrainers
.
NOTE: This notebook takes inspiration from Unsloth's cookbook.ipynb), for fine-tuning Gemma3 4B—we'll use that exact same model as our generator in our RAG system. The key difference is that we're fine-tuning the model specifically for retrieval-augmented generation tasks using our fed-rag
framework.
!uv pip install fed-rag[huggingface,unsloth] -q
!uv pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" -q
!uv pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes -q
import unsloth
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning. 🦥 Unsloth Zoo will now patch everything to make training faster!
Creating an UnslothFastModelGenerator
¶
from fed_rag.generators import UnslothFastModelGenerator
from transformers.generation.utils import GenerationConfig
generation_cfg = GenerationConfig(
do_sample=True,
eos_token_id=[1, 106],
bos_token_id=2,
max_new_tokens=2048,
pad_token_id=0,
top_p=0.95,
top_k=64,
temperature=0.6,
cache_implementation="offloaded",
)
unsloth_load_kwargs = {
"max_seq_length": 2048, # Choose any for long context!
"load_in_4bit": True,
"load_in_8bit": False, # [NEW!] A bit more accurate, uses 2x memory
"full_finetuning": False, # [NEW!] We have full finetuning now!
}
generator = UnslothFastModelGenerator(
model_name="unsloth/gemma-3-4b-it",
load_model_kwargs=unsloth_load_kwargs,
generation_config=generation_cfg,
)
==((====))== Unsloth 2025.5.7: Fast Gemma3 patching. Transformers: 4.51.3. \\ /| NVIDIA A40. Num GPUs = 1. Max memory: 44.448 GB. Platform: Linux. O^O/ \_/ \ Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0 \ / Bfloat16 = TRUE. FA [Xformers = None. FA2 = False] "-____-" Free license: http://github.com/unslothai/unsloth Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
generator.model.dtype
torch.bfloat16
Give it a spin¶
response = generator.generate(query="What is a Tulip?", context="")
print(response)
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
You are a helpful assistant. Given the user's query, provide a succinct and accurate response. If context is provided, use it in your answer if it helps you to create the most accurate response. <query> What is a Tulip? </query> <context> </context> <response> A tulip is a flowering plant in the genus *Tulipa*, native to Central Asia and Turkey. They are known for their cup-shaped flowers and are often associated with spring. </response>
In Unsloth's Gemma 3 (4B) cookbook.ipynb), they demonstrate how to use ~transformers.TextStreamer
to stream generation output in real-time rather than waiting for completion. We can apply the same technique here.
from transformers import TextStreamer
generator.generate(
query="What is a Porshe?",
context="",
streamer=TextStreamer(generator.tokenizer.unwrapped, skip_prompt=True),
)
Porsche is a German automobile manufacturer known for its high-performance sports cars, luxury vehicles, and SUVs. The company was founded in 1931 by Ferdinand Porsche. </response> <end_of_turn>
"\nYou are a helpful assistant. Given the user's query, provide a succinct\nand accurate response. If context is provided, use it in your answer if it helps\nyou to create the most accurate response.\n\n<query>\nWhat is a Porshe?\n</query>\n\n<context>\n\n</context>\n\n<response>\n\nPorsche is a German automobile manufacturer known for its high-performance sports cars, luxury vehicles, and SUVs. The company was founded in 1931 by Ferdinand Porsche.\n\n</response>\n"
Let's Build the Rest of our RAG System¶
Define our Retriever and Knowledge Store¶
import torch
from fed_rag import RAGSystem, RAGConfig
from fed_rag.retrievers.huggingface import (
HFSentenceTransformerRetriever,
)
from fed_rag.knowledge_stores import InMemoryKnowledgeStore
from fed_rag.data_structures import KnowledgeNode, NodeType
QUERY_ENCODER_NAME = "nthakur/dragon-plus-query-encoder"
CONTEXT_ENCODER_NAME = "nthakur/dragon-plus-context-encoder"
PRETRAINED_MODEL_NAME = "Qwen/Qwen3-0.6B"
# Retriever
retriever = HFSentenceTransformerRetriever(
query_model_name=QUERY_ENCODER_NAME,
context_model_name=CONTEXT_ENCODER_NAME,
load_model_at_init=False,
)
# Knowledge store
knowledge_store = InMemoryKnowledgeStore()
Add some knowledge to the store¶
text_chunks = [
"Retrieval-Augmented Generation (RAG) combines retrieval with generation.",
"LLMs can hallucinate information when they lack context.",
]
knowledge_nodes = [
KnowledgeNode(
node_type="text",
embedding=retriever.encode_context(ct).tolist(),
text_content=ct,
)
for ct in text_chunks
]
knowledge_store.load_nodes(knowledge_nodes)
knowledge_store.count
2
Assemble the RAG system¶
# Create the RAG system
rag_system = RAGSystem(
retriever=retriever,
generator=generator,
knowledge_store=knowledge_store,
rag_config=RAGConfig(top_k=1),
)
Let's first add our LoRA adapters¶
In order to do so, we use the to_peft()
method, which under the hood will call the FastModel.get_peft_model()
to build the PeftModel
, and then set it as this generators model. In other words, the underlying model is currently a PreTrainedModel
, but after executing the next cell, it will be a PeftModel
.
generator.to_peft(
finetune_vision_layers=False, # Turn off for just text!
finetune_language_layers=True, # Should leave on!
finetune_attention_modules=True, # Attention good for GRPO
finetune_mlp_modules=True, # SHould leave on always!
r=8, # Larger = higher accuracy, but might overfit
lora_alpha=8, # Recommended alpha == r at least
lora_dropout=0,
bias="none",
random_state=3407,
)
Unsloth: Making `base_model.model.vision_tower.vision_model` require gradients
UnslothFastModelGenerator(model_name='unsloth/gemma-3-4b-it', generation_config=GenerationConfig { "bos_token_id": 2, "cache_implementation": "hybrid", "do_sample": true, "eos_token_id": [ 1, 106 ], "max_new_tokens": 2048, "pad_token_id": 0, "temperature": 0.6, "top_k": 64, "top_p": 0.95 } , load_model_kwargs={'max_seq_length': 2048, 'load_in_4bit': True, 'load_in_8bit': False, 'full_finetuning': False})
generator.model.dtype
torch.bfloat16
from peft import PeftModel
isinstance(generator.model, PeftModel)
True
The Train Dataset¶
from datasets import Dataset
train_dataset = Dataset.from_dict(
# examples from Commonsense QA
{
"query": [
"The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?",
"Sammy wanted to go to where the people were. Where might he go?",
"To locate a choker not located in a jewelry box or boutique where would you go?",
"Google Maps and other highway and street GPS services have replaced what?",
"The fox walked from the city into the forest, what was it looking for?",
],
"response": [
"ignore",
"populated areas",
"jewelry store",
"atlas",
"natural habitat",
],
}
)
Since, Unsloth essentially applies efficiencies to the training processes of ~transformer.PreTrainedModels
as well as ~peft.PeftModels
, we can make full use of our HuggingFace generator trainer classes.
from fed_rag.trainers.huggingface.ralt import HuggingFaceTrainerForRALT
# the trainer object
generator_trainer = HuggingFaceTrainerForRALT(
rag_system=rag_system,
train_dataset=train_dataset,
# training_arguments=... # Optional ~transformers.TrainingArguments
)
result = generator_trainer.train()
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1 \\ /| Num examples = 5 | Num Epochs = 3 | Total steps = 3 O^O/ \_/ \ Batch size per device = 8 | Gradient accumulation steps = 1 \ / Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8 "-____-" Trainable parameters = 16,394,240/4,000,000,000 (0.41% trained) `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Step | Training Loss |
---|
Unsloth: Will smartly offload gradients to save VRAM!
result
TrainResult(loss=3.530837059020996)