mmlearn.modules.metrics.retrieval_recall¶
Retrieval Recall@K metric.
Classes
Retrieval Recall@K metric. |
- class RetrievalRecallAtK(top_k, reduction=None, aggregation='mean', **kwargs)[source]¶
Retrieval Recall@K metric.
Computes the Recall@K for retrieval tasks. The metric is computed as follows:
Compute the cosine similarity between the query and the database.
For each query, sort the database in decreasing order of similarity.
Compute the Recall@K as the number of true positives among the top K elements.
- Parameters:
top_k (int) – The number of top elements to consider for computing the Recall@K.
reduction ({"mean", "sum", "none", None}, default="sum") – Specifies the reduction to apply after computing the pairwise cosine similarity scores.
aggregation ({"mean", "median", "min", "max"} or callable, default="mean") – Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument
'dim'
and return a single scalar value.kwargs (Any) – Additional arguments to be passed to the
torchmetrics.Metric
class.
- Raises:
If the top_k is not a positive integer or None. - If the reduction is not one of {“mean”, “sum”, “none”, None}. - If the aggregation is not one of {“mean”, “median”, “min”, “max”} or a custom callable function.
- forward(*args, **kwargs)[source]¶
Forward method is not supported.
- Raises:
NotImplementedError – The forward method is not supported for this metric.
- Return type:
- update(x, y, indexes)[source]¶
Check shape, convert dtypes and add to accumulators.
- Parameters:
x (torch.Tensor) – Embeddings (unnormalized) of shape
(N, D)
whereN
is the number of samples and D is the number of dimensions.y (torch.Tensor) – Embeddings (unnormalized) of shape
(M, D)
whereM
is the number of samples andD
is the number of dimensions.indexes (torch.Tensor) – Index tensor of shape
(N,)
whereN
is the number of samples. This specifies which sample iny
is the positive match for each sample inx
.
- Raises:
ValueError – If indexes is None.
- Return type: