fl4health.losses.cosine_similarity_loss module

class CosineSimilarityLoss(device, dim=-1)[source]

Bases: Module

forward(x1, x2)[source]

Assumes that the tensors are provided “batch first” and computes the mean (over the batch) of the absolute value of the cosine similarity between features in x1 and x2

Parameters:
  • x1 (torch.Tensor) – First set of tensors to compute cosine sim, shape (batch_size, n_features)

  • x2 (torch.Tensor) – Second set of tensors to compute cosine sim, shape (batch_size, n_features)

Returns:

Mean absolute value of the cosine similarity between vectors across the mutual batch size.

Return type:

torch.Tensor