Basic Federated Fine-tuning of RAG Systems¶
In this notebook, we demonstrate how to perform federated RAG fine-tuning with FedRAG. Specifically, we'll apply federated learning to fine-tune the generator of a RAG system using a federated setting that comprises two clients.
HARDWARE REQUIREMENTS: This notebook requires a setup with at least two GPUs each having at least 12GB of RAM.
Install dependencies¶
!uv pip install fed-rag[huggingface,qdrant] docker -q
Setup¶
Running this notebook requires two high-level steps.
- Running the knowledge store Qdrant service via Docker
- Downloading the associated example Python script, which defines the
RAGSystem
as well as theFLTask
which we use to launch the federated learning task.
Running the Qdrant knowledge store service¶
We have previously prepared a knowledge store Qdrant service that comes pre-populated with knowledge artifacts from the December 2021 Wikipedia dump (i.e., Izacard, Gautier, et al. "Few-shot learning with retrieval augmented language models." arXiv preprint arXiv:2208.03299 1.2 (2022): 4.).
Executing the below command will run this docker image on the host machine.
import docker
import os
import time
client = docker.from_env()
image_name = "vectorinstitute/qdrant-atlas-dec-wiki-2021:latest"
# first see if we need to pull the docker image
try:
client.images.get(image_name)
print(f"Image '{image_name}' already exists locally")
except docker.errors.ImageNotFound:
print(f"Image '{image_name}' not found locally. Pulling...")
# Pull with progress information
for line in client.api.pull(image_name, stream=True, decode=True):
if "progress" in line:
print(f"\r{line['status']}: {line['progress']}", end="")
elif "status" in line:
print(f"\r{line['status']}", end="")
print("\nPull complete!")
# run the Qdrant container
container = client.containers.run(
"vectorinstitute/qdrant-atlas-dec-wiki-2021:latest",
detach=True, # -d flag
name="tiny-wiki-dec2021-ks", # --name
ports={"6333/tcp": 6333, "6334/tcp": 6334}, # -p 6333:6333 # -p 6334:6334
volumes={
"qdrant_data": { # -v qdrant_data:/qdrant_storage
"bind": "/qdrant_storage",
"mode": "rw",
}
},
environment={"SAMPLE_SIZE": "tiny"}, # -e SAMPLE_SIZE=tiny
device_requests=[
docker.types.DeviceRequest(
count=-1, capabilities=[["gpu"]]
) # --gpus all
],
remove=False, # Don't auto-remove when stopped
)
print(f"Container started with ID: {container.id}")
# wait a moment for the container to initialize
time.sleep(3)
# Check container status
container.reload() # Refresh container data
print(f"Container status: {container.status}")
print(f"Container logs:")
print(container.logs().decode("utf-8"))
Image 'vectorinstitute/qdrant-atlas-dec-wiki-2021:latest' already exists locally Container started with ID: 8e58b42f14a508109055e20dc5f0d066fce8b4775f4b3a9b98b758239ce19b6e Container status: running Container logs: Starting Qdrant Atlas Knowledge Store container Running database initialization check... Using tiny sample mode... Creating tiny sample file for testing... Using tiny sample file: tiny-sample.jsonl Verifying sample file creation... ✅ Sample file successfully created at: /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl File details: -rw-r--r-- 1 root root 6785 Jun 8 03:35 /app/data/atlas/enwiki-dec2021/tiny-sample.jsonl File content (first 3 lines): {"id": "140", "title": "History of marine biology", "section": "James Cook", "text": " James Cook is well known for his voyages of exploration for the British Navy in which he mapped out a significant amount of the world's uncharted waters. Cook's explorations took him around the world twice and led to countless descriptions of previously unknown plants and animals. Cook's explorations influenced many others and led to a number of scientists examining marine life more closely. Among those influenced was Charles Darwin who went on to make many contributions of his own. "} {"id": "141", "title": "History of marine biology", "section": "Charles Darwin", "text": " Charles Darwin, best known for his theory of evolution, made many significant contributions to the early study of marine biology. He spent much of his time from 1831 to 1836 on the voyage of HMS Beagle collecting and studying specimens from a variety of marine organisms. It was also on this expedition where Darwin began to study coral reefs and their formation. He came up with the theory that the overall growth of corals is a balance between the growth of corals upward and the sinking of the sea floor. He then came up with the idea that wherever coral atolls would be found, the central island where the coral had started to grow would be gradually subsiding"} {"id": "142", "title": "History of marine biology", "section": "Charles Wyville Thomson", "text": " Another influential expedition was the voyage of HMS Challenger from 1872 to 1876, organized and later led by Charles Wyville Thomson. It was the first expedition purely devoted to marine science. The expedition collected and analyzed thousands of marine specimens, laying the foundation for present knowledge about life near the deep-sea floor. The findings from the expedition were a summary of the known natural, physical and chemical ocean science to that time."} Directory listing: total 16 drwxr-xr-x 2 root root 4096 Jun 8 03:35 . drwxr-xr-x 3 root root 4096 Jun 8 03:35 .. -rw-r--r-- 1 root root 6785 Jun 8 03:35 tiny-sample.jsonl Starting Qdrant service for initialization... Waiting for Qdrant service to be ready... Waiting for Qdrant to start... (Attempt 1/30) _ _ __ _ __| |_ __ __ _ _ __ | |_ / _` |/ _` | '__/ _` | '_ \| __| | (_| | (_| | | | (_| | | | | |_ \__, |\__,_|_| \__,_|_| |_|\__| |_| Version: 1.14.0, build: 3617a011 Access web UI at http://localhost:6333/dashboard 2025-06-08T03:35:27.950637Z INFO storage::content_manager::consensus::persistent: Initializing new raft state at ./storage/raft_state.json 2025-06-08T03:35:27.986913Z INFO qdrant: Distributed mode disabled 2025-06-08T03:35:27.986975Z INFO qdrant: Telemetry reporting enabled, id: 625917b8-775f-4afa-b0e3-b10768d40bb4 2025-06-08T03:35:27.987060Z INFO qdrant: Inference service is not configured. 2025-06-08T03:35:27.989443Z INFO qdrant::actix: TLS disabled for REST API 2025-06-08T03:35:27.989564Z INFO qdrant::actix: Qdrant HTTP listening on 6333 2025-06-08T03:35:27.989603Z INFO actix_server::builder: Starting 11 workers 2025-06-08T03:35:27.989615Z INFO actix_server::server: Actix runtime found; starting in Actix runtime 2025-06-08T03:35:27.999766Z INFO qdrant::tonic: Qdrant gRPC listening on 6334 2025-06-08T03:35:27.999792Z INFO qdrant::tonic: TLS disabled for gRPC API
Check if the service is ready¶
To check if the knowledge store service is ready to be used, we can create a QdrantKnowledgeStore
with the correct collection name and check if the collection exists. If it does, then we're ready to carry on with the rest of the notebook.
from fed_rag.knowledge_stores import QdrantKnowledgeStore
ks = QdrantKnowledgeStore(
collection_name="nthakur.dragon-plus-context-encoder"
)
# If the collection exists, this should return an int.
# Otherwise, it will raise an error
ks.count
13
Download the example Python script which builds the RAG System and FL Task¶
This script can be found in the main Github repo for fed-rag and within the example_scripts
subdirectory. More specifically:
https://github.com/VectorInstitute/fed-rag/blob/main/example_scripts/cookbook_script-basic_fl.py
The commands below will download the script's text, display it here for convenience and then write it to a local file that we can execute.
SCRIPT_URL = "https://raw.githubusercontent.com/VectorInstitute/fed-rag/refs/heads/main/example_scripts/cookbook_script-basic_fl.py"
import requests
response = requests.get(SCRIPT_URL)
rag_code = response.text
from IPython.display import Code, display
display(Code(rag_code, language="python"))
from logging import INFO
from typing import Literal
import torch
from datasets import Dataset
from flwr.common.logger import log
from transformers.generation.utils import GenerationConfig
from fed_rag import RAGConfig, RAGSystem
from fed_rag.fl_tasks.huggingface import (
HuggingFaceFlowerClient,
HuggingFaceFlowerServer,
)
from fed_rag.generators import HFPretrainedModelGenerator
from fed_rag.knowledge_stores import QdrantKnowledgeStore
from fed_rag.retrievers import HFSentenceTransformerRetriever
from fed_rag.trainer_managers.huggingface import HuggingFaceRAGTrainerManager
from fed_rag.trainers.huggingface.ralt import HuggingFaceTrainerForRALT
GRPC_MAX_MESSAGE_LENGTH = int(512 * 1024 * 1024 * 3.75)
PEFT_MODEL_NAME = "Styxxxx/llama2_7b_lora-quac"
BASE_MODEL_NAME = "meta-llama/Llama-2-7b-hf"
TRAIN_DATASET = Dataset.from_dict(
# examples from Commonsense QA
{
"query": [
"The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?",
"Sammy wanted to go to where the people were. Where might he go?",
"To locate a choker not located in a jewelry box or boutique where would you go?",
"Google Maps and other highway and street GPS services have replaced what?",
],
"response": [
"ignore",
"populated areas",
"jewelry store",
"atlas",
],
}
)
VAL_DATASET = Dataset.from_dict(
{
"query": [
"The fox walked from the city into the forest, what was it looking for?"
],
"response": [
"natural habitat",
],
}
)
def get_trainer_manager(server: bool) -> HuggingFaceRAGTrainerManager:
# use the knowledge store in image: vectorinstitute/qdrant-atlas-dec-wiki-2021:latest
knowledge_store = QdrantKnowledgeStore(
collection_name="nthakur.dragon-plus-context-encoder",
timeout=10,
)
retriever = HFSentenceTransformerRetriever(
query_model_name="nthakur/dragon-plus-query-encoder",
context_model_name="nthakur/dragon-plus-context-encoder",
load_model_at_init=False,
)
# LLM generator
generation_cfg = GenerationConfig(
do_sample=True,
eos_token_id=151643,
bos_token_id=151643,
max_new_tokens=2048,
top_p=0.9,
temperature=0.6,
cache_implementation="offloaded",
stop_strings="</response>",
)
if server:
load_model_kwargs = {"device_map": "cpu", "torch_dtype": torch.float16}
else:
load_model_kwargs = {
"device_map": "auto",
"torch_dtype": torch.float16,
}
generator = HFPretrainedModelGenerator(
model_name="Qwen/Qwen2.5-0.5B",
load_model_at_init=False,
load_model_kwargs=load_model_kwargs,
generation_config=generation_cfg,
)
# assemble rag system
rag_config = RAGConfig(top_k=2)
rag_system = RAGSystem(
knowledge_store=knowledge_store, # knowledge store loaded from knowledge_store.py
generator=generator,
retriever=retriever,
rag_config=rag_config,
)
# the trainer object
generator_trainer = HuggingFaceTrainerForRALT(
rag_system=rag_system,
train_dataset=TRAIN_DATASET,
)
# trainer manager object
manager = HuggingFaceRAGTrainerManager(
mode="generator",
generator_trainer=generator_trainer,
)
return manager
def build_client(
train_manager: HuggingFaceRAGTrainerManager,
) -> HuggingFaceFlowerClient:
fl_task = train_manager.get_federated_task()
model = train_manager.model
log(INFO, f"loaded generator is on: {model.device}")
return fl_task.client(
model=model, train_dataset=TRAIN_DATASET, val_dataset=VAL_DATASET
)
def build_server(
train_manager: HuggingFaceRAGTrainerManager,
) -> HuggingFaceFlowerServer:
fl_task = train_manager.get_federated_task()
model = train_manager.model
return fl_task.server(model=model)
def main(
component: Literal["server", "client_0", "client_1"],
) -> None:
"""For starting any of the FL Task components.
IMPORTANT NOTE: This script requires the Dec. 2021 Wikipedia Qdrant knowledge
store to be up an running. Use the shell command below to run the docker image.
```sh
docker run --gpus all -d \
--name qdrant-ra-dit \
-p 6333:6333 \
-p 6334:6334 \
-v qdrant_data:/qdrant_storage \
-e SAMPLE_SIZE=tiny \
vectorinstitute/qdrant-atlas-dec-wiki-2021
```
MAIN USAGE:
## server
`uv run python example_scripts/cookbook_script-basic_fl.py --component server`
## client 1
`uv run python example_scripts/cookbook_script-basic_fl.py --component client_0`
## client 1
`uv run python example_scripts/cookbook_script-basic_fl.py --component client_1`
"""
import flwr as fl
if component == "server":
manager = get_trainer_manager(server=True)
server = build_server(manager)
fl.server.start_server(
server=server,
server_address="[::]:8080",
grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH,
)
elif component in ["client_0", "client_1"]:
manager = get_trainer_manager(server=False)
client = build_client(manager)
fl.client.start_client(
client=client,
server_address="[::]:8080",
grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH,
)
else:
raise ValueError("Unrecognized component.")
if __name__ == "__main__":
import fire
fire.Fire(main)
Federated fine-tuning¶
The script displayed above shows the RAGSystem
and the generator trainer task that we will federate next. To do this we will:
- Write the script text to a file
- Launch the server and two clients in their own separate subprocesses
# write the script's code to a Python file on disk
with open("rag_federated_learning.py", "w") as f:
f.write(rag_code)
With a file written to our local disk, we can run the script to launch the FL servers and clients. We will use a notebook utility class called ProcessMonitor
to do so.
from fed_rag.utils.notebook import ProcessMonitor
monitor = ProcessMonitor()
# launch server command
server_command = "python rag_federated_learning.py --component server"
# launch client command template
# the two clients will use one of the two available GPUs exclusively
client_command = "export CUDA_VISIBLE_DEVICES={client_id} && python rag_federated_learning.py --component client_{client_id}"
# start server process
monitor.start_process("server", server_command)
# give server time to standup
time.sleep(2)
✅ Started server (PID: 85559)
# start client processes
monitor.start_process(
name="client_0", command=client_command.format(client_id="0")
)
monitor.start_process(
name="client_1", command=client_command.format(client_id="1")
)
✅ Started client_0 (PID: 85585) ✅ Started client_1 (PID: 85588)
# this cell will run until completion of the subprocesses or if the kernel is interrupted
monitor.monitor_live(["server", "client_0", "client_1"])
🖥️ PROCESS MONITOR ============================================================ server 🔴 STOPPED ------------------------------ [23:36:29] INFO : Evaluation returned no results (`None`) [23:36:29] INFO : [23:36:29] INFO : [ROUND 1] [23:36:31] INFO : configure_fit: strategy sampled 2 clients (out of 2) [23:37:20] INFO : aggregate_fit: received 2 results and 0 failures [23:37:26] WARNING : No fit_metrics_aggregation_fn provided [23:37:26] INFO : configure_evaluate: strategy sampled 2 clients (out of 2) [23:37:41] INFO : aggregate_evaluate: received 2 results and 0 failures [23:37:41] WARNING : No evaluate_metrics_aggregation_fn provided [23:37:41] INFO : [23:37:41] INFO : [SUMMARY] [23:37:41] INFO : Run finished 1 round(s) in 72.02s [23:37:41] INFO : History (loss, distributed): [23:37:41] INFO : round 1: 0.41999998688697815 [23:37:41] INFO : client_0 🔴 STOPPED ------------------------------ [23:37:09] [23:37:09] [23:37:09] 100%|██████████| 3/3 [00:22<00:00, 2.18s/it] [23:37:09] 100%|██████████| 3/3 [00:22<00:00, 7.65s/it] [23:37:09] /home/nerdai/Projects/fed-rag/src/fed_rag/fl_tasks/huggingface.py:116: PydanticDeprecatedSince211: Accessing the 'model_fields' attribute on the instance is deprecated. Instead, you should access this attribute from the model class. Deprecated in Pydantic V2.11 to be removed in V3.0. [23:37:09] if name in self.task_bundle.model_fields: [23:37:16] INFO : Sent reply [23:37:40] INFO : [23:37:40] INFO : Received: evaluate message 5b749201-b76e-4b83-b8fd-485b59021f3f [23:37:40] WARNING : Deprecation Warning: The `client_fn` function must return an instance of `Client`, but an instance of `NumpyClient` was returned. Please use `NumPyClient.to_client()` method to convert it to `Client`. [23:37:41] INFO : Sent reply [23:37:41] INFO : [23:37:41] INFO : Received: reconnect message d60bda79-f1e1-4bd1-bb2e-95e8d7496ab6 [23:37:41] INFO : Disconnect and shut down [23:37:41] {'train_runtime': 22.963, 'train_samples_per_second': 0.523, 'train_steps_per_second': 0.131, 'train_loss': 2.1859957377115884, 'epoch': 3.0} client_1 🔴 STOPPED ------------------------------ [23:37:09] [23:37:09] [23:37:09] 100%|██████████| 3/3 [00:22<00:00, 2.03s/it] [23:37:09] 100%|██████████| 3/3 [00:22<00:00, 7.49s/it] [23:37:09] /home/nerdai/Projects/fed-rag/src/fed_rag/fl_tasks/huggingface.py:116: PydanticDeprecatedSince211: Accessing the 'model_fields' attribute on the instance is deprecated. Instead, you should access this attribute from the model class. Deprecated in Pydantic V2.11 to be removed in V3.0. [23:37:09] if name in self.task_bundle.model_fields: [23:37:17] INFO : Sent reply [23:37:40] INFO : [23:37:40] INFO : Received: evaluate message 0d91065a-9fd8-44e2-b744-c477a3d0cdbc [23:37:40] WARNING : Deprecation Warning: The `client_fn` function must return an instance of `Client`, but an instance of `NumpyClient` was returned. Please use `NumPyClient.to_client()` method to convert it to `Client`. [23:37:41] INFO : Sent reply [23:37:41] INFO : [23:37:41] INFO : Received: reconnect message 46311e1d-d99c-4a0f-9c3a-2931384263cd [23:37:41] INFO : Disconnect and shut down [23:37:41] {'train_runtime': 22.4633, 'train_samples_per_second': 0.534, 'train_steps_per_second': 0.134, 'train_loss': 2.1859957377115884, 'epoch': 3.0} 🔄 Last updated: 23:37:44 Press Ctrl+C to stop monitoring
Cleanup¶
monitor.stop_all()
🛑 Stopped server 🛑 Stopped client_0 🛑 Stopped client_1 🛑 All processes stopped
# stop and remove container
container.stop()
container.remove()