mmlearn.tasks.zero_shot_classification module

Zero-shot classification evaluation task.

class ClassificationTaskSpec(query_modality, top_k)[source]

Bases: object

Specification for a classification task.

query_modality: str

The modality of the query input.

top_k: list[int]

The top-k values for which to compute the classification metrics like accuracy.

class ZeroShotClassification(task_specs, tokenizer)[source]

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:
evaluation_step(pl_module, batch, batch_idx)[source]

Compute logits and update metrics.

Parameters:
  • pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.

  • batch (dict[str, torch.Tensor]) – A batch of data.

  • batch_idx (int) – The index of the batch.

Return type:

None

on_evaluation_epoch_end(pl_module)[source]

Compute and reset metrics.

Parameters:

pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.

Returns:

The computed metrics.

Return type:

dict[str, Any]

on_evaluation_epoch_start(pl_module)[source]

Set up the evaluation task.

Parameters:

pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.

Raises:

ValueError

  • If the task is not being run for validation or testing. - If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).

Return type:

None