mmlearn.tasks.zero_shot_retrieval module¶
Zero-shot cross-modal retrieval evaluation task.
- class RetrievalTaskSpec(query_modality, target_modality, top_k)[source]¶
Bases:
object
Specification for a retrieval task.
- class ZeroShotCrossModalRetrieval(task_specs)[source]¶
Bases:
EvaluationHooks
Zero-shot cross-modal retrieval evaluation task.
This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.
- Parameters:
task_specs (list[RetrievalTaskSpec]) – A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.
- evaluation_step(pl_module, batch, batch_idx)[source]¶
Run the forward pass and update retrieval recall metrics.
- Parameters:
pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.
batch (dict[str, torch.Tensor]) – A dictionary of batched input tensors.
batch_idx (int) – The index of the batch.
- Return type: