Skip to content

Base Trainer

Base trainer classes for RAG system components.

BaseTrainer

Bases: BaseModel, ABC

Base Trainer Class.

This abstract class provides the interface for creating Trainer objects that implement different training strategies.

Attributes:

Name Type Description
rag_system RAGSystem

The RAG system to be trained.

train_dataset Any

Dataset used for training.

Source code in src/fed_rag/base/trainer.py
class BaseTrainer(BaseModel, ABC):
    """Base Trainer Class.

    This abstract class provides the interface for creating Trainer objects that
    implement different training strategies.

    Attributes:
        rag_system: The RAG system to be trained.
        train_dataset: Dataset used for training.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    rag_system: RAGSystem
    train_dataset: Any
    _model = PrivateAttr()

    @abstractmethod
    def train(self) -> TrainResult:
        """Trains the model.

        Returns:
            TrainResult: The result of model training.
        """

    @abstractmethod
    def evaluate(self) -> TestResult:
        """Evaluates the model.

        Returns:
            TestResult: The result of model evaluation.
        """

    @abstractmethod
    def _get_model_from_rag_system(self) -> Any:
        """Get the model from the RAG system."""

    @model_validator(mode="after")
    def set_model(self) -> "BaseTrainer":
        self._model = self._get_model_from_rag_system()
        return self

    @property
    def model(self) -> Any:
        """Return the model to be trained."""
        return self._model

    @model.setter
    def model(self, v: Any) -> None:
        """Set the model to be trained."""
        self._model = v

model property writable

model

Return the model to be trained.

train abstractmethod

train()

Trains the model.

Returns:

Name Type Description
TrainResult TrainResult

The result of model training.

Source code in src/fed_rag/base/trainer.py
@abstractmethod
def train(self) -> TrainResult:
    """Trains the model.

    Returns:
        TrainResult: The result of model training.
    """

evaluate abstractmethod

evaluate()

Evaluates the model.

Returns:

Name Type Description
TestResult TestResult

The result of model evaluation.

Source code in src/fed_rag/base/trainer.py
@abstractmethod
def evaluate(self) -> TestResult:
    """Evaluates the model.

    Returns:
        TestResult: The result of model evaluation.
    """

BaseRetrieverTrainer

Bases: BaseTrainer, ABC

Base trainer for retriever components of RAG systems.

This trainer focuses specifically on training the retriever's encoder components, either the full encoder or just the query encoder depending on the retriever configuration.

Source code in src/fed_rag/base/trainer.py
class BaseRetrieverTrainer(BaseTrainer, ABC):
    """Base trainer for retriever components of RAG systems.

    This trainer focuses specifically on training the retriever's encoder
    components, either the full encoder or just the query encoder depending
    on the retriever configuration.
    """

    def _get_model_from_rag_system(self) -> Any:
        if self.rag_system.retriever.encoder:
            return self.rag_system.retriever.encoder
        else:
            return (
                self.rag_system.retriever.query_encoder
            )  # only update query encoder

BaseGeneratorTrainer

Bases: BaseTrainer, ABC

Base trainer for generator component of RAG systems.

This trainer focuses specifically on training the generator model.

Attributes:

Name Type Description
rag_system RAGSystem | NoEncodeRAGSystem

The RAG system to be trained. Can also be a NoEncodeRAGSytem.

Source code in src/fed_rag/base/trainer.py
class BaseGeneratorTrainer(BaseTrainer, ABC):
    """Base trainer for generator component of RAG systems.

    This trainer focuses specifically on training the generator model.

    Attributes:
        rag_system: The RAG system to be trained. Can also be a `NoEncodeRAGSytem`.
    """

    rag_system: RAGSystem | NoEncodeRAGSystem

    def _get_model_from_rag_system(self) -> Any:
        return self.rag_system.generator.model