fl4health.losses.contrastive_loss module¶
- class MoonContrastiveLoss(device, temperature=0.5)[source]¶
Bases:
Module
- __init__(device, temperature=0.5)[source]¶
This contrastive loss is implemented based on https://github.com/QinbinLi/MOON. Contrastive loss aims to enhance the similarity between the features and their positive pairs while reducing the similarity between the features and their negative pairs. :type device:
device
:param device: device to use for computation :type device: torch.device :type temperature:float
:param temperature: temperature to scale the logits :type temperature: float
- compute_negative_similarities(features, negative_pairs)[source]¶
This function computes the cosine similarities of the batch of features provided with the set of batches of negative pairs.
- Parameters:
features (torch.Tensor) – Main features, shape (batch_size, n_features)
negative_pairs (torch.Tensor) – Negative pairs of main features, shape (n_pairs, batch_size, n_features)
- Returns:
- Cosine similarities of the batch of features provided with the set of batches of
negative pairs. The shape is n_pairs x batch_size
- Return type:
torch.Tensor
- forward(features, positive_pairs, negative_pairs)[source]¶
Compute the contrastive loss based on the features, positive pair and negative pairs. While every feature has a positive pair, it can have multiple negative pairs. The loss is computed based on the similarity between the feature and its positive pair relative to negative pairs.
- Parameters:
features (torch.Tensor) – Main features, shape (batch_size, n_features)
positive_pairs (torch.Tensor) – Positive pair of main features, shape (1, batch_size, n_features)
negative_pairs (torch.Tensor) – Negative pairs of main features, shape (n_pairs, batch_size, n_features)
- Returns:
Contrastive loss value
- Return type:
torch.Tensor
- class NtXentLoss(device, temperature=0.5)[source]¶
Bases:
Module
- __init__(device, temperature=0.5)[source]¶
Implementation of Normalized Temperature-Scaled Cross Entropy Loss (NT-Xent) proposed in https://papers.nips.cc/paper_files/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html and notably used in SimCLR (https://arxiv.org/pdf/2002.05709) and FedSimCLR as proposed in Fed-X (https://arxiv.org/pdf/2207.09158).
NT-Xent is a contrastive loss in which each feature has a positive pair and the rest of the features are considered negative. It is computed based on the similarity of positive pairs relative to negative pairs.
- Parameters:
device (torch.device) – device to use for computation
temperature (float) – temperature to scale the logits
- forward(features, transformed_features)[source]¶
Compute the contrastive loss based on the features and transformed_features. Given N features and N transformed_features per batch, features[i] and transformed_features[i] are positive pairs and the remaining 2N - 2 are negative pairs.
- Parameters:
features (torch.Tensor) – Features of input without transformation applied. Shaped (batch_size, feature_dimension).
transformed_features (torch.Tensor) – Features of input with transformation applied. Shaped (batch_size, feature_dimension).
- Returns:
Contrastive loss value
- Return type:
torch.Tensor