mmlearn.tasks.contrastive_pretraining module

Contrastive pretraining task.

class AuxiliaryTaskSpec(modality, task, loss_weight=1.0)[source]

Bases: object

Specification for an auxiliary task to run alongside the main task.

loss_weight: float = 1.0

The weight to apply to the auxiliary task loss.

modality: str

The modality of the encoder to use for the auxiliary task.

task: Any

The auxiliary task module. This is expected to be a partially-initialized instance of a LightningModule created using functools.partial(), such that an initialized encoder can be passed as the only argument.

class ContrastivePretraining(encoders, heads=None, postprocessors=None, modality_module_mapping=None, optimizer=None, lr_scheduler=None, init_logit_scale=14.285714285714285, max_logit_scale=100, learnable_logit_scale=True, loss=None, modality_loss_pairs=None, auxiliary_tasks=None, log_auxiliary_tasks_loss=False, compute_validation_loss=True, compute_test_loss=True, evaluation_tasks=None)[source]

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:
  • encoders (dict[str, torch.nn.Module]) – A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

  • heads (Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None) – A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

  • postprocessors (Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None) – A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

  • modality_module_mapping (Optional[dict[str, ModuleKeySpec]], optional, default=None) – A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

  • optimizer (Optional[partial[torch.optim.Optimizer]], optional, default=None) – The optimizer to use for training. This is expected to be a partial() function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

  • lr_scheduler (Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None) – The learning rate scheduler to use for training. This can be a partial() function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

  • init_logit_scale (float, optional, default=1 / 0.07) – The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

  • max_logit_scale (float, optional, default=100) – The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

  • learnable_logit_scale (bool, optional, default=True) – Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

  • loss (Optional[torch.nn.Module], optional, default=None) – The loss function to use.

  • modality_loss_pairs (Optional[list[LossPairSpec]], optional, default=None) – A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

  • auxiliary_tasks (dict[str, AuxiliaryTaskSpec], optional, default=None) –

    Auxiliary tasks to run alongside the main contrastive pretraining task.

    • The auxiliary task module is expected to be a partially-initialized instance of a LightningModule created using functools.partial(), such that an initialized encoder can be passed as the only argument.

    • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.

  • log_auxiliary_tasks_loss (bool, optional, default=False) – Whether to log the loss of auxiliary tasks to the main logger.

  • compute_validation_loss (bool, optional, default=True) – Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

  • compute_test_loss (bool, optional, default=True) – Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

  • evaluation_tasks (Optional[dict[str, EvaluationSpec]], optional, default=None) – Evaluation tasks to run during validation, while training, and during testing.

Raises:

ValueError

  • If the loss function is not provided and either the validation or test loss needs to be computed. - If the given modality is not supported. - If the encoder, head, or postprocessor is not mapped to a modality. - If an unsupported modality is found in the loss pair specification. - If an unsupported modality is found in the auxiliary tasks. - If the auxiliary task is not a partial function. - If the evaluation task is not an instance of EvaluationHooks.

configure_model()[source]

Configure the model.

Return type:

None

encode(inputs, modality, normalize=False)[source]

Encode the input values for the given modality.

Parameters:
  • inputs (dict[str, Any]) – Input values.

  • modality (Modality) – The modality to encode.

  • normalize (bool, optional, default=False) – Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

Returns:

The encoded values for the specified modality.

Return type:

torch.Tensor

encoders

A ModuleDict, where the keys are the names of the modalities and the values are the encoder modules.

evaluation_tasks

A dictionary of evaluation tasks to run during validation, while training, or during testing.

forward(inputs)[source]

Run the forward pass.

Parameters:

inputs (dict[str, Any]) – The input tensors to encode.

Returns:

The encodings for each modality.

Return type:

dict[str, torch.Tensor]

heads

A ModuleDict, where the keys are the names of the modalities and the values are the projection head modules. This can be None if no heads modules are provided.

modality_loss_pairs

A list LossPairSpec instances specifying the pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

on_before_zero_grad(optimizer)[source]

Zero out the gradients of the model.

Return type:

None

on_load_checkpoint(checkpoint)[source]

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

checkpoint (Dict[str, Any]) – The loaded checkpoint.

Return type:

None

on_save_checkpoint(checkpoint)[source]

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

checkpoint (Dict[str, Any]) – The checkpoint to save.

Return type:

None

on_test_epoch_end()[source]

Compute and log epoch-level metrics at the end of the test epoch.

Return type:

None

on_test_epoch_start()[source]

Prepare for the test epoch.

Return type:

None

on_train_epoch_start()[source]

Prepare for the training epoch.

This method sets the modules to training mode.

Return type:

None

on_validation_epoch_end()[source]

Compute and log epoch-level metrics at the end of the validation epoch.

Return type:

None

on_validation_epoch_start()[source]

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Return type:

None

postprocessors

A ModuleDict, where the keys are the names of the modalities and the values are the postprocessor modules. This can be None if no postprocessor modules are provided.

test_step(batch, batch_idx)[source]

Run a single test step.

Parameters:
  • batch (dict[str, torch.Tensor]) – The batch of data to process.

  • batch_idx (int) – The index of the batch.

Returns:

The loss for the batch or None if the loss function is not provided.

Return type:

Optional[torch.Tensor]

training_step(batch, batch_idx)[source]

Compute the loss for the batch.

Parameters:
  • batch (dict[str, Any]) – The batch of data to process.

  • batch_idx (int) – The index of the batch.

Returns:

The loss for the batch.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Run a single validation step.

Parameters:
  • batch (dict[str, torch.Tensor]) – The batch of data to process.

  • batch_idx (int) – The index of the batch.

Returns:

The loss for the batch or None if the loss function is not provided.

Return type:

Optional[torch.Tensor]

class EvaluationSpec(task, run_on_validation=True, run_on_test=True)[source]

Bases: object

Specification for an evaluation task.

run_on_test: bool = True

Whether to run the evaluation task during training.

run_on_validation: bool = True

Whether to run the evaluation task during validation.

task: Any

The evaluation task module. This is expected to be an instance of EvaluationHooks.

class LossPairSpec(modalities, weight=1.0)[source]

Bases: object

Specification for a pair of modalities to compute the contrastive loss.

modalities: tuple[str, str]

The pair of modalities to compute the contrastive loss between.

weight: float = 1.0

The weight to apply to the contrastive loss for the pair of modalities.

class ModuleKeySpec(encoder_key=None, head_key=None, postprocessor_key=None)[source]

Bases: object

Module key specification for mapping modules to modalities.

encoder_key: Optional[str] = None

The key of the encoder module. If not provided, the modality name is used.

head_key: Optional[str] = None

The key of the head module. If not provided, the modality name is used.

postprocessor_key: Optional[str] = None

The key of the postprocessor module. If not provided, the modality name is used.