fl4health.losses.cosine_similarity_loss module¶
- class CosineSimilarityLoss(device, dim=-1)[source]¶
Bases:
Module
- __init__(device, dim=-1)[source]¶
Cosine similarity loss between two torch Tensors
- Parameters:
device (torch.device) – Which device this loss should be computed on
dim (int, optional) – Dimension where cosine similarity is computed. Defaults to -1.
- forward(x1, x2)[source]¶
Assumes that the tensors are provided “batch first” and computes the mean (over the batch) of the absolute value of the cosine similarity between features in x1 and x2
- Parameters:
x1 (torch.Tensor) – First set of tensors to compute cosine sim, shape (
batch_size
,n_features
)x2 (torch.Tensor) – Second set of tensors to compute cosine sim, shape (
batch_size
,n_features
)
- Returns:
Mean absolute value of the cosine similarity between vectors across the mutual batch size.
- Return type:
torch.Tensor