mmlearn.modules.metrics.retrieval_recall

Retrieval Recall@K metric.

Classes

RetrievalRecallAtK

Retrieval Recall@K metric.

class RetrievalRecallAtK(top_k, reduction=None, aggregation='mean', **kwargs)[source]

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.

  2. For each query, sort the database in decreasing order of similarity.

  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:
  • top_k (int) – The number of top elements to consider for computing the Recall@K.

  • reduction ({"mean", "sum", "none", None}, default="sum") – Specifies the reduction to apply after computing the pairwise cosine similarity scores.

  • aggregation ({"mean", "median", "min", "max"} or callable, default="mean") – Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

  • kwargs (Any) – Additional arguments to be passed to the torchmetrics.Metric class.

Raises:

ValueError

  • If the top_k is not a positive integer or None. - If the reduction is not one of {“mean”, “sum”, “none”, None}. - If the aggregation is not one of {“mean”, “median”, “min”, “max”} or a custom callable function.

compute()[source]

Compute the metric.

Returns:

The computed metric.

Return type:

torch.Tensor

forward(*args, **kwargs)[source]

Forward method is not supported.

Raises:

NotImplementedError – The forward method is not supported for this metric.

Return type:

Any

full_state_update: bool = False
higher_is_better: bool = True
indexes: list[Tensor]
is_differentiable: bool = False
num_samples: Tensor
update(x, y, indexes)[source]

Check shape, convert dtypes and add to accumulators.

Parameters:
  • x (torch.Tensor) – Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

  • y (torch.Tensor) – Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

  • indexes (torch.Tensor) – Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

Raises:

ValueError – If indexes is None.

Return type:

None

x: list[Tensor]
y: list[Tensor]