Skip to content

Huggingface

HuggingFace Trainer Mixin

HuggingFaceTrainerProtocol

Bases: Protocol

Source code in src/fed_rag/trainers/huggingface/mixin.py
@runtime_checkable
class HuggingFaceTrainerProtocol(Protocol):
    train_dataset: "Dataset"
    training_arguments: Optional["TrainingArguments"]

    def model(
        self,
    ) -> Union["SentenceTransformer", "PreTrainedModel", "PeftModel"]:
        pass  # pragma: no cover

HuggingFaceTrainerMixin

Bases: BaseModel, ABC

HuggingFace Trainer Mixin.

Source code in src/fed_rag/trainers/huggingface/mixin.py
class HuggingFaceTrainerMixin(BaseModel, ABC):
    """HuggingFace Trainer Mixin."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )
    train_dataset: "Dataset"
    training_arguments: Optional["TrainingArguments"] = None

    def __init__(
        self,
        train_dataset: "Dataset",
        training_arguments: Optional["TrainingArguments"] = None,
        **kwargs: Any,
    ):
        if not _has_huggingface:
            msg = (
                f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
                "To fix please run `pip install fed-rag[huggingface]`."
            )
            raise MissingExtraError(msg)
        super().__init__(
            train_dataset=train_dataset,
            training_arguments=training_arguments,
            **kwargs,
        )

    @property
    @abstractmethod
    def hf_trainer_obj(self) -> "Trainer":
        """A ~transformers.Trainer object."""

hf_trainer_obj abstractmethod property

hf_trainer_obj

A ~transformers.Trainer object.

HuggingFace LM-Supervised Retriever Trainer

HuggingFaceTrainerForLSR

Bases: HuggingFaceTrainerMixin, BaseRetrieverTrainer

HuggingFace LM-Supervised Retriever Trainer.

Source code in src/fed_rag/trainers/huggingface/lsr.py
class HuggingFaceTrainerForLSR(HuggingFaceTrainerMixin, BaseRetrieverTrainer):
    """HuggingFace LM-Supervised Retriever Trainer."""

    _hf_trainer: Optional["SentenceTransformerTrainer"] = PrivateAttr(
        default=None
    )

    def __init__(
        self,
        rag_system: RAGSystem,
        train_dataset: "Dataset",
        training_arguments: Optional["TrainingArguments"] = None,
        **kwargs: Any,
    ):
        super().__init__(
            train_dataset=train_dataset,
            rag_system=rag_system,
            training_arguments=training_arguments,
            **kwargs,
        )

    @model_validator(mode="after")
    def set_private_attributes(self) -> "HuggingFaceTrainerForLSR":
        # if made it to here, then this import is available
        from sentence_transformers import SentenceTransformer

        # validate rag system
        _validate_rag_system(self.rag_system)

        # validate model
        if not isinstance(self.model, SentenceTransformer):
            raise TrainerError(
                "For `HuggingFaceTrainerForLSR`, attribute `model` must be of type "
                "`~sentence_transformers.SentenceTransformer`."
            )

        self._hf_trainer = LSRSentenceTransformerTrainer(
            model=self.model,
            args=self.training_arguments,
            data_collator=DataCollatorForLSR(rag_system=self.rag_system),
            train_dataset=self.train_dataset,
        )

        return self

    def train(self) -> TrainResult:
        output: TrainOutput = self.hf_trainer_obj.train()
        return TrainResult(loss=output.training_loss)

    def evaluate(self) -> TestResult:
        # TODO: implement this
        raise NotImplementedError

    @property
    def hf_trainer_obj(self) -> "SentenceTransformerTrainer":
        return self._hf_trainer

LSRSentenceTransformerTrainer

Bases: SentenceTransformerTrainer

Source code in src/fed_rag/trainers/huggingface/lsr.py
class LSRSentenceTransformerTrainer(SentenceTransformerTrainer):
    def __init__(
        self,
        *args: Any,
        data_collator: DataCollatorForLSR,
        loss: Optional[LSRLoss] = None,
        **kwargs: Any,
    ):
        if not _has_huggingface:
            msg = (
                f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
                "To fix please run `pip install fed-rag[huggingface]`."
            )
            raise MissingExtraError(msg)

        # set loss
        if loss is None:
            loss = LSRLoss()
        else:
            if not isinstance(loss, LSRLoss):
                raise InvalidLossError(
                    "`LSRSentenceTransformerTrainer` must use ~fed_rag.loss.LSRLoss`."
                )

        if not isinstance(data_collator, DataCollatorForLSR):
            raise InvalidDataCollatorError(
                "`LSRSentenceTransformerTrainer` must use ~fed_rag.data_collators.DataCollatorForLSR`."
            )

        super().__init__(
            *args, loss=loss, data_collator=data_collator, **kwargs
        )

    def collect_scores(
        self, inputs: dict[str, torch.Tensor | Any]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if "retrieval_scores" not in inputs:
            raise MissingInputTensor(
                "Collated `inputs` are missing key `retrieval_scores`"
            )

        if "lm_scores" not in inputs:
            raise MissingInputTensor(
                "Collated `inputs` are missing key `lm_scores`"
            )

        retrieval_scores = inputs.get("retrieval_scores")
        lm_scores = inputs.get("lm_scores")

        return retrieval_scores, lm_scores

    def compute_loss(
        self,
        model: "SentenceTransformer",
        inputs: dict[str, torch.Tensor | Any],
        return_outputs: bool = False,
        num_items_in_batch: Any | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
        """Compute LSR loss.

        NOTE: the forward pass of the model is taken care of in the DataCollatorForLSR.

        Args:
            model (SentenceTransformer): _description_
            inputs (dict[str, torch.Tensor  |  Any]): _description_
            return_outputs (bool, optional): _description_. Defaults to False.
            num_items_in_batch (Any | None, optional): _description_. Defaults to None.

        Raises:
            NotImplementedError: _description_

        Returns:
            torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: _description_
        """
        retrieval_scores, lm_scores = self.collect_scores(inputs)
        loss = self.loss(retrieval_scores, lm_scores)

        # inputs are actually the outputs of RAGSystem's "forward" pass
        return (loss, inputs) if return_outputs else loss

compute_loss

compute_loss(
    model,
    inputs,
    return_outputs=False,
    num_items_in_batch=None,
)

Compute LSR loss.

NOTE: the forward pass of the model is taken care of in the DataCollatorForLSR.

Parameters:

Name Type Description Default
model SentenceTransformer

description

required
inputs dict[str, Tensor | Any]

description

required
return_outputs bool

description. Defaults to False.

False
num_items_in_batch Any | None

description. Defaults to None.

None

Raises:

Type Description
NotImplementedError

description

Returns:

Type Description
Tensor | tuple[Tensor, dict[str, Any]]

torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: description

Source code in src/fed_rag/trainers/huggingface/lsr.py
def compute_loss(
    self,
    model: "SentenceTransformer",
    inputs: dict[str, torch.Tensor | Any],
    return_outputs: bool = False,
    num_items_in_batch: Any | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
    """Compute LSR loss.

    NOTE: the forward pass of the model is taken care of in the DataCollatorForLSR.

    Args:
        model (SentenceTransformer): _description_
        inputs (dict[str, torch.Tensor  |  Any]): _description_
        return_outputs (bool, optional): _description_. Defaults to False.
        num_items_in_batch (Any | None, optional): _description_. Defaults to None.

    Raises:
        NotImplementedError: _description_

    Returns:
        torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: _description_
    """
    retrieval_scores, lm_scores = self.collect_scores(inputs)
    loss = self.loss(retrieval_scores, lm_scores)

    # inputs are actually the outputs of RAGSystem's "forward" pass
    return (loss, inputs) if return_outputs else loss

HuggingFace Retrieval-Augmented Generator Trainer

HuggingFaceTrainerForRALT

Bases: HuggingFaceTrainerMixin, BaseGeneratorTrainer

HuggingFace Trainer for Retrieval-Augmented LM Training/Fine-Tuning.

Source code in src/fed_rag/trainers/huggingface/ralt.py
class HuggingFaceTrainerForRALT(HuggingFaceTrainerMixin, BaseGeneratorTrainer):
    """HuggingFace Trainer for Retrieval-Augmented LM Training/Fine-Tuning."""

    _hf_trainer: Optional["Trainer"] = PrivateAttr(default=None)

    def __init__(
        self,
        rag_system: RAGSystem,
        train_dataset: "Dataset",
        training_arguments: Optional["TrainingArguments"] = None,
        **kwargs: Any,
    ):
        if not _has_huggingface:
            msg = (
                f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
                "To fix please run `pip install fed-rag[huggingface]`."
            )
            raise MissingExtraError(msg)

        if training_arguments is None:
            training_arguments = _get_default_training_args()
        else:
            training_arguments.remove_unused_columns = (
                False  # pragma: no cover
            )

        super().__init__(
            train_dataset=train_dataset,
            rag_system=rag_system,
            training_arguments=training_arguments,
            **kwargs,
        )

    @model_validator(mode="after")
    def set_private_attributes(self) -> "HuggingFaceTrainerForRALT":
        # if made it to here, then this import is available
        from transformers import Trainer

        # validate rag system
        _validate_rag_system(self.rag_system)

        self._hf_trainer = Trainer(
            model=self.model,
            args=self.training_arguments,
            data_collator=DataCollatorForRALT(rag_system=self.rag_system),
            train_dataset=self.train_dataset,
        )

        return self

    def train(self) -> TrainResult:
        output: TrainOutput = self.hf_trainer_obj.train()
        return TrainResult(loss=output.training_loss)

    def evaluate(self) -> TestResult:
        # TODO: implement this
        raise NotImplementedError

    @property
    def hf_trainer_obj(self) -> "Trainer":
        return self._hf_trainer