mmlearn.tasks.base module¶
Base class for all tasks in mmlearn that require training.
- class TrainingTask(optimizer=None, lr_scheduler=None, loss_fn=None, compute_validation_loss=True, compute_test_loss=True)[source]¶
Bases:
LightningModule
Base class for all tasks in mmlearn that require training.
- Parameters:
optimizer (Optional[partial[torch.optim.Optimizer]], optional, default=None) – The optimizer to use for training. This is expected to be a partial function, created using functools.partial, that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.
lr_scheduler (Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None) – The learning rate scheduler to use for training. This can be a partial function that takes the optimizer as the only required argument or a dictionary with a scheduler key that specifies the scheduler and an optional extras key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.
loss_fn (Optional[torch.nn.Module], optional, default=None) – Loss function to use for training.
compute_validation_loss (bool, optional, default=True) – Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.
compute_test_loss (bool, optional, default=True) – Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.
- Raises:
ValueError – If the loss function is not provided and either the validation or test loss needs to be computed.