mmlearn.tasks.zero_shot_classification module¶
Zero-shot classification evaluation task.
- class ClassificationTaskSpec(query_modality, top_k)[source]¶
Bases:
object
Specification for a classification task.
- class ZeroShotClassification(task_specs, tokenizer)[source]¶
Bases:
EvaluationHooks
Zero-shot classification evaluation task.
This task evaluates the zero-shot classification performance.
- Parameters:
task_specs (list[ClassificationTaskSpec]) – A list of classification task specifications.
tokenizer (Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]) – A function to tokenize text inputs.
- evaluation_step(pl_module, batch, batch_idx)[source]¶
Compute logits and update metrics.
- Parameters:
pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.
batch (dict[str, torch.Tensor]) – A batch of data.
batch_idx (int) – The index of the batch.
- Return type:
- on_evaluation_epoch_start(pl_module)[source]¶
Set up the evaluation task.
- Parameters:
pl_module (pl.LightningModule) – A reference to the Lightning module being evaluated.
- Raises:
If the task is not being run for validation or testing. - If the dataset does not have the required attributes to perform zero-shot classification (i.e
id2label
andzero_shot_prompt_templates
).
- Return type: