Getting Started

mmlearn contains a collection of tools and utilities to help researchers and practitioners easily set up and run training or evaluation experiments for multimodal representation learning methods. The toolkit is designed to be modular and extensible. We aim to provide a high degree of flexibility in using existing methods, while also allowing users to easily add support for new modalities of data, datasets, models and pretraining or evaluation methods.

Much of the power and flexibility of mmlearn comes from building on top of the PyTorch Lightning framework and using Hydra and hydra-zen for configuration management. Together, these tools make it easy to define and run experiments with different configurations, and to scale up experiments to run on a SLURM cluster.

The goal of this guide is to give you a brief overview of what mmlearn is and how you can get started using it.

Note

mmlearn currently only supports training and evaluation of encoder-only models.

For more detailed information on the features and capabilities of mmlearn, please refer to the API Reference.

Defining a Dataset

Datasets in mmlearn can be defined using PyTorch’s Dataset or IterableDataset classes. However, there are two additional requirements for datasets in mmlearn:

  1. The dataset must return an instance of Example from the __getitem__() method or the __iter__() method.

  2. The Example object returned by the dataset must contain the key 'example_index' and use modality-specific keys from the Modalities registry to store the data.

Example 1: Defining a map-style dataset in mmlearn:

from torch.utils.data.dataset import Dataset

from mmlearn.datasets.core import Example, Modalities
from mmlearn.constants import EXAMPLE_INDEX_KEY


class MyMapStyleDataset(Dataset[Example]):
   ...
   def __getitem__(self, idx: int) -> Example:
      ...
      return Example(
         {
            EXAMPLE_INDEX_KEY: idx,
            Modalities.TEXT.name: ...,
            Modalities.RGB.name: ...,
            Modalities.RGB.target: ...,
            Modalities.TEXT.mask: ...,
            ...
         }
      )

Example 2: Defining an iterable-style dataset in mmlearn:

from torch.utils.data.dataset import IterableDataset

from mmlearn.datasets.core import Example, Modalities
from mmlearn.constants import EXAMPLE_INDEX_KEY


class MyIterableStyleDataset(IterableDataset[Example]):
   ...
   def __iter__(self) -> Generator[Example, None, None]:
      ...
      idx = 0
      for item in items:
         yield Example(
            {
               EXAMPLE_INDEX_KEY: idx,
               Modalities.TEXT.name: ...,
               Modalities.AUDIO.name: ...,
               Modalities.TEXT.mask: ...,
               Modalities.AUDIO.mask: ...,
               ...
            }
         )
         idx += 1

The Example class represents a single example in the dataset and all the attributes associated with it. The class is an extension of the OrderedDict class that provides attribute-style access to the dictionary values and handles the creation of the 'example_ids' tuple, combining the 'example_index' and 'dataset_index' values.

Modalities is an instance of ModalityRegistry singleton class that serves as a global registry for all the modalities supported by mmlearn. It allows dot-style access registered modalities and their properties. For example, the 'RGB' modality can be accessed using Modalities.RGB (returns string 'rgb') and the 'target' property of the 'RGB' modality can be accessed using Modalities.RGB.target (returns the string 'rgb_target'). It also provides a method to register new modalities and their properties. For example, the following code snippet shows how to register a new 'DNA' modality:

from mmlearn.datasets.core import Modalities

Modalities.register_modality("dna")

Creating a Model

Models in mmlearn are generally defined by extending PyTorch’s nn.Module class. The input to the model’s forward method should be a dictionary, where the keys are the names of the modalities and the values are the corresponding (batched) tensors/data. The models must also return a list-like object where the first element is the last layer’s output.

import torch
from torch import nn

from mmlearn.datasets.core import Modalities


class MyTextEncoder(nn.Module):
   def __init__(self, input_dim: int, output_dim: int):
      super().__init__()
      self.encoder = ...

   def forward(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
      out = self.encoder(
         inputs[Modalities.TEXT.name],
         inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
         ),
      )
      return (out,)

Passing a dictionary of the (batched) inputs to the model’s forward method makes it easier to reuse the same model for different tasks.

Creating and Configuring a Project

A project in mmlearn can be thought of as a collection of related experiments. Within a project, you can reuse components from mmlearn (e.g., datasets, models, tasks) or define new ones and use them all together for experiments.

To create a new project, create a new directory following the structure below:

my_project/
├── configs/
│   ├── __init__.py
│   └── experiment/
│       ├── my_experiment.yaml
├── README.md (optional)
├── requirements.txt (optional)

The configs/ directory contains all the configurations, both structured configs and YAML config files for the experiments in the project. The configs/experiment/ directory contains the .yaml files for the experiments associated with the project. These .yaml files use the Hydra configuration format, which also allows overriding the configuration options/values from the command line.

The __init__.py file in the configs/ directory is required to make the configs/ directory a Python package, allowing hydra to compose configurations from .yaml files as well as structured configs from python modules. More on this in the next section.

Optionally, you can also include a README.md file with a brief description of the project and a requirements.txt file with the dependencies required to run the project.

Specifying Configurable Components

One of the key features of the Hydra configuration system is the ability to compose configurations from multiple sources, including the command line, .yaml files and structured configs from Python modules. Structured Configs in Hydra use Python dataclass() to define the configuration schema. This allows for both static and runtime type-checking of the configuration. Hydra-zen extends Hydra to makes it easy to dynamically generate dataclass-backed configurations for any class or function simply by adding a decorator to the class or function.

mmlearn provides a pre-populated config store, external_store, which can be used as a decorator to register configurable components. This config store already contains configurations for common components like PyTorch optimizers, learning rate schedulers, loss functions and samplers, as well as PyTorch Lightning’s Trainer callbacks and loggers. To dynamically add new configurable components to the store, simply add the external_store decorator to the class or function definition.

