fl4health.losses.contrastive_loss module¶
- class MoonContrastiveLoss(device, temperature=0.5)[source]¶
Bases:
Module
- __init__(device, temperature=0.5)[source]¶
This contrastive loss is implemented based on https://github.com/QinbinLi/MOON.
Contrastive loss aims to enhance the similarity between the features and their positive pairs while reducing the similarity between the features and their negative pairs.
- Parameters:
device (torch.device) – device to use for computation
temperature (float) – temperature to scale the logits
- compute_negative_similarities(features, negative_pairs)[source]¶
This function computes the cosine similarities of the batch of features provided with the set of batches of negative pairs.
- Parameters:
features (torch.Tensor) – Main features, shape (
batch_size
,n_features
)negative_pairs (torch.Tensor) – Negative pairs of main features, shape (
n_pairs
,batch_size
,n_features
)
- Returns:
Cosine similarities of the batch of features provided with the set of batches of negative pairs. The shape is
n_pairs
xbatch_size
- Return type:
torch.Tensor
- forward(features, positive_pairs, negative_pairs)[source]¶
Compute the contrastive loss based on the features, positive pair and negative pairs. While every feature has a positive pair, it can have multiple negative pairs. The loss is computed based on the similarity between the feature and its positive pair relative to negative pairs.
- Parameters:
features (torch.Tensor) – Main features, shape (
batch_size
,n_features
)positive_pairs (torch.Tensor) – Positive pair of main features, shape (1,
batch_size
,n_features
)negative_pairs (torch.Tensor) – Negative pairs of main features, shape (
n_pairs
,batch_size
,n_features
)
- Returns:
Contrastive loss value
- Return type:
torch.Tensor
- class NtXentLoss(device, temperature=0.5)[source]¶
Bases:
Module
- __init__(device, temperature=0.5)[source]¶
Implementation of Normalized Temperature-Scaled Cross Entropy Loss (NT-Xent) proposed in https://papers.nips.cc/paper_files/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html
and notably used in:
SimCLR (https://arxiv.org/pdf/2002.05709)
FedSimCLR as proposed in Fed-X (https://arxiv.org/pdf/2207.09158).
NT-Xent is a contrastive loss in which each feature has a positive pair and the rest of the features are considered negative. It is computed based on the similarity of positive pairs relative to negative pairs.
- Parameters:
device (torch.device) – device to use for computation
temperature (float) – temperature to scale the logits
- forward(features, transformed_features)[source]¶
Compute the contrastive loss based on the features and
transformed_features
. Given N features and Ntransformed_features
per batch,features[i]
andtransformed_features[i]
are positive pairs and the remaining \(2N - 2\) are negative pairs.- Parameters:
features (torch.Tensor) – Features of input without transformation applied. Shaped (
batch_size
,feature_dimension
).transformed_features (torch.Tensor) – Features of input with transformation applied. Shaped (
batch_size
,feature_dimension
).
- Returns:
Contrastive loss value
- Return type:
torch.Tensor