mmlearn.tasks.contrastive_pretraining module¶
Contrastive pretraining task.
- class AuxiliaryTaskSpec(modality, task, loss_weight=1.0)[source]¶
Bases:
object
Specification for an auxiliary task to run alongside the main task.
-
task:
Any
¶ The auxiliary task module. This is expected to be a partially-initialized instance of a
LightningModule
created usingfunctools.partial()
, such that an initialized encoder can be passed as the only argument.
-
task:
- class ContrastivePretraining(encoders, heads=None, postprocessors=None, modality_module_mapping=None, optimizer=None, lr_scheduler=None, init_logit_scale=14.285714285714285, max_logit_scale=100, learnable_logit_scale=True, loss=None, modality_loss_pairs=None, auxiliary_tasks=None, log_auxiliary_tasks_loss=False, compute_validation_loss=True, compute_test_loss=True, evaluation_tasks=None)[source]¶
Bases:
TrainingTask
Contrastive pretraining task.
This class supports contrastive pretraining with
N
modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.- Parameters:
encoders (dict[str, torch.nn.Module]) – A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the
modality_module_mapping
parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.heads (Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None) – A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the
modality_module_mapping
parameter must be provided. If any of the values are dictionaries, they will be wrapped in atorch.nn.Sequential
module. All head modules are expected to take a single input tensor and return a single output tensor.postprocessors (Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None) – A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.
modality_module_mapping (Optional[dict[str, ModuleKeySpec]], optional, default=None) – A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.
optimizer (Optional[partial[torch.optim.Optimizer]], optional, default=None) – The optimizer to use for training. This is expected to be a
partial()
function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.lr_scheduler (Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None) – The learning rate scheduler to use for training. This can be a
partial()
function that takes the optimizer as the only required argument or a dictionary with a'scheduler'
key that specifies the scheduler and an optional'extras'
key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.init_logit_scale (float, optional, default=1 / 0.07) – The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.
max_logit_scale (float, optional, default=100) – The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range
[0, log(max_logit_scale)]
.learnable_logit_scale (bool, optional, default=True) – Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.
loss (Optional[torch.nn.Module], optional, default=None) – The loss function to use.
modality_loss_pairs (Optional[list[LossPairSpec]], optional, default=None) – A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.
auxiliary_tasks (dict[str, AuxiliaryTaskSpec], optional, default=None) –
Auxiliary tasks to run alongside the main contrastive pretraining task.
The auxiliary task module is expected to be a partially-initialized instance of a
LightningModule
created usingfunctools.partial()
, such that an initialized encoder can be passed as the only argument.The
modality
parameter specifies the modality of the encoder to use for the auxiliary task. Theloss_weight
parameter specifies the weight to apply to the auxiliary task loss.
log_auxiliary_tasks_loss (bool, optional, default=False) – Whether to log the loss of auxiliary tasks to the main logger.
compute_validation_loss (bool, optional, default=True) – Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.
compute_test_loss (bool, optional, default=True) – Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.
evaluation_tasks (Optional[dict[str, EvaluationSpec]], optional, default=None) – Evaluation tasks to run during validation, while training, and during testing.
- Raises:
If the loss function is not provided and either the validation or test loss needs to be computed. - If the given modality is not supported. - If the encoder, head, or postprocessor is not mapped to a modality. - If an unsupported modality is found in the loss pair specification. - If an unsupported modality is found in the auxiliary tasks. - If the auxiliary task is not a partial function. - If the evaluation task is not an instance of
EvaluationHooks
.
- encode(inputs, modality, normalize=False)[source]¶
Encode the input values for the given modality.
- Parameters:
- Returns:
The encoded values for the specified modality.
- Return type:
- encoders¶
A
ModuleDict
, where the keys are the names of the modalities and the values are the encoder modules.
- evaluation_tasks¶
A dictionary of evaluation tasks to run during validation, while training, or during testing.
- forward(inputs)[source]¶
Run the forward pass.
- Parameters:
- Returns:
The encodings for each modality.
- Return type:
- heads¶
A
ModuleDict
, where the keys are the names of the modalities and the values are the projection head modules. This can beNone
if no heads modules are provided.
- modality_loss_pairs¶
A list
LossPairSpec
instances specifying the pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.
- on_load_checkpoint(checkpoint)[source]¶
Modify the model checkpoint after loading.
The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.
- on_save_checkpoint(checkpoint)[source]¶
Modify the checkpoint before saving.
The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.
- on_test_epoch_end()[source]¶
Compute and log epoch-level metrics at the end of the test epoch.
- Return type:
- on_train_epoch_start()[source]¶
Prepare for the training epoch.
This method sets the modules to training mode.
- Return type:
- on_validation_epoch_end()[source]¶
Compute and log epoch-level metrics at the end of the validation epoch.
- Return type:
- on_validation_epoch_start()[source]¶
Prepare for the validation epoch.
This method sets the modules to evaluation mode and calls the
on_evaluation_epoch_start
method of each evaluation task.- Return type:
- postprocessors¶
A
ModuleDict
, where the keys are the names of the modalities and the values are the postprocessor modules. This can beNone
if no postprocessor modules are provided.
- test_step(batch, batch_idx)[source]¶
Run a single test step.
- Parameters:
batch (dict[str, torch.Tensor]) – The batch of data to process.
batch_idx (int) – The index of the batch.
- Returns:
The loss for the batch or
None
if the loss function is not provided.- Return type:
Optional[torch.Tensor]
- training_step(batch, batch_idx)[source]¶
Compute the loss for the batch.
- Parameters:
- Returns:
The loss for the batch.
- Return type:
- validation_step(batch, batch_idx)[source]¶
Run a single validation step.
- Parameters:
batch (dict[str, torch.Tensor]) – The batch of data to process.
batch_idx (int) – The index of the batch.
- Returns:
The loss for the batch or
None
if the loss function is not provided.- Return type:
Optional[torch.Tensor]
- class EvaluationSpec(task, run_on_validation=True, run_on_test=True)[source]¶
Bases:
object
Specification for an evaluation task.
-
task:
Any
¶ The evaluation task module. This is expected to be an instance of
EvaluationHooks
.
-
task:
- class LossPairSpec(modalities, weight=1.0)[source]¶
Bases:
object
Specification for a pair of modalities to compute the contrastive loss.
- class ModuleKeySpec(encoder_key=None, head_key=None, postprocessor_key=None)[source]¶
Bases:
object
Module key specification for mapping modules to modalities.
-
encoder_key:
Optional
[str
] = None¶ The key of the encoder module. If not provided, the modality name is used.
-
encoder_key: