Source code for fl4health.losses.contrastive_loss

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class MoonContrastiveLoss(nn.Module):
[docs] def __init__( self, device: torch.device, temperature: float = 0.5, ) -> None: """ This contrastive loss is implemented based on https://github.com/QinbinLi/MOON. Contrastive loss aims to enhance the similarity between the features and their positive pairs while reducing the similarity between the features and their negative pairs. Args: device (torch.device): device to use for computation temperature (float): temperature to scale the logits """ super().__init__() self.device = device self.temperature = temperature self.cosine_similarity_function = torch.nn.CosineSimilarity(dim=-1).to(self.device) self.cross_entropy_function = torch.nn.CrossEntropyLoss().to(self.device)
[docs] def compute_negative_similarities(self, features: torch.Tensor, negative_pairs: torch.Tensor) -> torch.Tensor: """ This function computes the cosine similarities of the batch of features provided with the set of batches of negative pairs. Args: features (torch.Tensor): Main features, shape (batch_size, n_features) negative_pairs (torch.Tensor): Negative pairs of main features, shape (n_pairs, batch_size, n_features) Returns: torch.Tensor: Cosine similarities of the batch of features provided with the set of batches of negative pairs. The shape is n_pairs x batch_size """ # Check that features and each of the negatives pairs have the same shape assert features.shape == negative_pairs.shape[1:] # Repeat the feature tensor to compute the similarity of the feature tensor with all negative pairs. repeated_features = features.unsqueeze(0).repeat(len(negative_pairs), 1, 1) return self.cosine_similarity_function(repeated_features, negative_pairs)
[docs] def forward( self, features: torch.Tensor, positive_pairs: torch.Tensor, negative_pairs: torch.Tensor ) -> torch.Tensor: """ Compute the contrastive loss based on the features, positive pair and negative pairs. While every feature has a positive pair, it can have multiple negative pairs. The loss is computed based on the similarity between the feature and its positive pair relative to negative pairs. Args: features (torch.Tensor): Main features, shape (batch_size, n_features) positive_pairs (torch.Tensor): Positive pair of main features, shape (1, batch_size, n_features) negative_pairs (torch.Tensor): Negative pairs of main features, shape (n_pairs, batch_size, n_features) Returns: torch.Tensor: Contrastive loss value """ # TODO: We can extend it to support multiple positive pairs using multi-label classification features = features.to(self.device) positive_pairs = positive_pairs.to(self.device) negative_pairs = negative_pairs.to(self.device) if len(positive_pairs) != 1: raise AssertionError( "Each feature can have only one positive pair. ", "Thus positive pairs should be a tensor of shape (1, batch_size, n_features) ", f"rather than {positive_pairs.shape}", ) positive_pair = positive_pairs[0] assert len(features) == len(positive_pair) # Compute similarity of the batch of features with the provided batch of positive pair features positive_similarity = self.cosine_similarity_function(features, positive_pair) # Store similarities with shape batch_size x 1 logits = positive_similarity.reshape(-1, 1) # Compute the similarity of the batch of features with the collection of batches of negative pair features # Shape of tensor coming out is n_pairs x batch_size negative_pair_similarities = self.compute_negative_similarities(features, negative_pairs) logits = torch.cat((logits, negative_pair_similarities.T), dim=1) logits /= self.temperature labels = torch.zeros(features.size(0)).to(self.device).long() return self.cross_entropy_function(logits, labels)
[docs] class NtXentLoss(nn.Module):
[docs] def __init__(self, device: torch.device, temperature: float = 0.5) -> None: """ Implementation of Normalized Temperature-Scaled Cross Entropy Loss (NT-Xent) proposed in https://papers.nips.cc/paper_files/paper/2016/hash/6b180037abbebea991d8b1232f8a8ca9-Abstract.html and notably used in SimCLR (https://arxiv.org/pdf/2002.05709) and FedSimCLR as proposed in Fed-X (https://arxiv.org/pdf/2207.09158). NT-Xent is a contrastive loss in which each feature has a positive pair and the rest of the features are considered negative. It is computed based on the similarity of positive pairs relative to negative pairs. Args: device (torch.device): device to use for computation temperature (float): temperature to scale the logits """ super().__init__() self.device = device self.temperature = temperature
[docs] def forward(self, features: torch.Tensor, transformed_features: torch.Tensor) -> torch.Tensor: """ Compute the contrastive loss based on the features and transformed_features. Given N features and N transformed_features per batch, features[i] and transformed_features[i] are positive pairs and the remaining 2N - 2 are negative pairs. Args: features (torch.Tensor): Features of input without transformation applied. Shaped (batch_size, feature_dimension). transformed_features (torch.Tensor): Features of input with transformation applied. Shaped (batch_size, feature_dimension). Returns: torch.Tensor: Contrastive loss value """ features.to(self.device) transformed_features.to(self.device) # Ensure features and transformed_features are same shape assert features.shape == transformed_features.shape batch_size = features.shape[0] # Concatenate features and transformed features. Normalize each feature with euclidean norm. all_features = torch.concatenate([features, transformed_features], dim=0).to(self.device) all_features = F.normalize(all_features, dim=-1) # Compute similarity of each features with other features # Equivalent to Cosine Similarity since feature are normalized similarity_matrix = torch.matmul(all_features, all_features.T) # Extract positive pairs from similarity matrix # Positive pairs are elements (i, j) offset from matrix by batch size # As a result of stacking feature and transformed_features similarity_ij = torch.diag(similarity_matrix, diagonal=batch_size) similarity_ji = torch.diag(similarity_matrix, diagonal=-batch_size) positives = torch.concatenate([similarity_ij, similarity_ji], dim=0) # Numerator is the sum of the exponent of positive similarities numerator = torch.exp(positives / self.temperature) # Denominator is all pair combinations except for diagonal which corresponds to a features similarity to itself mask = (torch.ones(2 * batch_size, 2 * batch_size) - torch.eye(2 * batch_size)).to(self.device) similarity_matrix_without_diagonal = torch.mul(similarity_matrix, mask) denominator = torch.exp(similarity_matrix_without_diagonal / self.temperature) # Final loss negative log likelihood losses = -torch.log(numerator / denominator.sum(dim=1)) # Divide by 2 * batch size because pairs are double counted due to the symmetry of the similarity matrix loss = torch.sum(losses) / (2 * batch_size) return loss