mmlearn.modules.losses.contrastive.ContrastiveLoss

class ContrastiveLoss(l2_normalize=False, local_loss=False, gather_with_grad=False, modality_alignment=False, cache_labels=False)[source]

Bases: Module

Contrastive Loss.

Parameters:
  • l2_normalize (bool, optional, default=False) – Whether to L2 normalize the features.

  • local_loss (bool, optional, default=False) – Whether to calculate the loss locally i.e. local_features@global_features.

  • gather_with_grad (bool, optional, default=False) – Whether to gather tensors with gradients.

  • modality_alignment (bool, optional, default=False) – Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

  • cache_labels (bool, optional, default=False) – Whether to cache the labels.

Methods

Attributes

forward(embeddings, example_ids, logit_scale, modality_loss_pairs)[source]

Calculate the contrastive loss.

Parameters:
  • embeddings (dict[str, torch.Tensor]) – Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

  • example_ids (dict[str, torch.Tensor]) – Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

  • logit_scale (torch.Tensor) – Scale factor for the logits.

  • modality_loss_pairs (list[LossPairSpec]) – Specification of the modality pairs for which the loss should be calculated.

Returns:

The contrastive loss.

Return type:

torch.Tensor