Bases: BaseRAGTrainerManager
PyTorch native RAG Trainer Manager
Source code in src/fed_rag/trainer_managers/pytorch.py
| class PyTorchRAGTrainerManager(BaseRAGTrainerManager):
"""PyTorch native RAG Trainer Manager"""
def _prepare_generator_for_training(self, **kwargs: Any) -> None:
self.generator_trainer.model.train()
# freeze retriever
if self.retriever_trainer:
self.retriever_trainer.model.eval()
def _prepare_retriever_for_training(
self, freeze_context_encoder: bool = True, **kwargs: Any
) -> None:
self.retriever_trainer.model.train()
# freeze generator
if self.generator_trainer:
self.generator_trainer.model.eval()
def _train_retriever(self, **kwargs: Any) -> None:
self._prepare_retriever_for_training()
if self.retriever_trainer:
self.retriever_trainer.train()
else:
raise UnspecifiedRetrieverTrainer(
"Attempted to perform retriever trainer with an unspecified trainer function."
)
def _train_generator(self, **kwargs: Any) -> None:
self._prepare_generator_for_training()
if self.generator_trainer:
self.generator_trainer.train()
else:
raise UnspecifiedGeneratorTrainer(
"Attempted to perform generator trainer with an unspecified trainer function."
)
def train(self, **kwargs: Any) -> None:
if self.mode == "retriever":
self._train_retriever()
elif self.mode == "generator":
self._train_generator()
else:
assert_never(self.mode) # pragma: no cover
def _get_federated_trainer(self) -> tuple[Callable, nn.Module]:
if self.mode == "retriever":
if self.retriever_trainer is None:
raise UnspecifiedRetrieverTrainer(
"Cannot federate an unspecified retriever trainer function."
)
retriever_train_fn = self.retriever_trainer.train
retriever_module = self.retriever_trainer.model
# Create a standalone function for federation
def train_wrapper(
_mdl: nn.Module,
_train_dataloader: DataLoader,
_val_dataloader: DataLoader,
) -> TrainResult:
_ = retriever_train_fn()
return TrainResult(loss=0)
return federate.trainer.pytorch(train_wrapper), retriever_module
elif self.mode == "generator":
if self.generator_trainer is None:
raise UnspecifiedGeneratorTrainer(
"Cannot federate an unspecified generator trainer function."
)
generator_train_fn = self.generator_trainer.train
generator_module = self.generator_trainer.model
# Create a standalone function for federation
def train_wrapper(
_mdl: nn.Module,
_train_dataloader: DataLoader,
_val_dataloader: DataLoader,
) -> TrainResult:
_ = generator_train_fn()
# TODO get loss from out
return TrainResult(loss=0)
return federate.trainer.pytorch(train_wrapper), generator_module
else:
assert_never(self.mode) # pragma: no cover
def get_federated_task(self) -> PyTorchFLTask:
federated_trainer, _module = self._get_federated_trainer()
# TODO: add logic for getting evaluator/tester and then federate it as well
# federated_tester = self.get_federated_tester(tester_decorator)
# For now, using a simple placeholder test function
def test_fn(_mdl: nn.Module, _dataloader: DataLoader) -> TestResult:
# Implement simple testing or return a placeholder
return TestResult(loss=0.42, metrics={}) # pragma: no cover
federated_tester = federate.tester.pytorch(test_fn)
return PyTorchFLTask.from_trainer_and_tester(
trainer=federated_trainer,
tester=federated_tester,
)
|