Source code for fl4health.losses.cosine_similarity_loss
import torch
import torch.nn as nn
[docs]
class CosineSimilarityLoss(nn.Module):
[docs]
def __init__(self, device: torch.device, dim: int = -1) -> None:
"""
Cosine similarity loss between two torch Tensors
Args:
device (torch.device): Which device this loss should be computed on
dim (int, optional): Dimension where cosine similarity is computed. Defaults to -1.
"""
super().__init__()
self.cosine_similarity_function = nn.CosineSimilarity(dim=dim).to(device)
[docs]
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Assumes that the tensors are provided "batch first" and computes the mean (over the batch) of the absolute
value of the cosine similarity between features in x1 and x2
Args:
x1 (torch.Tensor): First set of tensors to compute cosine sim, shape (``batch_size``, ``n_features``)
x2 (torch.Tensor): Second set of tensors to compute cosine sim, shape (``batch_size``, ``n_features``)
Returns:
torch.Tensor: Mean absolute value of the cosine similarity between vectors across the mutual batch size.
"""
assert len(x1) == len(x2), "Tensors have different batch sizes"
return torch.abs(self.cosine_similarity_function(x1, x2)).mean()