Standard Usage¶
The standard usage pattern for fine-tuning a RAG system with FedRAG follows the below listed steps:
- Build a
train_datasetthat contains examples of (query, response) pairs. - Specify a retriever trainer as well as a generator trainer.
- Construct a RAG trainer manager and invoke the
train()method - (Optional) Get the associated
FLTaskRAGTrainerManager.get_federated_task()
Info
These steps assume that you have already constructed your RAGSystem that
you intend to fine-tune.
Info
The below code snippets require the hugginface extra to be installed, which
can be done via a pip install fed-rag[huggingface].
Build a train_dataset¶
For now, all FedRAG trainers deal with datasets that comprise of examples with (query, answer) pairs.
from datasets import Dataset
train_dataset = Dataset.from_dict(
{
"query": ["a query", "another query", ...],
"response": [
"reponse to a query",
"another response to another query",
...,
],
}
)
Specify a retriever and generator trainer¶
FedRAG trainer classes bear the responsibility of training the associated retriever
or generator on the training dataset. It has an attached data collator that takes
a batch of the training dataset and applies the "forward" pass of the RAG system
(i.e., retrieval from the knowledge store and if required, the subsequent generation
step), and returns the ~torch.Tensors required for computing the desire loss.
These trainer classes take your RAGSystem as input amongst possibly other
parameters.
from fed_rag.trainers.huggingface import (
HuggingFaceTrainerForRALT,
HuggingFaceTrainerForLSR,
)
retriever_trainer = HuggingFaceTrainerForLSR(rag_system)
generator_trainer = HuggingFaceTrainerForRALT(rag_system)
Create a RAGTrainerManager¶
The trainer manager class is responsible for orchestrating the training of the RAG system.
from fed_rag.trainer_managers.huggingface import HuggingFaceRAGTrainerManager
trainer_manager = HuggingFaceRAGTrainerManager(
mode="retriever",
retriever_trainer=retriever_trainer,
generator_trainer=generator_trainer,
)
# train
result = trainer_manager.train()
print(result.loss)
Note
Alternating training of the retriever and generator can be done by modifying
the mode attribute of the manager and calling train(). In the future, the
trainer manager will be able to orchestrate between retriever and generator
fine-tuning within a single epoch.
(Optional) Get the FLTask for federated training¶
FedRAG trainer managers offer a simple way to get the associated FLTask for
federated fine-tuning.
- This will return an
FLTaskfor either the retriever trainer or the generator trainer task, depending on themodethat the trainer manager is currently set on.
Spin up FL servers and clients¶
With an FLTask, we can obtain an FL server as well as clients. Starting a server
and required number of clients will commence the federated training.
import flwr as fl # (1)!
# federate generator fine-tuning
model = rag_system.generator.model
# server
server = fl_task.server(model, ...) # (2)!
# client
client = fl_task.client(...) # (3)!
# the below commands are blocking and would need to be run in separate processes
fl.server.start_server(server=server, server_address="[::]:8080")
fl.client.start_client(client=client, server_address="[::]:8080")
flwris the backend federated learning framework for FedRAG and comes included with the installation offed-rag.- Can pass in FL aggregation strategy, otherwise defaults to federated averaging.
- Requires the same arguments as the centralized
training_loop.
Note
Under the hood, FLTask.server() and FLTask.client() build ~flwr.Server
and ~flwr.Client objects, respectively.