mmlearn.tasks.zero_shot_retrieval module

Zero-shot cross-modal retrieval evaluation task.

class RetrievalTaskSpec(query_modality, target_modality, top_k)[source]

Bases: object

Specification for a retrieval task.

query_modality: str

The query modality.

target_modality: str

The target modality.

top_k: list[int]

The top-k values for which to compute the retrieval recall metrics.

class ZeroShotCrossModalRetrieval(task_specs)[source]

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

task_specs (list[RetrievalTaskSpec]) – A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

evaluation_step(pl_module, batch, batch_idx)[source]

Run the forward pass and update retrieval recall metrics.

Parameters:
  • pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.

  • batch (dict[str, torch.Tensor]) – A dictionary of batched input tensors.

  • batch_idx (int) – The index of the batch.

Return type:

None

on_evaluation_epoch_end(pl_module)[source]

Compute the retrieval recall metrics.

Parameters:

pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.

Returns:

A dictionary of evaluation results or None if no results are available.

Return type:

Optional[dict[str, Any]]

on_evaluation_epoch_start(pl_module)[source]

Move the metrics to the device of the Lightning module.

Return type:

None