User Guide

mmlearn contains a collection of tools and utilities to help researchers and practitioners easily set up and run training or evaluation experiments for multimodal representation learning methods. The toolkit is designed to be modular and extensible. We aim to provide a high degree of flexibility in using existing methods, while also allowing users to easily add support for new modalities of data, datasets, models and pretraining or evaluation methods.

Much of the power and flexibility of mmlearn comes from building on top of the PyTorch Lightning framework and using Hydra and hydra-zen for configuration management. Together, these tools make it easy to define and run experiments with different configurations, and to scale up experiments to run on a SLURM cluster.

The goal of this guide is to give you a brief overview of what mmlearn is and how you can get started using it.

Note

mmlearn currently only supports training and evaluation of encoder-only models.

For more detailed information on the features and capabilities of mmlearn, please refer to the API Reference.

Defining a Dataset

Datasets in mmlearn can be defined using PyTorch’s Dataset or IterableDataset classes. However, there are two additional requirements for datasets in mmlearn:

  1. The dataset must return an instance of Example from the __getitem__() method or the __iter__() method.

  2. The Example object returned by the dataset must contain the key 'example_index' and use modality-specific keys from the Modalities registry to store the data.

Example 1: Defining a map-style dataset in mmlearn:

from torch.utils.data.dataset import Dataset

from mmlearn.datasets.core import Example, Modalities
from mmlearn.constants import EXAMPLE_INDEX_KEY


class MyMapStyleDataset(Dataset[Example]):
   ...
   def __getitem__(self, idx: int) -> Example:
      ...
      return Example(
         {
            EXAMPLE_INDEX_KEY: idx,
            Modalities.TEXT.name: ...,
            Modalities.RGB.name: ...,
            Modalities.RGB.target: ...,
            Modalities.TEXT.mask: ...,
            ...
         }
      )

Example 2: Defining an iterable-style dataset in mmlearn:

from torch.utils.data.dataset import IterableDataset

from mmlearn.datasets.core import Example, Modalities
from mmlearn.constants import EXAMPLE_INDEX_KEY


class MyIterableStyleDataset(IterableDataset[Example]):
   ...
   def __iter__(self) -> Generator[Example, None, None]:
      ...
      idx = 0
      for item in items:
         yield Example(
            {
               EXAMPLE_INDEX_KEY: idx,
               Modalities.TEXT.name: ...,
               Modalities.AUDIO.name: ...,
               Modalities.TEXT.mask: ...,
               Modalities.AUDIO.mask: ...,
               ...
            }
         )
         idx += 1

The Example class represents a single example in the dataset and all the attributes associated with it. The class is an extension of the OrderedDict class that provides attribute-style access to the dictionary values and handles the creation of the 'example_ids' tuple, combining the 'example_index' and 'dataset_index' values. The 'example_index' key is created by the dataset object for each example returned by the dataset. On the other hand, the 'dataset_index' key is created by the CombinedDataset each Example object returned by the dataset.

Note

All dataset objects in mmlearn are wrapped in the CombinedDataset class, which is a subclass of torch.utils.data.Dataset. As such, the user almost never has to add/define the 'dataset_index' key explicitly.

Since batching typically combines data from the same modality into one tensor, both the 'example_index' and 'dataset_index' keys are essential for uniquely identifying paired examples across different modalities from the same dataset. The find_matching_indices() function does exactly this by finding the indices of the examples in a batch that have the same 'example_ids' tuple.

Modalities is an instance of ModalityRegistry singleton class that serves as a global registry for all the modalities supported by mmlearn. It allows dot-style access registered modalities and their properties. For example, the 'RGB' modality can be accessed using Modalities.RGB (returns string 'rgb') and the 'target' property of the 'RGB' modality can be accessed using Modalities.RGB.target (returns the string 'rgb_target'). It also provides a method to register new modalities and their properties. For example, the following code snippet shows how to register a new 'DNA' modality:

from mmlearn.datasets.core import Modalities

Modalities.register_modality("dna")

Adding New Modules

Modules are building blocks for models and tasks in mmlearn. They can be anything from encoders, layers, losses, optimizers, learning rate schedulers, metrics, etc. Modules in mmlearn are generally defined by extending PyTorch’s nn.Module class.

