Skip to content

Base Trainer

Base Trainer

BaseTrainer

Bases: BaseModel, ABC

Base Trainer Class.

Source code in src/fed_rag/base/trainer.py
class BaseTrainer(BaseModel, ABC):
    """Base Trainer Class."""

    model_config = ConfigDict(arbitrary_types_allowed=True)
    rag_system: RAGSystem
    train_dataset: Any
    _model = PrivateAttr()

    @abstractmethod
    def train(self) -> TrainResult:
        """Train loop."""

    @abstractmethod
    def evaluate(self) -> TestResult:
        """Evaluation"""

    @abstractmethod
    def _get_model_from_rag_system(self) -> Any:
        """Get the model from the RAG system."""

    @model_validator(mode="after")
    def set_model(self) -> "BaseTrainer":
        self._model = self._get_model_from_rag_system()
        return self

    @property
    def model(self) -> Any:
        """Return the model to be trained."""
        return self._model

    @model.setter
    def model(self, v: Any) -> None:
        """Set the model to be trained."""
        self._model = v

model property writable

model

Return the model to be trained.

train abstractmethod

train()

Train loop.

Source code in src/fed_rag/base/trainer.py
@abstractmethod
def train(self) -> TrainResult:
    """Train loop."""

evaluate abstractmethod

evaluate()

Evaluation

Source code in src/fed_rag/base/trainer.py
@abstractmethod
def evaluate(self) -> TestResult:
    """Evaluation"""

BaseRetrieverTrainer

Bases: BaseTrainer, ABC

Base Retriever Trainer Class.

Source code in src/fed_rag/base/trainer.py
class BaseRetrieverTrainer(BaseTrainer, ABC):
    """Base Retriever Trainer Class."""

    def _get_model_from_rag_system(self) -> Any:
        if self.rag_system.retriever.encoder:
            return self.rag_system.retriever.encoder
        else:
            return (
                self.rag_system.retriever.query_encoder
            )  # only update query encoder

BaseGeneratorTrainer

Bases: BaseTrainer, ABC

Base Retriever Trainer Class.

Source code in src/fed_rag/base/trainer.py
class BaseGeneratorTrainer(BaseTrainer, ABC):
    """Base Retriever Trainer Class."""

    def _get_model_from_rag_system(self) -> Any:
        return self.rag_system.generator.model