mmlearn.modules.losses.data2vec

Implementation of Data2vec loss function.

Classes

Data2VecLoss

Data2Vec loss function.

class Data2VecLoss(beta=0, loss_scale=None, reduction='none')[source]

Data2Vec loss function.

Parameters:
  • beta (float, optional, default=0) – Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

  • loss_scale (Optional[float], optional, default=None) – Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

  • reduction (str, optional, default='none') – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

Raises:

ValueError – If the reduction mode is not supported.

forward(x, y)[source]

Compute the Data2Vec loss.

Parameters:
  • x (torch.Tensor) – Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

  • y (torch.Tensor) – Target embeddings of shape (batch_size, num_patches, embedding_dim).

Returns:

Data2Vec loss value.

Return type:

torch.Tensor

Raises:

ValueError – If the shapes of x and y do not match.