For example, the following code snippet shows how to register a new dataset class:

from torch.utils.data.dataset import Dataset

from mmlearn.conf import external_store
from mmlearn.constants import EXAMPLE_INDEX_KEY
from mmlearn.datasets.core import Example, Modalities


@external_store(group="datasets")
class MyMapStyleDataset(Dataset[Example]):
   ...
   def __getitem__(self, idx: int) -> Example:
      ...
      return Example(
         {
            EXAMPLE_INDEX_KEY: idx,
            Modalities.TEXT.name: ...,
            Modalities.RGB.name: ...,
            Modalities.RGB.target: ...,
            Modalities.TEXT.mask: ...,
            ...
         }
      )

The external_store decorator immediately add the class to the config store once the Python interpreter loads the module containing the class. This is why the configs/ directory must be a Python package and why modules containing user-defined configurable components must be imported in the configs/__init__.py file.

The group argument specifies the config group under which the configurable component will be registered. This allows users to easily reference the component in the configurations using the group name and the class name. The available config groups in mmlearn are:

  • datasets: Contains all the dataset classes.

  • datasets/masking: Contains all the configurable classes and functions for masking input data.

  • datasets/tokenizers: Contains all the configurable classes and functions for converting raw inputs to tokens.

  • datasets/transforms: Contains all the configurable classes and functions for transforming input data.

  • dataloader/sampler: Contains all the dataloader sampler classes.

  • modules/encoders: Contains all the encoder modules.

  • modules/layers: For layers that can be used independent of the model.

  • modules/losses: Contains all the loss functions.

  • modules/optimizers: Contains all the optimizers.

  • modules/lr_schedulers: Contains all the learning rate schedulers.

  • modules/metrics: Contains all the evaluation metrics.

  • tasks: Contains all the task classes.

  • trainer/callbacks: Contains all the PyTorch Lightning Trainer callbacks.

  • trainer/logger: Contains all the PyTorch Lightning Trainer loggers.

The Base Configuration

The base configuration for all experiments in mmlearn are defined in the MMLearnConf dataclass. This serves as the base configuration for all experiments and can be extended to include additional configuration options, following Hydra’s override syntax.

The base configuration for mmlearn is shown below:

experiment_name: ???
job_type: train
seed: null
datasets:
   train: null
   val: null
   test: null
dataloader:
   train:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
   val:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
   test:
      _target_: torch.utils.data.dataloader.DataLoader
      _convert_: object
      dataset: ???
      batch_size: 1
      shuffle: null
      sampler: null
      batch_sampler: null
      num_workers: 0
      collate_fn:
         _target_: mmlearn.datasets.core.data_collator.DefaultDataCollator
         batch_processors: null
      pin_memory: true
      drop_last: false
      timeout: 0.0
      worker_init_fn: null
      multiprocessing_context: null
      generator: null
      prefetch_factor: null
      persistent_workers: false
      pin_memory_device: ''
task: ???
trainer:
   _target_: lightning.pytorch.trainer.trainer.Trainer
   accelerator: auto
   strategy: auto
   devices: auto
   num_nodes: 1
   precision: null
   logger: null
   callbacks: null
   fast_dev_run: false
   max_epochs: null
   min_epochs: null
   max_steps: -1
   min_steps: null
   max_time: null
   limit_train_batches: null
   limit_val_batches: null
   limit_test_batches: null
   limit_predict_batches: null
   overfit_batches: 0.0
   val_check_interval: null
   check_val_every_n_epoch: 1
   num_sanity_val_steps: null
   log_every_n_steps: null
   enable_checkpointing: true
   enable_progress_bar: true
   enable_model_summary: true
   accumulate_grad_batches: 1
   gradient_clip_val: null
   gradient_clip_algorithm: null
   deterministic: null
   benchmark: null
   inference_mode: true
   use_distributed_sampler: true
   profiler: null
   detect_anomaly: false
   barebones: false
   plugins: null
   sync_batchnorm: false
   reload_dataloaders_every_n_epochs: 0
   default_root_dir: ${hydra:runtime.output_dir}/checkpoints
tags:
   - ${experiment_name}
resume_from_checkpoint: null
strict_loading: true
torch_compile_kwargs:
   disable: true
   fullgraph: false
   dynamic: null
   backend: inductor
   mode: null
   options: null

The config keys with a value of ??? are placeholders that must be overridden in the experiment configurations. While the dataset key in the dataloader group is also a placeholder, it should not be provided as it will be automatically filled in from the datasets group.

Running an Experiment

To run an experiment locally, use the following command:

mmlearn_run 'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name

Tip

You can see the full config for an experiment without running it by adding the --help flag to the command.

mmlearn_run 'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name \
   task=my_task \ # required for the command to run
   --help

To run the experiment on a SLURM cluster, use the following command:

mmlearn_run --multirun \
   hydra.launcher.mem_per_cpu=5G \
   hydra.launcher.qos=your_qos \
   hydra.launcher.partition=your_partition \
   hydra.launcher.gres=gpu:4 \
   hydra.launcher.cpus_per_task=8 \
   hydra.launcher.tasks_per_node=4 \
   hydra.launcher.nodes=1 \
   hydra.launcher.stderr_to_stdout=true \
   hydra.launcher.timeout_min=720 \
   'hydra.searchpath=[pkg://path.to.my_project.configs]' \
   +experiment=my_experiment \
   experiment_name=my_experiment_name

This uses the submitit launcher plugin built into Hydra to submit the experiment to the SLURM scheduler with the specified resources.

Note

After the job is submitted, it is okay to cancel the program with Ctrl+C. The job will continue running on the cluster. You can also add & at the end of the command to run it in the background.