Bases: BaseModel
, ABC
HuggingFace Trainer Mixin.
Source code in src/fed_rag/trainers/pytorch/mixin.py
| class PyTorchTrainerMixin(BaseModel, ABC):
"""HuggingFace Trainer Mixin."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
train_dataset: Dataset
train_dataloader: DataLoader
training_arguments: TrainingArgs | None = None
def __init__(
self,
train_dataloader: DataLoader,
train_dataset: Dataset | None = None,
training_arguments: TrainingArgs | None = None,
**kwargs: Any,
):
if train_dataset is None:
train_dataset = train_dataloader.dataset
else:
# ensure consistency between loader.dataset and the supplied one
if id(train_dataset) != id(train_dataloader.dataset):
raise InconsistentDatasetError(
"Inconsistent datasets detected between supplied `train_dataset` and that "
"associated with the `train_dataloader`. These two datasets must be the same."
)
super().__init__(
train_dataset=train_dataset,
train_dataloader=train_dataloader,
training_arguments=training_arguments,
**kwargs,
)
|