Skip to content

Huggingface

HuggingFace RAG Trainer

HuggingFaceRAGTrainerManager

Bases: BaseRAGTrainerManager

HuggingFace RAG Trainer Manager

Source code in src/fed_rag/trainer_managers/huggingface.py
class HuggingFaceRAGTrainerManager(BaseRAGTrainerManager):
    """HuggingFace RAG Trainer Manager"""

    def __init__(
        self,
        mode: RAGTrainMode,
        retriever_trainer: BaseTrainer | None = None,
        generator_trainer: BaseTrainer | None = None,
        **kwargs: Any,
    ):
        if not _has_huggingface:
            msg = (
                f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
                "To fix please run `pip install fed-rag[huggingface]`."
            )
            raise MissingExtraError(msg)
        super().__init__(
            mode=mode,
            retriever_trainer=retriever_trainer,
            generator_trainer=generator_trainer,
            **kwargs,
        )

    def _prepare_generator_for_training(self, **kwargs: Any) -> None:
        self.generator_trainer.model.train()

        # freeze generator
        if self.retriever_trainer:
            self.retriever_trainer.model.eval()

    def _prepare_retriever_for_training(
        self, freeze_context_encoder: bool = True, **kwargs: Any
    ) -> None:
        self.retriever_trainer.model.train()

        # freeze generator
        if self.generator_trainer:
            self.generator_trainer.model.eval()

    def _train_retriever(self, **kwargs: Any) -> TrainResult:
        self._prepare_retriever_for_training()
        if self.retriever_trainer:
            return self.retriever_trainer.train()
        else:
            raise UnspecifiedRetrieverTrainer(
                "Attempted to perform retriever trainer with an unspecified trainer function."
            )

    def _train_generator(self, **kwargs: Any) -> TrainResult:
        self._prepare_generator_for_training()
        if self.generator_trainer:
            return self.generator_trainer.train()
        else:
            raise UnspecifiedGeneratorTrainer(
                "Attempted to perform generator trainer with an unspecified trainer function."
            )

    def train(self, **kwargs: Any) -> TrainResult:
        if self.mode == "retriever":
            return self._train_retriever()
        elif self.mode == "generator":
            return self._train_generator()
        else:
            assert_never(self.mode)  # pragma: no cover

    def _get_federated_trainer(self) -> tuple[Callable, "HFModelType"]:
        if self.mode == "retriever":
            if self.retriever_trainer is None:
                raise UnspecifiedRetrieverTrainer(
                    "Cannot federate an unspecified retriever trainer function."
                )
            retriever_train_fn = self.retriever_trainer.train
            retriever_module = self.retriever_trainer.model
            retriever_module = cast("SentenceTransformer", retriever_module)

            # Create a standalone function for federation
            def train_wrapper(
                _mdl: "HFModelType",
                _train_dataset: "Dataset",
                _val_dataset: "Dataset",
            ) -> TrainResult:
                _ = retriever_train_fn()
                return TrainResult(loss=0)

            return (
                federate.trainer.huggingface(train_wrapper),
                retriever_module,
            )

        elif self.mode == "generator":
            if self.generator_trainer is None:
                raise UnspecifiedGeneratorTrainer(
                    "Cannot federate an unspecified generator trainer function."
                )
            generator_train_fn = self.generator_trainer.train
            generator_module = self.generator_trainer.model

            # Create a standalone function for federation
            def train_wrapper(
                _mdl: "HFModelType",  # TODO: handle union types in inspector
                _train_dataset: "Dataset",
                _val_dataset: "Dataset",
            ) -> TrainResult:
                _ = generator_train_fn()
                # TODO get loss from out
                return TrainResult(loss=0)

            return (
                federate.trainer.huggingface(train_wrapper),
                generator_module,
            )
        else:
            assert_never(self.mode)  # pragma: no cover

    def get_federated_task(self) -> "HuggingFaceFLTask":
        from fed_rag.fl_tasks.huggingface import HuggingFaceFLTask

        federated_trainer, _module = self._get_federated_trainer()

        # TODO: add logic for getting evaluator/tester and then federate it as well
        # federated_tester = self.get_federated_tester(tester_decorator)
        # For now, using a simple placeholder test function
        def test_fn(_mdl: "HFModelType", _dataset: "Dataset") -> TestResult:
            # Implement simple testing or return a placeholder
            return TestResult(loss=0.42, metrics={})  # pragma: no cover

        federated_tester = federate.tester.huggingface(test_fn)

        return HuggingFaceFLTask.from_trainer_and_tester(
            trainer=federated_trainer,
            tester=federated_tester,
        )