fl4health.losses.contrastive_loss module

class MoonContrastiveLoss(device, temperature=0.5)[source]

Bases: Module

__init__(device, temperature=0.5)[source]

This contrastive loss is implemented based on https://github.com/QinbinLi/MOON. Contrastive loss aims to enhance the similarity between the features and their positive pairs while reducing the similarity between the features and their negative pairs. :type device: device :param device: device to use for computation :type device: torch.device :type temperature: float :param temperature: temperature to scale the logits :type temperature: float

compute_negative_similarities(features, negative_pairs)[source]

This function computes the cosine similarities of the batch of features provided with the set of batches of negative pairs.

Parameters:
  • features (torch.Tensor) – Main features, shape (batch_size, n_features)

  • negative_pairs (torch.Tensor) – Negative pairs of main features, shape (n_pairs, batch_size, n_features)

Returns:

Cosine similarities of the batch of features provided with the set of batches of

negative pairs. The shape is n_pairs x batch_size

Return type:

torch.Tensor

forward(features, positive_pairs, negative_pairs)[source]

Compute the contrastive loss based on the features, positive pair and negative pairs. While every feature has a positive pair, it can have multiple negative pairs. The loss is computed based on the similarity between the feature and its positive pair relative to negative pairs.

Parameters:
  • features (torch.Tensor) – Main features, shape (batch_size, n_features)

  • positive_pairs (torch.Tensor) – Positive pair of main features, shape (1, batch_size, n_features)

  • negative_pairs (torch.Tensor) – Negative pairs of main features, shape (n_pairs, batch_size, n_features)

Returns:

Contrastive loss value

Return type:

torch.Tensor

class NtXentLoss(device, temperature=0.5)[source]

Bases: Module

__init__(device, temperature=0.5)[source]

Implementation of Normalized Temperature-Scaled Cross Entropy Loss (NT-Xent) proposed in https://papers.nips.cc/paper_files/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html and notably used in SimCLR (https://arxiv.org/pdf/2002.05709) and FedSimCLR as proposed in Fed-X (https://arxiv.org/pdf/2207.09158).

NT-Xent is a contrastive loss in which each feature has a positive pair and the rest of the features are considered negative. It is computed based on the similarity of positive pairs relative to negative pairs.

Parameters:
  • device (torch.device) – device to use for computation

  • temperature (float) – temperature to scale the logits

forward(features, transformed_features)[source]

Compute the contrastive loss based on the features and transformed_features. Given N features and N transformed_features per batch, features[i] and transformed_features[i] are positive pairs and the remaining 2N - 2 are negative pairs.

Parameters:
  • features (torch.Tensor) – Features of input without transformation applied. Shaped (batch_size, feature_dimension).

  • transformed_features (torch.Tensor) – Features of input with transformation applied. Shaped (batch_size, feature_dimension).

Returns:

Contrastive loss value

Return type:

torch.Tensor