mmlearn.modules.losses.contrastive¶
Implementations of the contrastive loss and its variants.
Classes
Contrastive Loss. |
- class ContrastiveLoss(l2_normalize=False, local_loss=False, gather_with_grad=False, modality_alignment=False, cache_labels=False)[source]¶
Contrastive Loss.
- Parameters:
l2_normalize (bool, optional, default=False) – Whether to L2 normalize the features.
local_loss (bool, optional, default=False) – Whether to calculate the loss locally i.e.
local_features@global_features
.gather_with_grad (bool, optional, default=False) – Whether to gather tensors with gradients.
modality_alignment (bool, optional, default=False) – Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.
cache_labels (bool, optional, default=False) – Whether to cache the labels.
- forward(embeddings, example_ids, logit_scale, modality_loss_pairs)[source]¶
Calculate the contrastive loss.
- Parameters:
embeddings (dict[str, torch.Tensor]) – Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.
example_ids (dict[str, torch.Tensor]) – Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.
logit_scale (torch.Tensor) – Scale factor for the logits.
modality_loss_pairs (list[LossPairSpec]) – Specification of the modality pairs for which the loss should be calculated.
- Returns:
The contrastive loss.
- Return type: