mmlearn.modules.losses.data2vec¶
Implementation of Data2vec loss function.
Classes
Data2Vec loss function. |
- class Data2VecLoss(beta=0, loss_scale=None, reduction='none')[source]¶
Data2Vec loss function.
- Parameters:
beta (float, optional, default=0) – Specifies the beta parameter for smooth L1 loss. If
0
, MSE loss is used.loss_scale (Optional[float], optional, default=None) – Scaling factor for the loss. If None, uses
1 / sqrt(embedding_dim)
.reduction (str, optional, default='none') – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
.
- Raises:
ValueError – If the reduction mode is not supported.
- forward(x, y)[source]¶
Compute the Data2Vec loss.
- Parameters:
x (torch.Tensor) – Predicted embeddings of shape
(batch_size, num_patches, embedding_dim)
.y (torch.Tensor) – Target embeddings of shape
(batch_size, num_patches, embedding_dim)
.
- Returns:
Data2Vec loss value.
- Return type:
- Raises:
ValueError – If the shapes of x and y do not match.