Users have the flexibility to design new modules according to their requirements, with the exception of encoder modules and modules associated with specific pre-defined tasks (e.g., loss functions for the ContrastivePretraining task). The forward method of encoder modules must accept a dictionary as input, where the keys are the names of the modalities and the values are the corresponding (batched) tensors/data. This format makes it easier to reuse the encoder with different modalities and different tasks. In addition, the forward method must return a list-like object where the first element is the last layer’s output. The following code snippet shows how to define a new text encoder module:

import torch
from torch import nn

from mmlearn.datasets.core import Modalities


class MyTextEncoder(nn.Module):
   def __init__(self, input_dim: int, output_dim: int):
      super().__init__()
      self.encoder = ...

   def forward(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
      out = self.encoder(
         inputs[Modalities.TEXT.name],
         inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
         ),
      )
      return (out,)

For modules associated with pre-defined tasks, the new modules must adhere to the same function signature as the existing modules for that task. For instance, the forward method of a new loss function for the ContrastivePretraining task must have the following signature to be compatible with the existing loss functions for the task:

import torch

from mmlearn.tasks.contrastive_pretraining import LossPairSpec

def my_contrastive_loss(
   embeddings: dict[str, torch.Tensor],
   example_ids: dict[str, torch.Tensor],
   logit_scale: torch.Tensor,
   modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
   ...

Adding New Tasks

Tasks in mmlearn represent the different training and/or evaluation objectives that can be performed on the data using the different modules. Tasks that require training should extend the TrainingTask class, while tasks involving only evaluation should extend the EvaluationHooks class.

Training Tasks

The TrainingTask class is an extension of the LightningModule class, which itself is an extension of the Module class. The class provides a common interface for training tasks in mmlearn. It allows users to define the training loop, validation loop, test loop, and the setup for the model, optimizer, learning rate scheduler and loss function, all in one place (a functionality inherited from PyTorch Lightning). The class also provides hooks for customizing the training loop, validation loop, and test loop, as well as a suite of other functionalities like logging, checkpointing and handling distributed training.

See also

For more information on the features and capabilities of the TrainingTask class inherited from PyTorch Lightning, please refer to the PyTorch Lightning documentation.

To be used with the PyTorch Lightning Trainer, extensions of the TrainingTask class must define a training_step method. The following code snippet shows the minimum requirements for defining a new task in mmlearn:

from typing import Any, Optional, Union
from functools import partial

import torch

from mmlearn.tasks.base import TrainingTask

class MyTask(TrainingTask):
   def __init__(
      self,
      optimizer: Optional[partial[torch.optim.Optimizer]],
      loss_fn: Optional[torch.nn.Module],
      lr_scheduler: Optional[
         Union[
            dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
            partial[torch.optim.lr_scheduler.LRScheduler],
         ]
      ] = None,
   ) -> None:
      super().__init__(optimizer=optimizer, loss_fn=loss_fn, lr_scheduler=lr_scheduler)

      # Since this class also inherits from torch.nn.Module, we can define the
      # model and its components directly in the constructor and also define
      # a forward method for the model as an instance method of this class.
      # Alternatively, we can pass the model as an argument to the constructor
      # and assign it to an instance variable.
      self.model = ...

   def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
      outputs = self.model(batch) # or self(batch) if a forward method is defined in this class

      # maybe process outputs here

      loss = self.loss_fn(outputs, ...)
      return loss

Evaluation Tasks

The EvaluationHooks class is intended to be used for evaluation tasks that don’t require training, e.g. zero-shot evaluation tasks (as opposed to evaluation tasks like linear probing, which require training). The class provides an interface for defining and customizing the evaluation loop.

Classes that inherit from EvaluationHooks cannot be run/used on their own. They must be used in conjunction with a training task, which will call the hooks defined in the evaluation task during the evaluation phase. This way, multiple evaluation tasks can be defined and used with the same training task. The model to be evaluated is provided by the training task to the evaluation task.

Training tasks that wish to use one or more evaluation tasks must accept an instance of the evaluation task(s) as an argument to the constructor and must define a validation_step and/or test_step method that calls the evaluation_step method of the evaluation task(s).

Creating and Configuring a Project

A project in mmlearn can be thought of as a collection of related experiments. Within a project, you can reuse components from mmlearn (e.g., datasets, models, tasks) or define new ones and use them all together for experiments.

To create a new project, create a new directory following the structure below:

my_project/
├── configs/
│   ├── __init__.py
│   └── experiment/
│       ├── my_experiment.yaml
├── README.md (optional)
├── requirements.txt (optional)

The configs/ directory contains all the configurations, both structured configs and YAML config files for the experiments in the project. The configs/experiment/ directory contains the .yaml files for the experiments associated with the project. These .yaml files use the Hydra configuration format, which also allows overriding the configuration options/values from the command line.

The __init__.py file in the configs/ directory is required to make the configs/ directory a Python package, allowing hydra to compose configurations from .yaml files as well as structured configs from python modules. More on this in the next section.

Optionally, you can also include a README.md file with a brief description of the project and a requirements.txt file with the dependencies required to run the project.

Specifying Configurable Components

One of the key features of the Hydra configuration system is the ability to compose configurations from multiple sources, including the command line, .yaml files and structured configs from Python modules. Structured Configs in Hydra use Python dataclass() to define the configuration schema. This allows for both static and runtime type-checking of the configuration. Hydra-zen extends Hydra to makes it easy to dynamically generate dataclass-backed configurations for any class or function simply by adding a decorator to the class or function.

mmlearn provides a pre-populated config store, external_store, which can be used as a decorator to register configurable components. This config store already contains configurations for common components like PyTorch optimizers, learning rate schedulers, loss functions and samplers, as well as PyTorch Lightning’s Trainer callbacks and loggers. To dynamically add new configurable components to the store, simply add the external_store decorator to the class or function definition.

For example, the following code snippet shows how to register a new dataset class:

from torch.utils.data.dataset import Dataset

from mmlearn.conf import external_store
from mmlearn.constants import EXAMPLE_INDEX_KEY
from mmlearn.datasets.core import Example, Modalities


@external_store(group="datasets")
class MyMapStyleDataset(Dataset[Example]):
   ...
   def __getitem__(self, idx: int) -> Example:
      ...
      return Example(
         {
            EXAMPLE_INDEX_KEY: idx,
            Modalities.TEXT.name: ...,
            Modalities.RGB.name: ...,
            Modalities.RGB.target: ...,
            Modalities.TEXT.mask: ...,
            ...
         }
      )

The external_store decorator immediately add the class to the config store once the Python interpreter loads the module containing the class. This is why the configs/ directory must be a Python package and why modules containing user-defined configurable components must be imported in the configs/__init__.py file.

The group argument specifies the config group under which the configurable component will be registered. This allows users to easily reference the component in the configurations using the group name and the class name. The available config groups in mmlearn are:

  • datasets: Contains all the dataset classes.

  • datasets/masking: Contains all the configurable classes and functions for masking input data.

  • datasets/tokenizers: Contains all the configurable classes and functions for converting raw inputs to tokens.

  • datasets/transforms: Contains all the configurable classes and functions for transforming input data.

  • dataloader/sampler: Contains all the dataloader sampler classes.

  • modules/encoders: Contains all the encoder modules.

  • modules/layers: For layers that can be used independent of the model.

  • modules/losses: Contains all the loss functions.

  • modules/optimizers: Contains all the optimizers.

  • modules/lr_schedulers: Contains all the learning rate schedulers.

  • modules/metrics: Contains all the evaluation metrics.

  • tasks: Contains all the task classes.

  • trainer/callbacks: Contains all the PyTorch Lightning Trainer callbacks.

  • trainer/logger: Contains all the PyTorch Lightning Trainer loggers.

The Base Configuration

The base configuration for all experiments in mmlearn are defined in the MMLearnConf dataclass. This serves as the base configuration for all experiments and can be extended to include additional configuration options, following Hydra’s override syntax.

The base configuration for mmlearn is shown below:

experiment_name: ???
job_type: train
seed: null
datasets:
   train: null
   val: null
   test: null
dataloader:
   train:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
   val:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
   test:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
task: ???
trainer:
   _target_: lightning.pytorch.trainer.trainer.Trainer
   accelerator: auto
   strategy: auto
   devices: auto
   num_nodes: 1
   precision: null
   logger: null
   callbacks: null
   fast_dev_run: false
   max_epochs: null
   min_epochs: null
   max_steps: -1
   min_steps: null
   max_time: null
   limit_train_batches: null
   limit_val_batches: null
   limit_test_batches: null
   limit_predict_batches: null
   overfit_batches: 0.0
   val_check_interval: null
   check_val_every_n_epoch: 1
   num_sanity_val_steps: null
   log_every_n_steps: null
   enable_checkpointing: true
   enable_progress_bar: true
   enable_model_summary: true
   accumulate_grad_batches: 1
   gradient_clip_val: null
   gradient_clip_algorithm: null
   deterministic: null
   benchmark: null
   inference_mode: true
   use_distributed_sampler: true
   profiler: null
   detect_anomaly: false
   barebones: false
   plugins: null
   sync_batchnorm: false
   reload_dataloaders_every_n_epochs: 0
   default_root_dir: ${hydra:runtime.output_dir}/checkpoints
tags:
   - ${experiment_name}
resume_from_checkpoint: null
strict_loading: true
torch_compile_kwargs:
   disable: true
   fullgraph: false
   dynamic: null
   backend: inductor
   mode: null
   options: null

The config keys with a value of ??? are placeholders that must be overridden in the experiment configurations. While the dataset key in the dataloader group is also a placeholder, it should not be provided as it will be automatically filled in from the datasets group.

Configuring an Experiment

To configure an experiment, create a new .yaml file in the configs/experiment/ directory of the project. The configuration file should define the experiment-specific configuration options and override the base configuration options as needed. Configurable components from the config store can be referenced by name in the configuration file under the defaults list. The following code snippet shows an example configuration file for an experiment:

# @package _global_

defaults:
- /datasets@datasets.train.my_iterable: MyIterableStyleDataset
- /datasets@datasets.train.my_map: MyMapStyleDataset
- /modules/encoders@task.encoders.text: MyTextEncoder
- /modules/encoders@task.encoders.rgb: MyRGBEncoder
- /modules/losses@task.loss: ContrastiveLoss
- /modules/optimizers@task.optimizer: AdamW
- /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR
- /eval_task@task.evaluation_tasks.retrieval.task: ZeroShotCrossModalRetrieval
- /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor
- /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint
- /trainer/callbacks@trainer.callbacks.early_stopping: EarlyStopping
- /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary
- /trainer/logger@trainer.logger.wandb: WandbLogger
- override /task: ContrastivePretraining
- _self_

seed: 42

datasets:
   train:
      my_iterable:
         my_iterable_arg1: ...
      my_map:
         my_map_arg1: ...

dataloader:
   train:
      batch_size: 64

task:
   encoders:
      text:
         text_arg1: ...
      rgb:
         rgb_arg1: ...
   evaluation_tasks:
      retrieval:
         task:
         task_specs:
            - query_modality: text
               target_modality: rgb
               top_k: [10, 200]
            - query_modality: rgb
               target_modality: text
               top_k: [10, 200]
         run_on_validation: false
         run_on_test: true

Running an Experiment

To run an experiment locally, use the following command:

mmlearn_run 'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name

Tip

You can see the full config for an experiment without running it by adding the --help flag to the command.

mmlearn_run 'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name \
   task=my_task \ # required for the command to run
   --help

To run the experiment on a SLURM cluster, use the following command:

mmlearn_run --multirun \
   hydra.launcher.mem_per_cpu=5G \
   hydra.launcher.qos=your_qos \
   hydra.launcher.partition=your_partition \
   hydra.launcher.gres=gpu:4 \
   hydra.launcher.cpus_per_task=8 \
   hydra.launcher.tasks_per_node=4 \
   hydra.launcher.nodes=1 \
   hydra.launcher.stderr_to_stdout=true \
   hydra.launcher.timeout_min=720 \
   'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name

This uses the submitit launcher plugin built into Hydra to submit the experiment to the SLURM scheduler with the specified resources.

Note

After the job is submitted, it is okay to cancel the program with Ctrl+C. The job will continue running on the cluster. You can also add & at the end of the command to run it in the background.