Skip to content

API Reference

Top Level Module

mmlearn

Multimodal learning library.

cli

Command Line Interface for mmlearn.

run

Main entry point for training and evaluation.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]
DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch
Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)
CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)
DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch
BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask
RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices
HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)
TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence
HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)
HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )
HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)
HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches
LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )
MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)
L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)
PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids
ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()
Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)
RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )
ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()
encode
encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs
on_train_epoch_start
on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()
training_step
training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss
on_before_zero_grad
on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")
on_validation_epoch_end
on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")
on_test_epoch_end
on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_load_checkpoint
on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)
on_save_checkpoint
on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)
IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)
on_before_zero_grad
on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)
training_step
training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")
on_validation_epoch_end
on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
on_test_epoch_end
on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_save_checkpoint
on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }
on_load_checkpoint
on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)
ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results
ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results
find_matching_indices
find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices
linear_warmup_cosine_annealing_lr
linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )
main
main(cfg)

Entry point for training or evaluation.

Source code in mmlearn/cli/run.py
@_hydra_main(
    config_path="pkg://mmlearn.conf", config_name="base_config", version_base=None
)
def main(cfg: MMLearnConf) -> None:  # noqa: PLR0912
    """Entry point for training or evaluation."""
    cfg_copy = copy.deepcopy(cfg)  # copy of the config for logging

    L.seed_everything(cfg.seed, workers=True)

    if is_torch_tf32_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        if "16-mixed" in str(cfg.trainer.precision):
            cfg.trainer.precision = "bf16-mixed"

    # setup trainer first so that we can get some variables for distributed training
    callbacks = instantiate_callbacks(cfg.trainer.get("callbacks"))
    cfg.trainer["callbacks"] = None  # will be replaced with the instantiated object
    loggers = instantiate_loggers(cfg.trainer.get("logger"))
    cfg.trainer["logger"] = None
    trainer: Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=loggers, _convert_="all"
    )
    assert isinstance(trainer, Trainer), (
        "Trainer must be an instance of `lightning.pytorch.trainer.Trainer`"
    )

    if rank_zero_only.rank == 0 and loggers is not None:  # update wandb config
        for trainer_logger in loggers:
            if isinstance(trainer_logger, WandbLogger):
                trainer_logger.experiment.config.update(
                    OmegaConf.to_container(cfg_copy, resolve=True, enum_to_str=True),
                    allow_val_change=True,
                )
    trainer.print(OmegaConf.to_yaml(cfg_copy, resolve=True))

    requires_distributed_sampler = (
        trainer.distributed_sampler_kwargs is not None
        and trainer._accelerator_connector.use_distributed_sampler
    )
    if requires_distributed_sampler:  # we handle distributed samplers
        trainer._accelerator_connector.use_distributed_sampler = False

    # prepare dataloaders
    if cfg.job_type == JobType.train:
        train_dataset = instantiate_datasets(cfg.datasets.train)
        assert train_dataset is not None, (
            "Train dataset (`cfg.datasets.train`) is required for training."
        )

        train_sampler = instantiate_sampler(
            cfg.dataloader.train.get("sampler"),
            train_dataset,
            requires_distributed_sampler=requires_distributed_sampler,
            distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
        )
        cfg.dataloader.train["sampler"] = None  # replaced with the instantiated object
        train_loader: DataLoader = hydra.utils.instantiate(
            cfg.dataloader.train, dataset=train_dataset, sampler=train_sampler
        )

        val_loader: Optional[DataLoader] = None
        val_dataset = instantiate_datasets(cfg.datasets.val)
        if val_dataset is not None:
            val_sampler = instantiate_sampler(
                cfg.dataloader.val.get("sampler"),
                val_dataset,
                requires_distributed_sampler=requires_distributed_sampler,
                distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
            )
            cfg.dataloader.val["sampler"] = None
            val_loader = hydra.utils.instantiate(
                cfg.dataloader.val, dataset=val_dataset, sampler=val_sampler
            )
    else:
        test_dataset = instantiate_datasets(cfg.datasets.test)
        assert test_dataset is not None, (
            "Test dataset (`cfg.datasets.test`) is required for evaluation."
        )

        test_sampler = instantiate_sampler(
            cfg.dataloader.test.get("sampler"),
            test_dataset,
            requires_distributed_sampler=requires_distributed_sampler,
            distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
        )
        cfg.dataloader.test["sampler"] = None
        test_loader = hydra.utils.instantiate(
            cfg.dataloader.test, dataset=test_dataset, sampler=test_sampler
        )

    # setup task module
    if cfg.task is None or "_target_" not in cfg.task:
        raise ValueError(
            "Expected a non-empty config for `cfg.task` with a `_target_` key. "
            f"But got: {cfg.task}"
        )
    logger.info(f"Instantiating task module: {cfg.task['_target_']}")
    model: L.LightningModule = hydra.utils.instantiate(cfg.task, _convert_="partial")
    assert isinstance(model, L.LightningModule), "Task must be a `LightningModule`"
    model.strict_loading = cfg.strict_loading

    # compile model
    model = torch.compile(model, **OmegaConf.to_object(cfg.torch_compile_kwargs))

    if cfg.job_type == JobType.train:
        trainer.fit(
            model, train_loader, val_loader, ckpt_path=cfg.resume_from_checkpoint
        )
    elif cfg.job_type == JobType.eval:
        trainer.test(model, test_loader, ckpt_path=cfg.resume_from_checkpoint)

conf

Hydra/Hydra-zen-based configurations.

JobType

Bases: str, Enum

Type of the job.

Source code in mmlearn/conf/__init__.py
class JobType(str, Enum):
    """Type of the job."""

    train = "train"
    eval = "eval"

DatasetConf dataclass

Configuration template for the datasets.

Source code in mmlearn/conf/__init__.py
@dataclass
class DatasetConf:
    """Configuration template for the datasets."""

    #: Configuration for the training dataset.
    train: Optional[Any] = field(
        default=None,
        metadata={"help": "Configuration for the training dataset."},
    )
    #: Configuration for the validation dataset.
    val: Optional[Any] = field(
        default=None, metadata={"help": "Configuration for the validation dataset."}
    )
    #: Configuration for the test dataset.
    test: Optional[Any] = field(
        default=None,
        metadata={"help": "Configuration for the test dataset."},
    )

DataLoaderConf dataclass

Configuration for the dataloader.

Source code in mmlearn/conf/__init__.py
@dataclass
class DataLoaderConf:
    """Configuration for the dataloader."""

    #: Configuration for the training dataloader.
    train: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the training dataloader."},
    )
    #: Configuration for the validation dataloader.
    val: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the validation dataloader."},
    )
    #: Configuration for the test dataloader.
    test: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the test dataloader."},
    )

MMLearnConf dataclass

Top-level configuration for mmlearn experiments.

Source code in mmlearn/conf/__init__.py
@dataclass
class MMLearnConf:
    """Top-level configuration for mmlearn experiments."""

    defaults: list[Any] = field(
        default_factory=lambda: [
            "_self_",  # See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order for more information
            {"task": MISSING},
            {"override hydra/launcher": "submitit_slurm"},
        ]
    )
    #: Name of the experiment. This must be specified for any experiment to run.
    experiment_name: str = field(default=MISSING)
    #: Type of the job.
    job_type: JobType = field(default=JobType.train)
    #: Seed for the random number generators. This is set for Python, Numpy and PyTorch,
    #: including the workers in PyTorch Dataloaders.
    seed: Optional[int] = field(default=None)
    #: Configuration for the datasets.
    datasets: DatasetConf = field(default_factory=DatasetConf)
    #: Configuration for the dataloaders.
    dataloader: DataLoaderConf = field(default_factory=DataLoaderConf)
    #: Configuration for the task. This is required to run any experiment.
    task: Any = field(default=MISSING)
    #: Configuration for the trainer. The options here are the same as in
    #: :py:class:`~lightning.pytorch.trainer.trainer.Trainer`
    trainer: Any = field(
        default_factory=builds(
            lightning_trainer.Trainer,
            populate_full_signature=True,
            enable_model_summary=True,
            enable_progress_bar=True,
            enable_checkpointing=True,
            default_root_dir=_get_default_ckpt_dir(),
        )
    )
    #: Tags for the experiment. This is useful for `wandb <https://docs.wandb.ai/ref/python/init>`_
    #: logging.
    tags: Optional[list[str]] = field(default_factory=lambda: [II("experiment_name")])
    #: Path to the checkpoint to resume training from.
    resume_from_checkpoint: Optional[Path] = field(default=None)
    #: Whether to strictly enforce loading of model weights i.e. `strict=True` in
    #: :py:meth:`~lightning.pytorch.core.module.LightningModule.load_from_checkpoint`.
    strict_loading: bool = field(default=True)
    #: Configuration for torch.compile. These are essentially the same as the
    #: arguments for :py:func:`torch.compile`.
    torch_compile_kwargs: dict[str, Any] = field(
        default_factory=lambda: {
            "disable": True,
            "fullgraph": False,
            "dynamic": None,
            "backend": "inductor",
            "mode": None,
            "options": None,
        }
    )
    #: Hydra configuration.
    hydra: HydraConf = field(
        default_factory=lambda: HydraConf(
            searchpath=["pkg://mmlearn.conf"],
            run=RunDir(
                dir=SI("./outputs/${experiment_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}")
            ),
            sweep=SweepDir(
                dir=SI("./outputs/${experiment_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}"),
                subdir=SI("${hydra.job.num}_${hydra.job.id}"),
            ),
            help=HelpConf(
                app_name="mmlearn",
                header="mmlearn: A modular framework for research on multimodal representation learning.",
            ),
            job=JobConf(
                name=II("experiment_name"),
                env_set={
                    "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
                    "HYDRA_FULL_ERROR": "1",
                },
            ),
        )
    )

register_external_modules

register_external_modules(
    module,
    group,
    name=None,
    package=None,
    provider=None,
    base_cls=None,
    ignore_cls=None,
    ignore_prefix=None,
    **kwargs_for_builds
)

Add all classes in an external module to a ZenStore.

Parameters:

Name Type Description Default
module ModuleType

The module to add classes from.

required
group str

The config group to add the classes to.

required
name Optional[str]

The name to give to the dynamically-generated configs. If None, the class name is used.

None
package Optional[str]

The package to add the configs to.

None
provider Optional[str]

The provider to add the configs to.

None
base_cls Optional[type]

The base class to filter classes by. The base class is also excluded from the configs.

None
ignore_cls Optional[list[type]]

list of classes to ignore.

None
ignore_prefix Optional[str]

Ignore classes whose names start with this prefix.

None
kwargs_for_builds Any

Additional keyword arguments to pass to hydra_zen.builds.

{}
Source code in mmlearn/conf/__init__.py
def register_external_modules(
    module: ModuleType,
    group: str,
    name: Optional[str] = None,
    package: Optional[str] = None,
    provider: Optional[str] = None,
    base_cls: Optional[type] = None,
    ignore_cls: Optional[list[type]] = None,
    ignore_prefix: Optional[str] = None,
    **kwargs_for_builds: Any,
) -> None:
    """Add all classes in an external module to a ZenStore.

    Parameters
    ----------
    module : ModuleType
        The module to add classes from.
    group : str
        The config group to add the classes to.
    name : Optional[str], optional, default=None
        The name to give to the dynamically-generated configs. If `None`, the
        class name is used.
    package : Optional[str], optional, default=None
        The package to add the configs to.
    provider : Optional[str], optional, default=None
        The provider to add the configs to.
    base_cls : Optional[type], optional, default=None
        The base class to filter classes by. The base class is also excluded from
        the configs.
    ignore_cls : Optional[list[type]], optional, default=None
        list of classes to ignore.
    ignore_prefix : Optional[str], optional, default=None
        Ignore classes whose names start with this prefix.
    kwargs_for_builds : Any
        Additional keyword arguments to pass to ``hydra_zen.builds``.

    """
    for key, cls in module.__dict__.items():
        if (
            isinstance(cls, type)
            and (base_cls is None or issubclass(cls, base_cls))
            and cls != base_cls
            and (ignore_cls is None or cls not in ignore_cls)
            and (ignore_prefix is None or not key.startswith(ignore_prefix))
        ):
            external_store(
                builds(cls, populate_full_signature=True, **kwargs_for_builds),
                name=name or key,
                group=group,
                package=package,
                provider=provider,
            )

constants

Constants.

datasets

Datasets.

CheXpert

Bases: Dataset[Example]

CheXpert dataset.

Each datapoint is a pair of (image, target label).

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, valid)

Dataset split.

"train"
labeler Optional[{chexpert, chexbert, vchexbert}]

Labeler used to extract labels from the training images. "valid" split has no labeler, labeling for valid split was done by human radiologists.

None
transform Optional[Callable[[PIL.Image], torch.Tensor]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/chexpert.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("CHEXPERT_ROOT_DIR", MISSING),
    split="train",
)
class CheXpert(Dataset[Example]):
    """CheXpert dataset.

    Each datapoint is a pair of `(image, target label)`.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "valid"}
        Dataset split.
    labeler : Optional[{"chexpert", "chexbert", "vchexbert"}], optional, default=None
        Labeler used to extract labels from the training images. "valid" split
        has no labeler, labeling for valid split was done by human radiologists.
    transform : Optional[Callable[[PIL.Image], torch.Tensor], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "valid"],
        labeler: Optional[Literal["chexpert", "chexbert", "vchexbert"]] = None,
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "valid"], f"split {split} is not available."
        assert labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None, (
            f"labeler {labeler} is not available."
        )
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        if split == "valid":
            data_file = f"{split}_data.json"
        elif split == "train":
            data_file = f"{labeler}_{split}_data.json"
        data_path = os.path.join(root_dir, data_file)

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

    def __getitem__(self, idx: int) -> Example:
        """Return the idx'th datapoint."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        return Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return the idx'th datapoint.

Source code in mmlearn/datasets/chexpert.py
def __getitem__(self, idx: int) -> Example:
    """Return the idx'th datapoint."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    return Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/chexpert.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

ImageNet

Bases: ImageFolder

ImageNet dataset.

This is a wrapper around the 🇵🇾class:~torchvision.datasets.ImageFolder class that returns an 🇵🇾class:~mmlearn.datasets.core.example.Example object.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, val)

The split of the dataset to use.

"train"
transform Optional[Callable]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
target_transform Optional[Callable]

A callable that takes in the target and transforms it.

None
mask_generator Optional[Callable]

A callable that generates a mask for the image.

None
Source code in mmlearn/datasets/imagenet.py
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("IMAGENET_ROOT_DIR", MISSING),
)
class ImageNet(ImageFolder):
    """ImageNet dataset.

    This is a wrapper around the :py:class:`~torchvision.datasets.ImageFolder` class
    that returns an :py:class:`~mmlearn.datasets.core.example.Example` object.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "val"}, default="train"
        The split of the dataset to use.
    transform : Optional[Callable], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    target_transform : Optional[Callable], optional, default=None
        A callable that takes in the target and transforms it.
    mask_generator : Optional[Callable], optional, default=None
        A callable that generates a mask for the image.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "val"] = "train",
        transform: Optional[Callable[..., Any]] = None,
        target_transform: Optional[Callable[..., Any]] = None,
        mask_generator: Optional[Callable[..., Any]] = None,
    ) -> None:
        split = "train" if split == "train" else "val"
        root_dir = os.path.join(root_dir, split)
        super().__init__(
            root=root_dir, transform=transform, target_transform=target_transform
        )
        self.mask_generator = mask_generator

    def __getitem__(self, index: int) -> Example:
        """Get an example at the given index."""
        image, target = super().__getitem__(index)
        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: target,
                EXAMPLE_INDEX_KEY: index,
            }
        )
        mask = self.mask_generator() if self.mask_generator else None
        if mask is not None:  # error will be raised during collation if `None`
            example[Modalities.RGB.mask] = mask
        return example

    @property
    def zero_shot_prompt_templates(self) -> list[str]:
        """Return the zero-shot prompt templates."""
        return [
            "a bad photo of a {}.",
            "a photo of many {}.",
            "a sculpture of a {}.",
            "a photo of the hard to see {}.",
            "a low resolution photo of the {}.",
            "a rendering of a {}.",
            "graffiti of a {}.",
            "a bad photo of the {}.",
            "a cropped photo of the {}.",
            "a tattoo of a {}.",
            "the embroidered {}.",
            "a photo of a hard to see {}.",
            "a bright photo of a {}.",
            "a photo of a clean {}.",
            "a photo of a dirty {}.",
            "a dark photo of the {}.",
            "a drawing of a {}.",
            "a photo of my {}.",
            "the plastic {}.",
            "a photo of the cool {}.",
            "a close-up photo of a {}.",
            "a black and white photo of the {}.",
            "a painting of the {}.",
            "a painting of a {}.",
            "a pixelated photo of the {}.",
            "a sculpture of the {}.",
            "a bright photo of the {}.",
            "a cropped photo of a {}.",
            "a plastic {}.",
            "a photo of the dirty {}.",
            "a jpeg corrupted photo of a {}.",
            "a blurry photo of the {}.",
            "a photo of the {}.",
            "a good photo of the {}.",
            "a rendering of the {}.",
            "a {} in a video game.",
            "a photo of one {}.",
            "a doodle of a {}.",
            "a close-up photo of the {}.",
            "a photo of a {}.",
            "the origami {}.",
            "the {} in a video game.",
            "a sketch of a {}.",
            "a doodle of the {}.",
            "a origami {}.",
            "a low resolution photo of a {}.",
            "the toy {}.",
            "a rendition of the {}.",
            "a photo of the clean {}.",
            "a photo of a large {}.",
            "a rendition of a {}.",
            "a photo of a nice {}.",
            "a photo of a weird {}.",
            "a blurry photo of a {}.",
            "a cartoon {}.",
            "art of a {}.",
            "a sketch of the {}.",
            "a embroidered {}.",
            "a pixelated photo of a {}.",
            "itap of the {}.",
            "a jpeg corrupted photo of the {}.",
            "a good photo of a {}.",
            "a plushie {}.",
            "a photo of the nice {}.",
            "a photo of the small {}.",
            "a photo of the weird {}.",
            "the cartoon {}.",
            "art of the {}.",
            "a drawing of the {}.",
            "a photo of the large {}.",
            "a black and white photo of a {}.",
            "the plushie {}.",
            "a dark photo of a {}.",
            "itap of a {}.",
            "graffiti of the {}.",
            "a toy {}.",
            "itap of my {}.",
            "a photo of a cool {}.",
            "a photo of a small {}.",
            "a tattoo of the {}.",
        ]

    @property
    def id2label(self) -> dict[int, str]:
        """Return the label mapping."""
        return {
            0: "tench",
            1: "goldfish",
            2: "great white shark",
            3: "tiger shark",
            4: "hammerhead shark",
            5: "electric ray",
            6: "stingray",
            7: "rooster",
            8: "hen",
            9: "ostrich",
            10: "brambling",
            11: "goldfinch",
            12: "house finch",
            13: "junco",
            14: "indigo bunting",
            15: "American robin",
            16: "bulbul",
            17: "jay",
            18: "magpie",
            19: "chickadee",
            20: "American dipper",
            21: "kite (bird of prey)",
            22: "bald eagle",
            23: "vulture",
            24: "great grey owl",
            25: "fire salamander",
            26: "smooth newt",
            27: "newt",
            28: "spotted salamander",
            29: "axolotl",
            30: "American bullfrog",
            31: "tree frog",
            32: "tailed frog",
            33: "loggerhead sea turtle",
            34: "leatherback sea turtle",
            35: "mud turtle",
            36: "terrapin",
            37: "box turtle",
            38: "banded gecko",
            39: "green iguana",
            40: "Carolina anole",
            41: "desert grassland whiptail lizard",
            42: "agama",
            43: "frilled-necked lizard",
            44: "alligator lizard",
            45: "Gila monster",
            46: "European green lizard",
            47: "chameleon",
            48: "Komodo dragon",
            49: "Nile crocodile",
            50: "American alligator",
            51: "triceratops",
            52: "worm snake",
            53: "ring-necked snake",
            54: "eastern hog-nosed snake",
            55: "smooth green snake",
            56: "kingsnake",
            57: "garter snake",
            58: "water snake",
            59: "vine snake",
            60: "night snake",
            61: "boa constrictor",
            62: "African rock python",
            63: "Indian cobra",
            64: "green mamba",
            65: "sea snake",
            66: "Saharan horned viper",
            67: "eastern diamondback rattlesnake",
            68: "sidewinder rattlesnake",
            69: "trilobite",
            70: "harvestman",
            71: "scorpion",
            72: "yellow garden spider",
            73: "barn spider",
            74: "European garden spider",
            75: "southern black widow",
            76: "tarantula",
            77: "wolf spider",
            78: "tick",
            79: "centipede",
            80: "black grouse",
            81: "ptarmigan",
            82: "ruffed grouse",
            83: "prairie grouse",
            84: "peafowl",
            85: "quail",
            86: "partridge",
            87: "african grey parrot",
            88: "macaw",
            89: "sulphur-crested cockatoo",
            90: "lorikeet",
            91: "coucal",
            92: "bee eater",
            93: "hornbill",
            94: "hummingbird",
            95: "jacamar",
            96: "toucan",
            97: "duck",
            98: "red-breasted merganser",
            99: "goose",
            100: "black swan",
            101: "tusker",
            102: "echidna",
            103: "platypus",
            104: "wallaby",
            105: "koala",
            106: "wombat",
            107: "jellyfish",
            108: "sea anemone",
            109: "brain coral",
            110: "flatworm",
            111: "nematode",
            112: "conch",
            113: "snail",
            114: "slug",
            115: "sea slug",
            116: "chiton",
            117: "chambered nautilus",
            118: "Dungeness crab",
            119: "rock crab",
            120: "fiddler crab",
            121: "red king crab",
            122: "American lobster",
            123: "spiny lobster",
            124: "crayfish",
            125: "hermit crab",
            126: "isopod",
            127: "white stork",
            128: "black stork",
            129: "spoonbill",
            130: "flamingo",
            131: "little blue heron",
            132: "great egret",
            133: "bittern bird",
            134: "crane bird",
            135: "limpkin",
            136: "common gallinule",
            137: "American coot",
            138: "bustard",
            139: "ruddy turnstone",
            140: "dunlin",
            141: "common redshank",
            142: "dowitcher",
            143: "oystercatcher",
            144: "pelican",
            145: "king penguin",
            146: "albatross",
            147: "grey whale",
            148: "killer whale",
            149: "dugong",
            150: "sea lion",
            151: "Chihuahua",
            152: "Japanese Chin",
            153: "Maltese",
            154: "Pekingese",
            155: "Shih Tzu",
            156: "King Charles Spaniel",
            157: "Papillon",
            158: "toy terrier",
            159: "Rhodesian Ridgeback",
            160: "Afghan Hound",
            161: "Basset Hound",
            162: "Beagle",
            163: "Bloodhound",
            164: "Bluetick Coonhound",
            165: "Black and Tan Coonhound",
            166: "Treeing Walker Coonhound",
            167: "English foxhound",
            168: "Redbone Coonhound",
            169: "borzoi",
            170: "Irish Wolfhound",
            171: "Italian Greyhound",
            172: "Whippet",
            173: "Ibizan Hound",
            174: "Norwegian Elkhound",
            175: "Otterhound",
            176: "Saluki",
            177: "Scottish Deerhound",
            178: "Weimaraner",
            179: "Staffordshire Bull Terrier",
            180: "American Staffordshire Terrier",
            181: "Bedlington Terrier",
            182: "Border Terrier",
            183: "Kerry Blue Terrier",
            184: "Irish Terrier",
            185: "Norfolk Terrier",
            186: "Norwich Terrier",
            187: "Yorkshire Terrier",
            188: "Wire Fox Terrier",
            189: "Lakeland Terrier",
            190: "Sealyham Terrier",
            191: "Airedale Terrier",
            192: "Cairn Terrier",
            193: "Australian Terrier",
            194: "Dandie Dinmont Terrier",
            195: "Boston Terrier",
            196: "Miniature Schnauzer",
            197: "Giant Schnauzer",
            198: "Standard Schnauzer",
            199: "Scottish Terrier",
            200: "Tibetan Terrier",
            201: "Australian Silky Terrier",
            202: "Soft-coated Wheaten Terrier",
            203: "West Highland White Terrier",
            204: "Lhasa Apso",
            205: "Flat-Coated Retriever",
            206: "Curly-coated Retriever",
            207: "Golden Retriever",
            208: "Labrador Retriever",
            209: "Chesapeake Bay Retriever",
            210: "German Shorthaired Pointer",
            211: "Vizsla",
            212: "English Setter",
            213: "Irish Setter",
            214: "Gordon Setter",
            215: "Brittany dog",
            216: "Clumber Spaniel",
            217: "English Springer Spaniel",
            218: "Welsh Springer Spaniel",
            219: "Cocker Spaniel",
            220: "Sussex Spaniel",
            221: "Irish Water Spaniel",
            222: "Kuvasz",
            223: "Schipperke",
            224: "Groenendael dog",
            225: "Malinois",
            226: "Briard",
            227: "Australian Kelpie",
            228: "Komondor",
            229: "Old English Sheepdog",
            230: "Shetland Sheepdog",
            231: "collie",
            232: "Border Collie",
            233: "Bouvier des Flandres dog",
            234: "Rottweiler",
            235: "German Shepherd Dog",
            236: "Dobermann",
            237: "Miniature Pinscher",
            238: "Greater Swiss Mountain Dog",
            239: "Bernese Mountain Dog",
            240: "Appenzeller Sennenhund",
            241: "Entlebucher Sennenhund",
            242: "Boxer",
            243: "Bullmastiff",
            244: "Tibetan Mastiff",
            245: "French Bulldog",
            246: "Great Dane",
            247: "St. Bernard",
            248: "husky",
            249: "Alaskan Malamute",
            250: "Siberian Husky",
            251: "Dalmatian",
            252: "Affenpinscher",
            253: "Basenji",
            254: "pug",
            255: "Leonberger",
            256: "Newfoundland dog",
            257: "Great Pyrenees dog",
            258: "Samoyed",
            259: "Pomeranian",
            260: "Chow Chow",
            261: "Keeshond",
            262: "brussels griffon",
            263: "Pembroke Welsh Corgi",
            264: "Cardigan Welsh Corgi",
            265: "Toy Poodle",
            266: "Miniature Poodle",
            267: "Standard Poodle",
            268: "Mexican hairless dog (xoloitzcuintli)",
            269: "grey wolf",
            270: "Alaskan tundra wolf",
            271: "red wolf or maned wolf",
            272: "coyote",
            273: "dingo",
            274: "dhole",
            275: "African wild dog",
            276: "hyena",
            277: "red fox",
            278: "kit fox",
            279: "Arctic fox",
            280: "grey fox",
            281: "tabby cat",
            282: "tiger cat",
            283: "Persian cat",
            284: "Siamese cat",
            285: "Egyptian Mau",
            286: "cougar",
            287: "lynx",
            288: "leopard",
            289: "snow leopard",
            290: "jaguar",
            291: "lion",
            292: "tiger",
            293: "cheetah",
            294: "brown bear",
            295: "American black bear",
            296: "polar bear",
            297: "sloth bear",
            298: "mongoose",
            299: "meerkat",
            300: "tiger beetle",
            301: "ladybug",
            302: "ground beetle",
            303: "longhorn beetle",
            304: "leaf beetle",
            305: "dung beetle",
            306: "rhinoceros beetle",
            307: "weevil",
            308: "fly",
            309: "bee",
            310: "ant",
            311: "grasshopper",
            312: "cricket insect",
            313: "stick insect",
            314: "cockroach",
            315: "praying mantis",
            316: "cicada",
            317: "leafhopper",
            318: "lacewing",
            319: "dragonfly",
            320: "damselfly",
            321: "red admiral butterfly",
            322: "ringlet butterfly",
            323: "monarch butterfly",
            324: "small white butterfly",
            325: "sulphur butterfly",
            326: "gossamer-winged butterfly",
            327: "starfish",
            328: "sea urchin",
            329: "sea cucumber",
            330: "cottontail rabbit",
            331: "hare",
            332: "Angora rabbit",
            333: "hamster",
            334: "porcupine",
            335: "fox squirrel",
            336: "marmot",
            337: "beaver",
            338: "guinea pig",
            339: "common sorrel horse",
            340: "zebra",
            341: "pig",
            342: "wild boar",
            343: "warthog",
            344: "hippopotamus",
            345: "ox",
            346: "water buffalo",
            347: "bison",
            348: "ram (adult male sheep)",
            349: "bighorn sheep",
            350: "Alpine ibex",
            351: "hartebeest",
            352: "impala (antelope)",
            353: "gazelle",
            354: "arabian camel",
            355: "llama",
            356: "weasel",
            357: "mink",
            358: "European polecat",
            359: "black-footed ferret",
            360: "otter",
            361: "skunk",
            362: "badger",
            363: "armadillo",
            364: "three-toed sloth",
            365: "orangutan",
            366: "gorilla",
            367: "chimpanzee",
            368: "gibbon",
            369: "siamang",
            370: "guenon",
            371: "patas monkey",
            372: "baboon",
            373: "macaque",
            374: "langur",
            375: "black-and-white colobus",
            376: "proboscis monkey",
            377: "marmoset",
            378: "white-headed capuchin",
            379: "howler monkey",
            380: "titi monkey",
            381: "Geoffroy's spider monkey",
            382: "common squirrel monkey",
            383: "ring-tailed lemur",
            384: "indri",
            385: "Asian elephant",
            386: "African bush elephant",
            387: "red panda",
            388: "giant panda",
            389: "snoek fish",
            390: "eel",
            391: "silver salmon",
            392: "rock beauty fish",
            393: "clownfish",
            394: "sturgeon",
            395: "gar fish",
            396: "lionfish",
            397: "pufferfish",
            398: "abacus",
            399: "abaya",
            400: "academic gown",
            401: "accordion",
            402: "acoustic guitar",
            403: "aircraft carrier",
            404: "airliner",
            405: "airship",
            406: "altar",
            407: "ambulance",
            408: "amphibious vehicle",
            409: "analog clock",
            410: "apiary",
            411: "apron",
            412: "trash can",
            413: "assault rifle",
            414: "backpack",
            415: "bakery",
            416: "balance beam",
            417: "balloon",
            418: "ballpoint pen",
            419: "Band-Aid",
            420: "banjo",
            421: "baluster / handrail",
            422: "barbell",
            423: "barber chair",
            424: "barbershop",
            425: "barn",
            426: "barometer",
            427: "barrel",
            428: "wheelbarrow",
            429: "baseball",
            430: "basketball",
            431: "bassinet",
            432: "bassoon",
            433: "swimming cap",
            434: "bath towel",
            435: "bathtub",
            436: "station wagon",
            437: "lighthouse",
            438: "beaker",
            439: "military hat (bearskin or shako)",
            440: "beer bottle",
            441: "beer glass",
            442: "bell tower",
            443: "baby bib",
            444: "tandem bicycle",
            445: "bikini",
            446: "ring binder",
            447: "binoculars",
            448: "birdhouse",
            449: "boathouse",
            450: "bobsleigh",
            451: "bolo tie",
            452: "poke bonnet",
            453: "bookcase",
            454: "bookstore",
            455: "bottle cap",
            456: "hunting bow",
            457: "bow tie",
            458: "brass memorial plaque",
            459: "bra",
            460: "breakwater",
            461: "breastplate",
            462: "broom",
            463: "bucket",
            464: "buckle",
            465: "bulletproof vest",
            466: "high-speed train",
            467: "butcher shop",
            468: "taxicab",
            469: "cauldron",
            470: "candle",
            471: "cannon",
            472: "canoe",
            473: "can opener",
            474: "cardigan",
            475: "car mirror",
            476: "carousel",
            477: "tool kit",
            478: "cardboard box / carton",
            479: "car wheel",
            480: "automated teller machine",
            481: "cassette",
            482: "cassette player",
            483: "castle",
            484: "catamaran",
            485: "CD player",
            486: "cello",
            487: "mobile phone",
            488: "chain",
            489: "chain-link fence",
            490: "chain mail",
            491: "chainsaw",
            492: "storage chest",
            493: "chiffonier",
            494: "bell or wind chime",
            495: "china cabinet",
            496: "Christmas stocking",
            497: "church",
            498: "movie theater",
            499: "cleaver",
            500: "cliff dwelling",
            501: "cloak",
            502: "clogs",
            503: "cocktail shaker",
            504: "coffee mug",
            505: "coffeemaker",
            506: "spiral or coil",
            507: "combination lock",
            508: "computer keyboard",
            509: "candy store",
            510: "container ship",
            511: "convertible",
            512: "corkscrew",
            513: "cornet",
            514: "cowboy boot",
            515: "cowboy hat",
            516: "cradle",
            517: "construction crane",
            518: "crash helmet",
            519: "crate",
            520: "infant bed",
            521: "Crock Pot",
            522: "croquet ball",
            523: "crutch",
            524: "cuirass",
            525: "dam",
            526: "desk",
            527: "desktop computer",
            528: "rotary dial telephone",
            529: "diaper",
            530: "digital clock",
            531: "digital watch",
            532: "dining table",
            533: "dishcloth",
            534: "dishwasher",
            535: "disc brake",
            536: "dock",
            537: "dog sled",
            538: "dome",
            539: "doormat",
            540: "drilling rig",
            541: "drum",
            542: "drumstick",
            543: "dumbbell",
            544: "Dutch oven",
            545: "electric fan",
            546: "electric guitar",
            547: "electric locomotive",
            548: "entertainment center",
            549: "envelope",
            550: "espresso machine",
            551: "face powder",
            552: "feather boa",
            553: "filing cabinet",
            554: "fireboat",
            555: "fire truck",
            556: "fire screen",
            557: "flagpole",
            558: "flute",
            559: "folding chair",
            560: "football helmet",
            561: "forklift",
            562: "fountain",
            563: "fountain pen",
            564: "four-poster bed",
            565: "freight car",
            566: "French horn",
            567: "frying pan",
            568: "fur coat",
            569: "garbage truck",
            570: "gas mask or respirator",
            571: "gas pump",
            572: "goblet",
            573: "go-kart",
            574: "golf ball",
            575: "golf cart",
            576: "gondola",
            577: "gong",
            578: "gown",
            579: "grand piano",
            580: "greenhouse",
            581: "radiator grille",
            582: "grocery store",
            583: "guillotine",
            584: "hair clip",
            585: "hair spray",
            586: "half-track",
            587: "hammer",
            588: "hamper",
            589: "hair dryer",
            590: "hand-held computer",
            591: "handkerchief",
            592: "hard disk drive",
            593: "harmonica",
            594: "harp",
            595: "combine harvester",
            596: "hatchet",
            597: "holster",
            598: "home theater",
            599: "honeycomb",
            600: "hook",
            601: "hoop skirt",
            602: "gymnastic horizontal bar",
            603: "horse-drawn vehicle",
            604: "hourglass",
            605: "iPod",
            606: "clothes iron",
            607: "carved pumpkin",
            608: "jeans",
            609: "jeep",
            610: "T-shirt",
            611: "jigsaw puzzle",
            612: "rickshaw",
            613: "joystick",
            614: "kimono",
            615: "knee pad",
            616: "knot",
            617: "lab coat",
            618: "ladle",
            619: "lampshade",
            620: "laptop computer",
            621: "lawn mower",
            622: "lens cap",
            623: "letter opener",
            624: "library",
            625: "lifeboat",
            626: "lighter",
            627: "limousine",
            628: "ocean liner",
            629: "lipstick",
            630: "slip-on shoe",
            631: "lotion",
            632: "music speaker",
            633: "loupe magnifying glass",
            634: "sawmill",
            635: "magnetic compass",
            636: "messenger bag",
            637: "mailbox",
            638: "tights",
            639: "one-piece bathing suit",
            640: "manhole cover",
            641: "maraca",
            642: "marimba",
            643: "mask",
            644: "matchstick",
            645: "maypole",
            646: "maze",
            647: "measuring cup",
            648: "medicine cabinet",
            649: "megalith",
            650: "microphone",
            651: "microwave oven",
            652: "military uniform",
            653: "milk can",
            654: "minibus",
            655: "miniskirt",
            656: "minivan",
            657: "missile",
            658: "mitten",
            659: "mixing bowl",
            660: "mobile home",
            661: "ford model t",
            662: "modem",
            663: "monastery",
            664: "monitor",
            665: "moped",
            666: "mortar and pestle",
            667: "graduation cap",
            668: "mosque",
            669: "mosquito net",
            670: "vespa",
            671: "mountain bike",
            672: "tent",
            673: "computer mouse",
            674: "mousetrap",
            675: "moving van",
            676: "muzzle",
            677: "metal nail",
            678: "neck brace",
            679: "necklace",
            680: "baby pacifier",
            681: "notebook computer",
            682: "obelisk",
            683: "oboe",
            684: "ocarina",
            685: "odometer",
            686: "oil filter",
            687: "pipe organ",
            688: "oscilloscope",
            689: "overskirt",
            690: "bullock cart",
            691: "oxygen mask",
            692: "product packet / packaging",
            693: "paddle",
            694: "paddle wheel",
            695: "padlock",
            696: "paintbrush",
            697: "pajamas",
            698: "palace",
            699: "pan flute",
            700: "paper towel",
            701: "parachute",
            702: "parallel bars",
            703: "park bench",
            704: "parking meter",
            705: "railroad car",
            706: "patio",
            707: "payphone",
            708: "pedestal",
            709: "pencil case",
            710: "pencil sharpener",
            711: "perfume",
            712: "Petri dish",
            713: "photocopier",
            714: "plectrum",
            715: "Pickelhaube",
            716: "picket fence",
            717: "pickup truck",
            718: "pier",
            719: "piggy bank",
            720: "pill bottle",
            721: "pillow",
            722: "ping-pong ball",
            723: "pinwheel",
            724: "pirate ship",
            725: "drink pitcher",
            726: "block plane",
            727: "planetarium",
            728: "plastic bag",
            729: "plate rack",
            730: "farm plow",
            731: "plunger",
            732: "Polaroid camera",
            733: "pole",
            734: "police van",
            735: "poncho",
            736: "pool table",
            737: "soda bottle",
            738: "plant pot",
            739: "potter's wheel",
            740: "power drill",
            741: "prayer rug",
            742: "printer",
            743: "prison",
            744: "missile",
            745: "projector",
            746: "hockey puck",
            747: "punching bag",
            748: "purse",
            749: "quill",
            750: "quilt",
            751: "race car",
            752: "racket",
            753: "radiator",
            754: "radio",
            755: "radio telescope",
            756: "rain barrel",
            757: "recreational vehicle",
            758: "fishing casting reel",
            759: "reflex camera",
            760: "refrigerator",
            761: "remote control",
            762: "restaurant",
            763: "revolver",
            764: "rifle",
            765: "rocking chair",
            766: "rotisserie",
            767: "eraser",
            768: "rugby ball",
            769: "ruler measuring stick",
            770: "sneaker",
            771: "safe",
            772: "safety pin",
            773: "salt shaker",
            774: "sandal",
            775: "sarong",
            776: "saxophone",
            777: "scabbard",
            778: "weighing scale",
            779: "school bus",
            780: "schooner",
            781: "scoreboard",
            782: "CRT monitor",
            783: "screw",
            784: "screwdriver",
            785: "seat belt",
            786: "sewing machine",
            787: "shield",
            788: "shoe store",
            789: "shoji screen / room divider",
            790: "shopping basket",
            791: "shopping cart",
            792: "shovel",
            793: "shower cap",
            794: "shower curtain",
            795: "ski",
            796: "balaclava ski mask",
            797: "sleeping bag",
            798: "slide rule",
            799: "sliding door",
            800: "slot machine",
            801: "snorkel",
            802: "snowmobile",
            803: "snowplow",
            804: "soap dispenser",
            805: "soccer ball",
            806: "sock",
            807: "solar thermal collector",
            808: "sombrero",
            809: "soup bowl",
            810: "keyboard space bar",
            811: "space heater",
            812: "space shuttle",
            813: "spatula",
            814: "motorboat",
            815: "spider web",
            816: "spindle",
            817: "sports car",
            818: "spotlight",
            819: "stage",
            820: "steam locomotive",
            821: "through arch bridge",
            822: "steel drum",
            823: "stethoscope",
            824: "scarf",
            825: "stone wall",
            826: "stopwatch",
            827: "stove",
            828: "strainer",
            829: "tram",
            830: "stretcher",
            831: "couch",
            832: "stupa",
            833: "submarine",
            834: "suit",
            835: "sundial",
            836: "sunglasses",
            837: "sunglasses",
            838: "sunscreen",
            839: "suspension bridge",
            840: "mop",
            841: "sweatshirt",
            842: "swim trunks / shorts",
            843: "swing",
            844: "electrical switch",
            845: "syringe",
            846: "table lamp",
            847: "tank",
            848: "tape player",
            849: "teapot",
            850: "teddy bear",
            851: "television",
            852: "tennis ball",
            853: "thatched roof",
            854: "front curtain",
            855: "thimble",
            856: "threshing machine",
            857: "throne",
            858: "tile roof",
            859: "toaster",
            860: "tobacco shop",
            861: "toilet seat",
            862: "torch",
            863: "totem pole",
            864: "tow truck",
            865: "toy store",
            866: "tractor",
            867: "semi-trailer truck",
            868: "tray",
            869: "trench coat",
            870: "tricycle",
            871: "trimaran",
            872: "tripod",
            873: "triumphal arch",
            874: "trolleybus",
            875: "trombone",
            876: "hot tub",
            877: "turnstile",
            878: "typewriter keyboard",
            879: "umbrella",
            880: "unicycle",
            881: "upright piano",
            882: "vacuum cleaner",
            883: "vase",
            884: "vaulted or arched ceiling",
            885: "velvet fabric",
            886: "vending machine",
            887: "vestment",
            888: "viaduct",
            889: "violin",
            890: "volleyball",
            891: "waffle iron",
            892: "wall clock",
            893: "wallet",
            894: "wardrobe",
            895: "military aircraft",
            896: "sink",
            897: "washing machine",
            898: "water bottle",
            899: "water jug",
            900: "water tower",
            901: "whiskey jug",
            902: "whistle",
            903: "hair wig",
            904: "window screen",
            905: "window shade",
            906: "Windsor tie",
            907: "wine bottle",
            908: "airplane wing",
            909: "wok",
            910: "wooden spoon",
            911: "wool",
            912: "split-rail fence",
            913: "shipwreck",
            914: "sailboat",
            915: "yurt",
            916: "website",
            917: "comic book",
            918: "crossword",
            919: "traffic or street sign",
            920: "traffic light",
            921: "dust jacket",
            922: "menu",
            923: "plate",
            924: "guacamole",
            925: "consomme",
            926: "hot pot",
            927: "trifle",
            928: "ice cream",
            929: "popsicle",
            930: "baguette",
            931: "bagel",
            932: "pretzel",
            933: "cheeseburger",
            934: "hot dog",
            935: "mashed potatoes",
            936: "cabbage",
            937: "broccoli",
            938: "cauliflower",
            939: "zucchini",
            940: "spaghetti squash",
            941: "acorn squash",
            942: "butternut squash",
            943: "cucumber",
            944: "artichoke",
            945: "bell pepper",
            946: "cardoon",
            947: "mushroom",
            948: "Granny Smith apple",
            949: "strawberry",
            950: "orange",
            951: "lemon",
            952: "fig",
            953: "pineapple",
            954: "banana",
            955: "jackfruit",
            956: "cherimoya (custard apple)",
            957: "pomegranate",
            958: "hay",
            959: "carbonara",
            960: "chocolate syrup",
            961: "dough",
            962: "meatloaf",
            963: "pizza",
            964: "pot pie",
            965: "burrito",
            966: "red wine",
            967: "espresso",
            968: "tea cup",
            969: "eggnog",
            970: "mountain",
            971: "bubble",
            972: "cliff",
            973: "coral reef",
            974: "geyser",
            975: "lakeshore",
            976: "promontory",
            977: "sandbar",
            978: "beach",
            979: "valley",
            980: "volcano",
            981: "baseball player",
            982: "bridegroom",
            983: "scuba diver",
            984: "rapeseed",
            985: "daisy",
            986: "yellow lady's slipper",
            987: "corn",
            988: "acorn",
            989: "rose hip",
            990: "horse chestnut seed",
            991: "coral fungus",
            992: "agaric",
            993: "gyromitra",
            994: "stinkhorn mushroom",
            995: "earth star fungus",
            996: "hen of the woods mushroom",
            997: "bolete",
            998: "corn cob",
            999: "toilet paper",
        }
zero_shot_prompt_templates property
zero_shot_prompt_templates

Return the zero-shot prompt templates.

id2label property
id2label

Return the label mapping.

__getitem__
__getitem__(index)

Get an example at the given index.

Source code in mmlearn/datasets/imagenet.py
def __getitem__(self, index: int) -> Example:
    """Get an example at the given index."""
    image, target = super().__getitem__(index)
    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: target,
            EXAMPLE_INDEX_KEY: index,
        }
    )
    mask = self.mask_generator() if self.mask_generator else None
    if mask is not None:  # error will be raised during collation if `None`
        example[Modalities.RGB.mask] = mask
    return example

LibriSpeech

Bases: Dataset[Example]

LibriSpeech dataset.

This is a wrapper around 🇵🇾class:torchaudio.datasets.LIBRISPEECH that assumes that the dataset is already downloaded and the top-level directory of the dataset in the root directory is librispeech.

Parameters:

Name Type Description Default
root_dir str

Root directory of dataset.

required
split (train - clean - 100, train - clean - 360, train - other - 500, dev - clean, dev - other, test - clean, test - other)

Split of the dataset to use.

"train-clean-100"

Raises:

Type Description
ImportError

If torchaudio is not installed.

Notes

This dataset only returns the audio and transcript from the dataset.

Source code in mmlearn/datasets/librispeech.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LIBRISPEECH_ROOT_DIR", MISSING),
)
class LibriSpeech(Dataset[Example]):
    """LibriSpeech dataset.

    This is a wrapper around :py:class:`torchaudio.datasets.LIBRISPEECH` that assumes
    that the dataset is already downloaded and the top-level directory of the dataset
    in the root directory is `librispeech`.

    Parameters
    ----------
    root_dir : str
        Root directory of dataset.
    split : {"train-clean-100", "train-clean-360", "train-other-500", "dev-clean", "dev-other", "test-clean", "test-other"}, default="train-clean-100"
        Split of the dataset to use.

    Raises
    ------
    ImportError
        If ``torchaudio`` is not installed.

    Notes
    -----
    This dataset only returns the audio and transcript from the dataset.

    """  # noqa: W505

    def __init__(self, root_dir: str, split: str = "train-clean-100") -> None:
        super().__init__()
        if not _TORCHAUDIO_AVAILABLE:
            raise ImportError(
                "LibriSpeech dataset requires `torchaudio`, which is not installed."
            )
        from torchaudio.datasets import LIBRISPEECH

        self.dataset = LIBRISPEECH(
            root=root_dir,
            url=split,
            download=False,
            folder_in_archive="librispeech",
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
        assert sample_rate == SAMPLE_RATE, (
            f"Expected sample rate to be `16000`, got {sample_rate}."
        )
        waveform = pad_or_trim(waveform.flatten())

        return Example(
            {
                Modalities.AUDIO.name: waveform,
                Modalities.TEXT.name: transcript,
                EXAMPLE_INDEX_KEY: idx,
            },
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/librispeech.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/librispeech.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
    assert sample_rate == SAMPLE_RATE, (
        f"Expected sample rate to be `16000`, got {sample_rate}."
    )
    waveform = pad_or_trim(waveform.flatten())

    return Example(
        {
            Modalities.AUDIO.name: waveform,
            Modalities.TEXT.name: transcript,
            EXAMPLE_INDEX_KEY: idx,
        },
    )

LLVIPDataset

Bases: Dataset[Example]

Low-Light Visible-Infrared Pair (LLVIP) dataset.

Loads pairs of RGB and THERMAL images from the LLVIP dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset. The directory should contain 'visible' and 'infrared' subdirectories.

required
train bool

Flag to indicate whether to load the training or test set.

True
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor. This is applied to both RGB and thermal images.

None
Source code in mmlearn/datasets/llvip.py
@store(
    name="LLVIP",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LLVIP_ROOT_DIR", MISSING),
)
class LLVIPDataset(Dataset[Example]):
    """Low-Light Visible-Infrared Pair (LLVIP) dataset.

    Loads pairs of `RGB` and `THERMAL` images from the LLVIP dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset. The directory should contain
        'visible' and 'infrared' subdirectories.
    train : bool, default=True
        Flag to indicate whether to load the training or test set.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor. This is applied to both RGB and thermal
        images.
    """

    def __init__(
        self,
        root_dir: str,
        train: bool = True,
        transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ):
        self.path_images_rgb = os.path.join(
            root_dir,
            "visible",
            "train" if train else "test",
        )
        self.path_images_ir = os.path.join(
            root_dir, "infrared", "train" if train else "test"
        )
        self.train = train
        self.transform = transform or transforms.ToTensor()

        self.rgb_images = sorted(glob.glob(os.path.join(self.path_images_rgb, "*.jpg")))
        self.ir_images = sorted(glob.glob(os.path.join(self.path_images_ir, "*.jpg")))

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.rgb_images)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        rgb_image_path = self.rgb_images[idx]
        ir_image_path = self.ir_images[idx]

        rgb_image = PILImage.open(rgb_image_path).convert("RGB")
        ir_image = PILImage.open(ir_image_path).convert("L")

        example = Example(
            {
                Modalities.RGB.name: self.transform(rgb_image),
                Modalities.THERMAL.name: self.transform(ir_image),
                EXAMPLE_INDEX_KEY: idx,
            },
        )

        if self.train:
            annot_path = (
                rgb_image_path.replace("visible", "Annotations")
                .replace(".jpg", ".xml")
                .replace("train", "")
            )
            annot = self._get_bbox(annot_path)
            example["annotation"] = {
                "bboxes": torch.from_numpy(annot["bboxes"]),
                "labels": torch.from_numpy(annot["labels"]),
            }
        return example

    def _get_bbox(self, filename: str) -> dict[str, np.ndarray]:
        """Parse the XML file to get bounding boxes and labels.

        Parameters
        ----------
        filename : str
            Path to the annotation XML file.

        Returns
        -------
        dict
            A dictionary containing bounding boxes and labels.
        """
        try:
            root = ET.parse(filename).getroot()

            bboxes, labels = [], []
            for obj in root.findall("object"):
                bbox_obj = obj.find("bndbox")
                bbox = [
                    int(bbox_obj.find(dim).text)  # type: ignore[union-attr,arg-type]
                    for dim in ["xmin", "ymin", "xmax", "ymax"]
                ]
                bboxes.append(bbox)
                labels.append(1)  # Assuming 'person' is the only label
            return {
                "bboxes": np.array(bboxes).astype("float"),
                "labels": np.array(labels).astype("int"),
            }
        except ET.ParseError as e:
            raise ValueError(f"Error parsing XML: {e}") from None
        except Exception as e:
            raise RuntimeError(
                f"Error processing annotation file {filename}: {e}",
            ) from None
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/llvip.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.rgb_images)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/llvip.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    rgb_image_path = self.rgb_images[idx]
    ir_image_path = self.ir_images[idx]

    rgb_image = PILImage.open(rgb_image_path).convert("RGB")
    ir_image = PILImage.open(ir_image_path).convert("L")

    example = Example(
        {
            Modalities.RGB.name: self.transform(rgb_image),
            Modalities.THERMAL.name: self.transform(ir_image),
            EXAMPLE_INDEX_KEY: idx,
        },
    )

    if self.train:
        annot_path = (
            rgb_image_path.replace("visible", "Annotations")
            .replace(".jpg", ".xml")
            .replace("train", "")
        )
        annot = self._get_bbox(annot_path)
        example["annotation"] = {
            "bboxes": torch.from_numpy(annot["bboxes"]),
            "labels": torch.from_numpy(annot["labels"]),
        }
    return example

NIHCXR

Bases: Dataset[Example]

NIH Chest X-ray dataset.

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, test, bbox)

Dataset split. "bbox" is a subset of "test" which contains bounding box info.

"train"
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/nihcxr.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NIH_CXR_DIR", MISSING),
    split="train",
)
class NIHCXR(Dataset[Example]):
    """NIH Chest X-ray dataset.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "test", "bbox"}
        Dataset split. "bbox" is a subset of "test" which contains bounding box info.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test", "bbox"],
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "test", "bbox"], f"split {split} is not available."
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        data_path = os.path.join(root_dir, split + "_data.json")

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

        self.bbox = split == "bbox"

    def __getitem__(self, idx: int) -> Example:
        """Return image-label or image-label-tabular(bbox)."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

        if self.bbox:
            example["bbox"] = entry["bbox"]

        return example

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return image-label or image-label-tabular(bbox).

Source code in mmlearn/datasets/nihcxr.py
def __getitem__(self, idx: int) -> Example:
    """Return image-label or image-label-tabular(bbox)."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )

    if self.bbox:
        example["bbox"] = entry["bbox"]

    return example
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nihcxr.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

NYUv2Dataset

Bases: Dataset[Example]

NYUv2 dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images.

  • "disparity": Return the depth image as disparity map.
  • "image": Return the depth image as a 3-channel image.
"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None

Raises:

Type Description
ImportError

If opencv-python is not installed.

Source code in mmlearn/datasets/nyuv2.py
@store(
    name="NYUv2",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NYUV2_ROOT_DIR", MISSING),
)
class NYUv2Dataset(Dataset[Example]):
    """NYUv2 dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images.

        - `"disparity"`: Return the depth image as disparity map.
        - `"image"`: Return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    Raises
    ------
    ImportError
        If `opencv-python` is not installed.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "NYUv2 dataset requires `opencv-python` which is not installed.",
            )
        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.png") for f in file_ids]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = [str(open(f).read().strip()) for f in label_files]  # noqa: SIM115
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]

        self.samples = list(zip(rgb_files, depth_files, labels, strict=False))

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = depth_normalize(
                self.samples[idx][1],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nyuv2.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/nyuv2.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = depth_normalize(
            self.samples[idx][1],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

SUNRGBDDataset

Bases: Dataset[Example]

SUN RGB-D dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images. If "disparity", the depth images are converted to disparity similar to the ImageBind implementation. Otherwise, return the depth image as a 3-channel image.

"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None
References

.. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets

Source code in mmlearn/datasets/sunrgbd.py
@store(
    name="SUNRGBD",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("SUNRGBD_ROOT_DIR", MISSING),
)
class SUNRGBDDataset(Dataset[Example]):
    """SUN RGB-D dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images. If "disparity", the depth images are
        converted to disparity similar to the ImageBind implementation.
        Otherwise, return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    References
    ----------
    .. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "SUN RGB-D dataset requires `opencv-python` which is not installed.",
            )

        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.jpg") for f in file_ids]
        intrinsic_files = [
            os.path.join(root_dir, "intrinsics", f"{f}.txt") for f in file_ids
        ]

        sensor_types = [
            file.removeprefix(os.path.join(root_dir, "depth")).split(os.sep)[1]
            for file in depth_files
        ]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = []
        for label_file in label_files:
            with open(label_file, "r") as file:  # noqa: SIM115
                labels.append(file.read().strip())
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]
            intrinsic_files = [intrinsic_files[i] for i in valid_indices]
            sensor_types = [sensor_types[i] for i in valid_indices]

        self.samples = list(
            zip(
                rgb_files,
                depth_files,
                labels,
                intrinsic_files,
                sensor_types,
                strict=False,
            )
        )

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = convert_depth_to_disparity(
                self.samples[idx][1],
                self.samples[idx][3],
                self.samples[idx][4],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/sunrgbd.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/sunrgbd.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = convert_depth_to_disparity(
            self.samples[idx][1],
            self.samples[idx][3],
            self.samples[idx][4],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

chexpert

CheXpert Dataset.

CheXpert

Bases: Dataset[Example]

CheXpert dataset.

Each datapoint is a pair of (image, target label).

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, valid)

Dataset split.

"train"
labeler Optional[{chexpert, chexbert, vchexbert}]

Labeler used to extract labels from the training images. "valid" split has no labeler, labeling for valid split was done by human radiologists.

None
transform Optional[Callable[[PIL.Image], torch.Tensor]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/chexpert.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("CHEXPERT_ROOT_DIR", MISSING),
    split="train",
)
class CheXpert(Dataset[Example]):
    """CheXpert dataset.

    Each datapoint is a pair of `(image, target label)`.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "valid"}
        Dataset split.
    labeler : Optional[{"chexpert", "chexbert", "vchexbert"}], optional, default=None
        Labeler used to extract labels from the training images. "valid" split
        has no labeler, labeling for valid split was done by human radiologists.
    transform : Optional[Callable[[PIL.Image], torch.Tensor], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "valid"],
        labeler: Optional[Literal["chexpert", "chexbert", "vchexbert"]] = None,
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "valid"], f"split {split} is not available."
        assert labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None, (
            f"labeler {labeler} is not available."
        )
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        if split == "valid":
            data_file = f"{split}_data.json"
        elif split == "train":
            data_file = f"{labeler}_{split}_data.json"
        data_path = os.path.join(root_dir, data_file)

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

    def __getitem__(self, idx: int) -> Example:
        """Return the idx'th datapoint."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        return Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return the idx'th datapoint.

Source code in mmlearn/datasets/chexpert.py
def __getitem__(self, idx: int) -> Example:
    """Return the idx'th datapoint."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    return Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/chexpert.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

core

Modules for core dataloading functionality.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]
DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch
Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)
CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)
DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch
find_matching_indices
find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices
combined_dataset

Wrapper for combining multiple datasets into one.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]
data_collator

Data collators for batching examples.

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch
collate_example_list
collate_example_list(examples)

Collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch.

Parameters:

Name Type Description Default
examples list[Example]

list of examples to collate.

required

Returns:

Type Description
dict[str, Any]

Dictionary of batched examples.

Source code in mmlearn/datasets/core/data_collator.py
def collate_example_list(examples: list[Example]) -> dict[str, Any]:
    """Collate a list of :py:class:`~mmlearn.datasets.core.example.Example` objects into a batch.

    Parameters
    ----------
    examples : list[Example]
        list of examples to collate.

    Returns
    -------
    dict[str, Any]
        Dictionary of batched examples.

    """  # noqa: W505
    return _collate_example_dict(_merge_examples(examples))
example

Module for example-related classes and functions.

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)
find_matching_indices
find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices
modalities

Module for managing supported modalities in the library.

Modality dataclass

A representation of a modality in the library.

This class is used to represent a modality in the library. It contains the name of the modality and the properties that can be associated with it. The properties are dynamically generated based on the name of the modality and can be accessed as attributes of the class.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality, by default None

None

Raises:

Type Description
ValueError

If the property already exists for the modality or if the format string is invalid.

Source code in mmlearn/datasets/core/modalities.py
@dataclass
class Modality:
    """A representation of a modality in the library.

    This class is used to represent a modality in the library. It contains the name of
    the modality and the properties that can be associated with it. The properties are
    dynamically generated based on the name of the modality and can be accessed as
    attributes of the class.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality, by default None

    Raises
    ------
    ValueError
        If the property already exists for the modality or if the format string is
        invalid.
    """

    #: The name of the modality.
    name: str

    #: Target/label associated with the modality. This will return ``name_target``.
    target: str = field(init=False, repr=False)

    #: Attention mask associated with the modality. This will return
    # ``name_attention_mask``.
    attention_mask: str = field(init=False, repr=False)

    #: Input mask associated with the modality. This will return ``name_mask``.
    mask: str = field(init=False, repr=False)

    #: Embedding associated with the modality. This will return ``name_embedding``.
    embedding: str = field(init=False, repr=False)

    #: Masked embedding associated with the modality. This will return
    # ``name_masked_embedding``.
    masked_embedding: str = field(init=False, repr=False)

    #: Embedding from an Exponential Moving Average (EMA) encoder associated with
    #: the modality.
    ema_embedding: str = field(init=False, repr=False)

    #: Other properties specific to the modality.
    modality_specific_properties: Optional[dict[str, str]] = field(
        default=None, repr=False
    )

    def __post_init__(self) -> None:
        """Initialize the modality with the name and properties."""
        self.name = self.name.lower()
        self._properties = {}

        for field_name in self.__dataclass_fields__:
            if field_name not in ("name", "modality_specific_properties"):
                field_value = f"{self.name}_{field_name}"
                self._properties[field_name] = field_value
                setattr(self, field_name, field_value)

        if self.modality_specific_properties is not None:
            for (
                property_name,
                format_string,
            ) in self.modality_specific_properties.items():
                self.add_property(property_name, format_string)

    @property
    def properties(self) -> dict[str, str]:
        """Return the properties associated with the modality."""
        return self._properties

    def add_property(self, name: str, format_string: str) -> None:
        """Add a new property to the modality.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the modality. It will overwrite the
            existing property.

        Raises
        ------
        ValueError
            If `format_string` is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        if name in self._properties:
            warnings.warn(
                f"Property '{name}' already exists for modality '{super().__str__()}'."
                "Will overwrite the existing property.",
                category=UserWarning,
                stacklevel=2,
            )

        if not _is_format_string(format_string):
            raise ValueError(
                f"Invalid format string '{format_string}' for property "
                f"'{name}' of modality '{super().__str__()}'."
            )

        self._properties[name] = format_string.format(self.name)
        setattr(self, name, self._properties[name])

    def __str__(self) -> str:
        """Return the object as a string."""
        return self.name.lower()
properties property
properties

Return the properties associated with the modality.

__post_init__
__post_init__()

Initialize the modality with the name and properties.

Source code in mmlearn/datasets/core/modalities.py
def __post_init__(self) -> None:
    """Initialize the modality with the name and properties."""
    self.name = self.name.lower()
    self._properties = {}

    for field_name in self.__dataclass_fields__:
        if field_name not in ("name", "modality_specific_properties"):
            field_value = f"{self.name}_{field_name}"
            self._properties[field_name] = field_value
            setattr(self, field_name, field_value)

    if self.modality_specific_properties is not None:
        for (
            property_name,
            format_string,
        ) in self.modality_specific_properties.items():
            self.add_property(property_name, format_string)
add_property
add_property(name, format_string)

Add a new property to the modality.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the modality. It will overwrite the existing property.

Raises:

Type Description
ValueError

If format_string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_property(self, name: str, format_string: str) -> None:
    """Add a new property to the modality.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the modality. It will overwrite the
        existing property.

    Raises
    ------
    ValueError
        If `format_string` is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    if name in self._properties:
        warnings.warn(
            f"Property '{name}' already exists for modality '{super().__str__()}'."
            "Will overwrite the existing property.",
            category=UserWarning,
            stacklevel=2,
        )

    if not _is_format_string(format_string):
        raise ValueError(
            f"Invalid format string '{format_string}' for property "
            f"'{name}' of modality '{super().__str__()}'."
        )

    self._properties[name] = format_string.format(self.name)
    setattr(self, name, self._properties[name])
__str__
__str__()

Return the object as a string.

Source code in mmlearn/datasets/core/modalities.py
def __str__(self) -> str:
    """Return the object as a string."""
    return self.name.lower()
ModalityRegistry

Modality registry.

A singleton class that manages the supported modalities (and their properties) in the library. The class provides methods to add new modalities and properties, and to access the existing modalities. The class is implemented as a singleton to ensure that there is only one instance of the registry in the library.

Source code in mmlearn/datasets/core/modalities.py
class ModalityRegistry:
    """Modality registry.

    A singleton class that manages the supported modalities (and their properties) in
    the library. The class provides methods to add new modalities and properties, and
    to access the existing modalities. The class is implemented as a singleton to
    ensure that there is only one instance of the registry in the library.
    """

    _instance: ClassVar[Any] = None
    _modality_registry: dict[str, Modality] = {}

    def __new__(cls) -> Self:
        """Create a new instance of the class if it does not exist."""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._modality_registry = {}
        return cls._instance  # type: ignore[no-any-return]

    def register_modality(
        self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
    ) -> None:
        """Add a new modality to the registry.

        Parameters
        ----------
        name : str
            The name of the modality.
        modality_specific_properties : Optional[dict[str, str]], optional, default=None
            Additional properties specific to the modality.

        Warns
        -----
        UserWarning
            If the modality already exists in the registry. It will overwrite the
            existing modality.

        """
        if name.lower() in self._modality_registry:
            warnings.warn(
                f"Modality '{name}' already exists in the registry. Overwriting...",
                category=UserWarning,
                stacklevel=2,
            )

        name = name.lower()
        modality = Modality(name, modality_specific_properties)
        self._modality_registry[name] = modality
        setattr(self, name, modality)

    def add_default_property(self, name: str, format_string: str) -> None:
        """Add a new property that is applicable to all modalities.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the default properties. It will
            overwrite the existing property.

        Raises
        ------
        ValueError
            If the format string is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        for modality in self._modality_registry.values():
            modality.add_property(name, format_string)

    def has_modality(self, name: str) -> bool:
        """Check if the modality exists in the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        bool
            True if the modality exists in the registry, False otherwise.
        """
        return name.lower() in self._modality_registry

    def get_modality(self, name: str) -> Modality:
        """Get the modality name from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        Modality
            The modality object from the registry.
        """
        return self._modality_registry[name.lower()]

    def get_modality_properties(self, name: str) -> dict[str, str]:
        """Get the properties of a modality from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        dict[str, str]
            The properties associated with the modality.
        """
        return self.get_modality(name).properties

    def list_modalities(self) -> list[Modality]:
        """Get the list of supported modalities in the registry.

        Returns
        -------
        list[Modality]
            The list of supported modalities in the registry.
        """
        return list(self._modality_registry.values())

    def __getattr__(self, name: str) -> Modality:
        """Access a modality as an attribute by its name."""
        if name.lower() in self._modality_registry:
            return self._modality_registry[name.lower()]
        raise AttributeError(
            f"'{self.__class__.__name__}' object has no attribute '{name}'"
        )
__new__
__new__()

Create a new instance of the class if it does not exist.

Source code in mmlearn/datasets/core/modalities.py
def __new__(cls) -> Self:
    """Create a new instance of the class if it does not exist."""
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._modality_registry = {}
    return cls._instance  # type: ignore[no-any-return]
register_modality
register_modality(name, modality_specific_properties=None)

Add a new modality to the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality.

None

Warns:

Type Description
UserWarning

If the modality already exists in the registry. It will overwrite the existing modality.

Source code in mmlearn/datasets/core/modalities.py
def register_modality(
    self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
) -> None:
    """Add a new modality to the registry.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality.

    Warns
    -----
    UserWarning
        If the modality already exists in the registry. It will overwrite the
        existing modality.

    """
    if name.lower() in self._modality_registry:
        warnings.warn(
            f"Modality '{name}' already exists in the registry. Overwriting...",
            category=UserWarning,
            stacklevel=2,
        )

    name = name.lower()
    modality = Modality(name, modality_specific_properties)
    self._modality_registry[name] = modality
    setattr(self, name, modality)
add_default_property
add_default_property(name, format_string)

Add a new property that is applicable to all modalities.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the default properties. It will overwrite the existing property.

Raises:

Type Description
ValueError

If the format string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_default_property(self, name: str, format_string: str) -> None:
    """Add a new property that is applicable to all modalities.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the default properties. It will
        overwrite the existing property.

    Raises
    ------
    ValueError
        If the format string is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    for modality in self._modality_registry.values():
        modality.add_property(name, format_string)
has_modality
has_modality(name)

Check if the modality exists in the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
bool

True if the modality exists in the registry, False otherwise.

Source code in mmlearn/datasets/core/modalities.py
def has_modality(self, name: str) -> bool:
    """Check if the modality exists in the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    bool
        True if the modality exists in the registry, False otherwise.
    """
    return name.lower() in self._modality_registry
get_modality
get_modality(name)

Get the modality name from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
Modality

The modality object from the registry.

Source code in mmlearn/datasets/core/modalities.py
def get_modality(self, name: str) -> Modality:
    """Get the modality name from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    Modality
        The modality object from the registry.
    """
    return self._modality_registry[name.lower()]
get_modality_properties
get_modality_properties(name)

Get the properties of a modality from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
dict[str, str]

The properties associated with the modality.

Source code in mmlearn/datasets/core/modalities.py
def get_modality_properties(self, name: str) -> dict[str, str]:
    """Get the properties of a modality from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    dict[str, str]
        The properties associated with the modality.
    """
    return self.get_modality(name).properties
list_modalities
list_modalities()

Get the list of supported modalities in the registry.

Returns:

Type Description
list[Modality]

The list of supported modalities in the registry.

Source code in mmlearn/datasets/core/modalities.py
def list_modalities(self) -> list[Modality]:
    """Get the list of supported modalities in the registry.

    Returns
    -------
    list[Modality]
        The list of supported modalities in the registry.
    """
    return list(self._modality_registry.values())
__getattr__
__getattr__(name)

Access a modality as an attribute by its name.

Source code in mmlearn/datasets/core/modalities.py
def __getattr__(self, name: str) -> Modality:
    """Access a modality as an attribute by its name."""
    if name.lower() in self._modality_registry:
        return self._modality_registry[name.lower()]
    raise AttributeError(
        f"'{self.__class__.__name__}' object has no attribute '{name}'"
    )
samplers

Samplers for data loading.

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)
DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

imagenet

ImageNet dataset.

ImageNet

Bases: ImageFolder

ImageNet dataset.

This is a wrapper around the 🇵🇾class:~torchvision.datasets.ImageFolder class that returns an 🇵🇾class:~mmlearn.datasets.core.example.Example object.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, val)

The split of the dataset to use.

"train"
transform Optional[Callable]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
target_transform Optional[Callable]

A callable that takes in the target and transforms it.

None
mask_generator Optional[Callable]

A callable that generates a mask for the image.

None
Source code in mmlearn/datasets/imagenet.py
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("IMAGENET_ROOT_DIR", MISSING),
)
class ImageNet(ImageFolder):
    """ImageNet dataset.

    This is a wrapper around the :py:class:`~torchvision.datasets.ImageFolder` class
    that returns an :py:class:`~mmlearn.datasets.core.example.Example` object.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "val"}, default="train"
        The split of the dataset to use.
    transform : Optional[Callable], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    target_transform : Optional[Callable], optional, default=None
        A callable that takes in the target and transforms it.
    mask_generator : Optional[Callable], optional, default=None
        A callable that generates a mask for the image.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "val"] = "train",
        transform: Optional[Callable[..., Any]] = None,
        target_transform: Optional[Callable[..., Any]] = None,
        mask_generator: Optional[Callable[..., Any]] = None,
    ) -> None:
        split = "train" if split == "train" else "val"
        root_dir = os.path.join(root_dir, split)
        super().__init__(
            root=root_dir, transform=transform, target_transform=target_transform
        )
        self.mask_generator = mask_generator

    def __getitem__(self, index: int) -> Example:
        """Get an example at the given index."""
        image, target = super().__getitem__(index)
        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: target,
                EXAMPLE_INDEX_KEY: index,
            }
        )
        mask = self.mask_generator() if self.mask_generator else None
        if mask is not None:  # error will be raised during collation if `None`
            example[Modalities.RGB.mask] = mask
        return example

    @property
    def zero_shot_prompt_templates(self) -> list[str]:
        """Return the zero-shot prompt templates."""
        return [
            "a bad photo of a {}.",
            "a photo of many {}.",
            "a sculpture of a {}.",
            "a photo of the hard to see {}.",
            "a low resolution photo of the {}.",
            "a rendering of a {}.",
            "graffiti of a {}.",
            "a bad photo of the {}.",
            "a cropped photo of the {}.",
            "a tattoo of a {}.",
            "the embroidered {}.",
            "a photo of a hard to see {}.",
            "a bright photo of a {}.",
            "a photo of a clean {}.",
            "a photo of a dirty {}.",
            "a dark photo of the {}.",
            "a drawing of a {}.",
            "a photo of my {}.",
            "the plastic {}.",
            "a photo of the cool {}.",
            "a close-up photo of a {}.",
            "a black and white photo of the {}.",
            "a painting of the {}.",
            "a painting of a {}.",
            "a pixelated photo of the {}.",
            "a sculpture of the {}.",
            "a bright photo of the {}.",
            "a cropped photo of a {}.",
            "a plastic {}.",
            "a photo of the dirty {}.",
            "a jpeg corrupted photo of a {}.",
            "a blurry photo of the {}.",
            "a photo of the {}.",
            "a good photo of the {}.",
            "a rendering of the {}.",
            "a {} in a video game.",
            "a photo of one {}.",
            "a doodle of a {}.",
            "a close-up photo of the {}.",
            "a photo of a {}.",
            "the origami {}.",
            "the {} in a video game.",
            "a sketch of a {}.",
            "a doodle of the {}.",
            "a origami {}.",
            "a low resolution photo of a {}.",
            "the toy {}.",
            "a rendition of the {}.",
            "a photo of the clean {}.",
            "a photo of a large {}.",
            "a rendition of a {}.",
            "a photo of a nice {}.",
            "a photo of a weird {}.",
            "a blurry photo of a {}.",
            "a cartoon {}.",
            "art of a {}.",
            "a sketch of the {}.",
            "a embroidered {}.",
            "a pixelated photo of a {}.",
            "itap of the {}.",
            "a jpeg corrupted photo of the {}.",
            "a good photo of a {}.",
            "a plushie {}.",
            "a photo of the nice {}.",
            "a photo of the small {}.",
            "a photo of the weird {}.",
            "the cartoon {}.",
            "art of the {}.",
            "a drawing of the {}.",
            "a photo of the large {}.",
            "a black and white photo of a {}.",
            "the plushie {}.",
            "a dark photo of a {}.",
            "itap of a {}.",
            "graffiti of the {}.",
            "a toy {}.",
            "itap of my {}.",
            "a photo of a cool {}.",
            "a photo of a small {}.",
            "a tattoo of the {}.",
        ]

    @property
    def id2label(self) -> dict[int, str]:
        """Return the label mapping."""
        return {
            0: "tench",
            1: "goldfish",
            2: "great white shark",
            3: "tiger shark",
            4: "hammerhead shark",
            5: "electric ray",
            6: "stingray",
            7: "rooster",
            8: "hen",
            9: "ostrich",
            10: "brambling",
            11: "goldfinch",
            12: "house finch",
            13: "junco",
            14: "indigo bunting",
            15: "American robin",
            16: "bulbul",
            17: "jay",
            18: "magpie",
            19: "chickadee",
            20: "American dipper",
            21: "kite (bird of prey)",
            22: "bald eagle",
            23: "vulture",
            24: "great grey owl",
            25: "fire salamander",
            26: "smooth newt",
            27: "newt",
            28: "spotted salamander",
            29: "axolotl",
            30: "American bullfrog",
            31: "tree frog",
            32: "tailed frog",
            33: "loggerhead sea turtle",
            34: "leatherback sea turtle",
            35: "mud turtle",
            36: "terrapin",
            37: "box turtle",
            38: "banded gecko",
            39: "green iguana",
            40: "Carolina anole",
            41: "desert grassland whiptail lizard",
            42: "agama",
            43: "frilled-necked lizard",
            44: "alligator lizard",
            45: "Gila monster",
            46: "European green lizard",
            47: "chameleon",
            48: "Komodo dragon",
            49: "Nile crocodile",
            50: "American alligator",
            51: "triceratops",
            52: "worm snake",
            53: "ring-necked snake",
            54: "eastern hog-nosed snake",
            55: "smooth green snake",
            56: "kingsnake",
            57: "garter snake",
            58: "water snake",
            59: "vine snake",
            60: "night snake",
            61: "boa constrictor",
            62: "African rock python",
            63: "Indian cobra",
            64: "green mamba",
            65: "sea snake",
            66: "Saharan horned viper",
            67: "eastern diamondback rattlesnake",
            68: "sidewinder rattlesnake",
            69: "trilobite",
            70: "harvestman",
            71: "scorpion",
            72: "yellow garden spider",
            73: "barn spider",
            74: "European garden spider",
            75: "southern black widow",
            76: "tarantula",
            77: "wolf spider",
            78: "tick",
            79: "centipede",
            80: "black grouse",
            81: "ptarmigan",
            82: "ruffed grouse",
            83: "prairie grouse",
            84: "peafowl",
            85: "quail",
            86: "partridge",
            87: "african grey parrot",
            88: "macaw",
            89: "sulphur-crested cockatoo",
            90: "lorikeet",
            91: "coucal",
            92: "bee eater",
            93: "hornbill",
            94: "hummingbird",
            95: "jacamar",
            96: "toucan",
            97: "duck",
            98: "red-breasted merganser",
            99: "goose",
            100: "black swan",
            101: "tusker",
            102: "echidna",
            103: "platypus",
            104: "wallaby",
            105: "koala",
            106: "wombat",
            107: "jellyfish",
            108: "sea anemone",
            109: "brain coral",
            110: "flatworm",
            111: "nematode",
            112: "conch",
            113: "snail",
            114: "slug",
            115: "sea slug",
            116: "chiton",
            117: "chambered nautilus",
            118: "Dungeness crab",
            119: "rock crab",
            120: "fiddler crab",
            121: "red king crab",
            122: "American lobster",
            123: "spiny lobster",
            124: "crayfish",
            125: "hermit crab",
            126: "isopod",
            127: "white stork",
            128: "black stork",
            129: "spoonbill",
            130: "flamingo",
            131: "little blue heron",
            132: "great egret",
            133: "bittern bird",
            134: "crane bird",
            135: "limpkin",
            136: "common gallinule",
            137: "American coot",
            138: "bustard",
            139: "ruddy turnstone",
            140: "dunlin",
            141: "common redshank",
            142: "dowitcher",
            143: "oystercatcher",
            144: "pelican",
            145: "king penguin",
            146: "albatross",
            147: "grey whale",
            148: "killer whale",
            149: "dugong",
            150: "sea lion",
            151: "Chihuahua",
            152: "Japanese Chin",
            153: "Maltese",
            154: "Pekingese",
            155: "Shih Tzu",
            156: "King Charles Spaniel",
            157: "Papillon",
            158: "toy terrier",
            159: "Rhodesian Ridgeback",
            160: "Afghan Hound",
            161: "Basset Hound",
            162: "Beagle",
            163: "Bloodhound",
            164: "Bluetick Coonhound",
            165: "Black and Tan Coonhound",
            166: "Treeing Walker Coonhound",
            167: "English foxhound",
            168: "Redbone Coonhound",
            169: "borzoi",
            170: "Irish Wolfhound",
            171: "Italian Greyhound",
            172: "Whippet",
            173: "Ibizan Hound",
            174: "Norwegian Elkhound",
            175: "Otterhound",
            176: "Saluki",
            177: "Scottish Deerhound",
            178: "Weimaraner",
            179: "Staffordshire Bull Terrier",
            180: "American Staffordshire Terrier",
            181: "Bedlington Terrier",
            182: "Border Terrier",
            183: "Kerry Blue Terrier",
            184: "Irish Terrier",
            185: "Norfolk Terrier",
            186: "Norwich Terrier",
            187: "Yorkshire Terrier",
            188: "Wire Fox Terrier",
            189: "Lakeland Terrier",
            190: "Sealyham Terrier",
            191: "Airedale Terrier",
            192: "Cairn Terrier",
            193: "Australian Terrier",
            194: "Dandie Dinmont Terrier",
            195: "Boston Terrier",
            196: "Miniature Schnauzer",
            197: "Giant Schnauzer",
            198: "Standard Schnauzer",
            199: "Scottish Terrier",
            200: "Tibetan Terrier",
            201: "Australian Silky Terrier",
            202: "Soft-coated Wheaten Terrier",
            203: "West Highland White Terrier",
            204: "Lhasa Apso",
            205: "Flat-Coated Retriever",
            206: "Curly-coated Retriever",
            207: "Golden Retriever",
            208: "Labrador Retriever",
            209: "Chesapeake Bay Retriever",
            210: "German Shorthaired Pointer",
            211: "Vizsla",
            212: "English Setter",
            213: "Irish Setter",
            214: "Gordon Setter",
            215: "Brittany dog",
            216: "Clumber Spaniel",
            217: "English Springer Spaniel",
            218: "Welsh Springer Spaniel",
            219: "Cocker Spaniel",
            220: "Sussex Spaniel",
            221: "Irish Water Spaniel",
            222: "Kuvasz",
            223: "Schipperke",
            224: "Groenendael dog",
            225: "Malinois",
            226: "Briard",
            227: "Australian Kelpie",
            228: "Komondor",
            229: "Old English Sheepdog",
            230: "Shetland Sheepdog",
            231: "collie",
            232: "Border Collie",
            233: "Bouvier des Flandres dog",
            234: "Rottweiler",
            235: "German Shepherd Dog",
            236: "Dobermann",
            237: "Miniature Pinscher",
            238: "Greater Swiss Mountain Dog",
            239: "Bernese Mountain Dog",
            240: "Appenzeller Sennenhund",
            241: "Entlebucher Sennenhund",
            242: "Boxer",
            243: "Bullmastiff",
            244: "Tibetan Mastiff",
            245: "French Bulldog",
            246: "Great Dane",
            247: "St. Bernard",
            248: "husky",
            249: "Alaskan Malamute",
            250: "Siberian Husky",
            251: "Dalmatian",
            252: "Affenpinscher",
            253: "Basenji",
            254: "pug",
            255: "Leonberger",
            256: "Newfoundland dog",
            257: "Great Pyrenees dog",
            258: "Samoyed",
            259: "Pomeranian",
            260: "Chow Chow",
            261: "Keeshond",
            262: "brussels griffon",
            263: "Pembroke Welsh Corgi",
            264: "Cardigan Welsh Corgi",
            265: "Toy Poodle",
            266: "Miniature Poodle",
            267: "Standard Poodle",
            268: "Mexican hairless dog (xoloitzcuintli)",
            269: "grey wolf",
            270: "Alaskan tundra wolf",
            271: "red wolf or maned wolf",
            272: "coyote",
            273: "dingo",
            274: "dhole",
            275: "African wild dog",
            276: "hyena",
            277: "red fox",
            278: "kit fox",
            279: "Arctic fox",
            280: "grey fox",
            281: "tabby cat",
            282: "tiger cat",
            283: "Persian cat",
            284: "Siamese cat",
            285: "Egyptian Mau",
            286: "cougar",
            287: "lynx",
            288: "leopard",
            289: "snow leopard",
            290: "jaguar",
            291: "lion",
            292: "tiger",
            293: "cheetah",
            294: "brown bear",
            295: "American black bear",
            296: "polar bear",
            297: "sloth bear",
            298: "mongoose",
            299: "meerkat",
            300: "tiger beetle",
            301: "ladybug",
            302: "ground beetle",
            303: "longhorn beetle",
            304: "leaf beetle",
            305: "dung beetle",
            306: "rhinoceros beetle",
            307: "weevil",
            308: "fly",
            309: "bee",
            310: "ant",
            311: "grasshopper",
            312: "cricket insect",
            313: "stick insect",
            314: "cockroach",
            315: "praying mantis",
            316: "cicada",
            317: "leafhopper",
            318: "lacewing",
            319: "dragonfly",
            320: "damselfly",
            321: "red admiral butterfly",
            322: "ringlet butterfly",
            323: "monarch butterfly",
            324: "small white butterfly",
            325: "sulphur butterfly",
            326: "gossamer-winged butterfly",
            327: "starfish",
            328: "sea urchin",
            329: "sea cucumber",
            330: "cottontail rabbit",
            331: "hare",
            332: "Angora rabbit",
            333: "hamster",
            334: "porcupine",
            335: "fox squirrel",
            336: "marmot",
            337: "beaver",
            338: "guinea pig",
            339: "common sorrel horse",
            340: "zebra",
            341: "pig",
            342: "wild boar",
            343: "warthog",
            344: "hippopotamus",
            345: "ox",
            346: "water buffalo",
            347: "bison",
            348: "ram (adult male sheep)",
            349: "bighorn sheep",
            350: "Alpine ibex",
            351: "hartebeest",
            352: "impala (antelope)",
            353: "gazelle",
            354: "arabian camel",
            355: "llama",
            356: "weasel",
            357: "mink",
            358: "European polecat",
            359: "black-footed ferret",
            360: "otter",
            361: "skunk",
            362: "badger",
            363: "armadillo",
            364: "three-toed sloth",
            365: "orangutan",
            366: "gorilla",
            367: "chimpanzee",
            368: "gibbon",
            369: "siamang",
            370: "guenon",
            371: "patas monkey",
            372: "baboon",
            373: "macaque",
            374: "langur",
            375: "black-and-white colobus",
            376: "proboscis monkey",
            377: "marmoset",
            378: "white-headed capuchin",
            379: "howler monkey",
            380: "titi monkey",
            381: "Geoffroy's spider monkey",
            382: "common squirrel monkey",
            383: "ring-tailed lemur",
            384: "indri",
            385: "Asian elephant",
            386: "African bush elephant",
            387: "red panda",
            388: "giant panda",
            389: "snoek fish",
            390: "eel",
            391: "silver salmon",
            392: "rock beauty fish",
            393: "clownfish",
            394: "sturgeon",
            395: "gar fish",
            396: "lionfish",
            397: "pufferfish",
            398: "abacus",
            399: "abaya",
            400: "academic gown",
            401: "accordion",
            402: "acoustic guitar",
            403: "aircraft carrier",
            404: "airliner",
            405: "airship",
            406: "altar",
            407: "ambulance",
            408: "amphibious vehicle",
            409: "analog clock",
            410: "apiary",
            411: "apron",
            412: "trash can",
            413: "assault rifle",
            414: "backpack",
            415: "bakery",
            416: "balance beam",
            417: "balloon",
            418: "ballpoint pen",
            419: "Band-Aid",
            420: "banjo",
            421: "baluster / handrail",
            422: "barbell",
            423: "barber chair",
            424: "barbershop",
            425: "barn",
            426: "barometer",
            427: "barrel",
            428: "wheelbarrow",
            429: "baseball",
            430: "basketball",
            431: "bassinet",
            432: "bassoon",
            433: "swimming cap",
            434: "bath towel",
            435: "bathtub",
            436: "station wagon",
            437: "lighthouse",
            438: "beaker",
            439: "military hat (bearskin or shako)",
            440: "beer bottle",
            441: "beer glass",
            442: "bell tower",
            443: "baby bib",
            444: "tandem bicycle",
            445: "bikini",
            446: "ring binder",
            447: "binoculars",
            448: "birdhouse",
            449: "boathouse",
            450: "bobsleigh",
            451: "bolo tie",
            452: "poke bonnet",
            453: "bookcase",
            454: "bookstore",
            455: "bottle cap",
            456: "hunting bow",
            457: "bow tie",
            458: "brass memorial plaque",
            459: "bra",
            460: "breakwater",
            461: "breastplate",
            462: "broom",
            463: "bucket",
            464: "buckle",
            465: "bulletproof vest",
            466: "high-speed train",
            467: "butcher shop",
            468: "taxicab",
            469: "cauldron",
            470: "candle",
            471: "cannon",
            472: "canoe",
            473: "can opener",
            474: "cardigan",
            475: "car mirror",
            476: "carousel",
            477: "tool kit",
            478: "cardboard box / carton",
            479: "car wheel",
            480: "automated teller machine",
            481: "cassette",
            482: "cassette player",
            483: "castle",
            484: "catamaran",
            485: "CD player",
            486: "cello",
            487: "mobile phone",
            488: "chain",
            489: "chain-link fence",
            490: "chain mail",
            491: "chainsaw",
            492: "storage chest",
            493: "chiffonier",
            494: "bell or wind chime",
            495: "china cabinet",
            496: "Christmas stocking",
            497: "church",
            498: "movie theater",
            499: "cleaver",
            500: "cliff dwelling",
            501: "cloak",
            502: "clogs",
            503: "cocktail shaker",
            504: "coffee mug",
            505: "coffeemaker",
            506: "spiral or coil",
            507: "combination lock",
            508: "computer keyboard",
            509: "candy store",
            510: "container ship",
            511: "convertible",
            512: "corkscrew",
            513: "cornet",
            514: "cowboy boot",
            515: "cowboy hat",
            516: "cradle",
            517: "construction crane",
            518: "crash helmet",
            519: "crate",
            520: "infant bed",
            521: "Crock Pot",
            522: "croquet ball",
            523: "crutch",
            524: "cuirass",
            525: "dam",
            526: "desk",
            527: "desktop computer",
            528: "rotary dial telephone",
            529: "diaper",
            530: "digital clock",
            531: "digital watch",
            532: "dining table",
            533: "dishcloth",
            534: "dishwasher",
            535: "disc brake",
            536: "dock",
            537: "dog sled",
            538: "dome",
            539: "doormat",
            540: "drilling rig",
            541: "drum",
            542: "drumstick",
            543: "dumbbell",
            544: "Dutch oven",
            545: "electric fan",
            546: "electric guitar",
            547: "electric locomotive",
            548: "entertainment center",
            549: "envelope",
            550: "espresso machine",
            551: "face powder",
            552: "feather boa",
            553: "filing cabinet",
            554: "fireboat",
            555: "fire truck",
            556: "fire screen",
            557: "flagpole",
            558: "flute",
            559: "folding chair",
            560: "football helmet",
            561: "forklift",
            562: "fountain",
            563: "fountain pen",
            564: "four-poster bed",
            565: "freight car",
            566: "French horn",
            567: "frying pan",
            568: "fur coat",
            569: "garbage truck",
            570: "gas mask or respirator",
            571: "gas pump",
            572: "goblet",
            573: "go-kart",
            574: "golf ball",
            575: "golf cart",
            576: "gondola",
            577: "gong",
            578: "gown",
            579: "grand piano",
            580: "greenhouse",
            581: "radiator grille",
            582: "grocery store",
            583: "guillotine",
            584: "hair clip",
            585: "hair spray",
            586: "half-track",
            587: "hammer",
            588: "hamper",
            589: "hair dryer",
            590: "hand-held computer",
            591: "handkerchief",
            592: "hard disk drive",
            593: "harmonica",
            594: "harp",
            595: "combine harvester",
            596: "hatchet",
            597: "holster",
            598: "home theater",
            599: "honeycomb",
            600: "hook",
            601: "hoop skirt",
            602: "gymnastic horizontal bar",
            603: "horse-drawn vehicle",
            604: "hourglass",
            605: "iPod",
            606: "clothes iron",
            607: "carved pumpkin",
            608: "jeans",
            609: "jeep",
            610: "T-shirt",
            611: "jigsaw puzzle",
            612: "rickshaw",
            613: "joystick",
            614: "kimono",
            615: "knee pad",
            616: "knot",
            617: "lab coat",
            618: "ladle",
            619: "lampshade",
            620: "laptop computer",
            621: "lawn mower",
            622: "lens cap",
            623: "letter opener",
            624: "library",
            625: "lifeboat",
            626: "lighter",
            627: "limousine",
            628: "ocean liner",
            629: "lipstick",
            630: "slip-on shoe",
            631: "lotion",
            632: "music speaker",
            633: "loupe magnifying glass",
            634: "sawmill",
            635: "magnetic compass",
            636: "messenger bag",
            637: "mailbox",
            638: "tights",
            639: "one-piece bathing suit",
            640: "manhole cover",
            641: "maraca",
            642: "marimba",
            643: "mask",
            644: "matchstick",
            645: "maypole",
            646: "maze",
            647: "measuring cup",
            648: "medicine cabinet",
            649: "megalith",
            650: "microphone",
            651: "microwave oven",
            652: "military uniform",
            653: "milk can",
            654: "minibus",
            655: "miniskirt",
            656: "minivan",
            657: "missile",
            658: "mitten",
            659: "mixing bowl",
            660: "mobile home",
            661: "ford model t",
            662: "modem",
            663: "monastery",
            664: "monitor",
            665: "moped",
            666: "mortar and pestle",
            667: "graduation cap",
            668: "mosque",
            669: "mosquito net",
            670: "vespa",
            671: "mountain bike",
            672: "tent",
            673: "computer mouse",
            674: "mousetrap",
            675: "moving van",
            676: "muzzle",
            677: "metal nail",
            678: "neck brace",
            679: "necklace",
            680: "baby pacifier",
            681: "notebook computer",
            682: "obelisk",
            683: "oboe",
            684: "ocarina",
            685: "odometer",
            686: "oil filter",
            687: "pipe organ",
            688: "oscilloscope",
            689: "overskirt",
            690: "bullock cart",
            691: "oxygen mask",
            692: "product packet / packaging",
            693: "paddle",
            694: "paddle wheel",
            695: "padlock",
            696: "paintbrush",
            697: "pajamas",
            698: "palace",
            699: "pan flute",
            700: "paper towel",
            701: "parachute",
            702: "parallel bars",
            703: "park bench",
            704: "parking meter",
            705: "railroad car",
            706: "patio",
            707: "payphone",
            708: "pedestal",
            709: "pencil case",
            710: "pencil sharpener",
            711: "perfume",
            712: "Petri dish",
            713: "photocopier",
            714: "plectrum",
            715: "Pickelhaube",
            716: "picket fence",
            717: "pickup truck",
            718: "pier",
            719: "piggy bank",
            720: "pill bottle",
            721: "pillow",
            722: "ping-pong ball",
            723: "pinwheel",
            724: "pirate ship",
            725: "drink pitcher",
            726: "block plane",
            727: "planetarium",
            728: "plastic bag",
            729: "plate rack",
            730: "farm plow",
            731: "plunger",
            732: "Polaroid camera",
            733: "pole",
            734: "police van",
            735: "poncho",
            736: "pool table",
            737: "soda bottle",
            738: "plant pot",
            739: "potter's wheel",
            740: "power drill",
            741: "prayer rug",
            742: "printer",
            743: "prison",
            744: "missile",
            745: "projector",
            746: "hockey puck",
            747: "punching bag",
            748: "purse",
            749: "quill",
            750: "quilt",
            751: "race car",
            752: "racket",
            753: "radiator",
            754: "radio",
            755: "radio telescope",
            756: "rain barrel",
            757: "recreational vehicle",
            758: "fishing casting reel",
            759: "reflex camera",
            760: "refrigerator",
            761: "remote control",
            762: "restaurant",
            763: "revolver",
            764: "rifle",
            765: "rocking chair",
            766: "rotisserie",
            767: "eraser",
            768: "rugby ball",
            769: "ruler measuring stick",
            770: "sneaker",
            771: "safe",
            772: "safety pin",
            773: "salt shaker",
            774: "sandal",
            775: "sarong",
            776: "saxophone",
            777: "scabbard",
            778: "weighing scale",
            779: "school bus",
            780: "schooner",
            781: "scoreboard",
            782: "CRT monitor",
            783: "screw",
            784: "screwdriver",
            785: "seat belt",
            786: "sewing machine",
            787: "shield",
            788: "shoe store",
            789: "shoji screen / room divider",
            790: "shopping basket",
            791: "shopping cart",
            792: "shovel",
            793: "shower cap",
            794: "shower curtain",
            795: "ski",
            796: "balaclava ski mask",
            797: "sleeping bag",
            798: "slide rule",
            799: "sliding door",
            800: "slot machine",
            801: "snorkel",
            802: "snowmobile",
            803: "snowplow",
            804: "soap dispenser",
            805: "soccer ball",
            806: "sock",
            807: "solar thermal collector",
            808: "sombrero",
            809: "soup bowl",
            810: "keyboard space bar",
            811: "space heater",
            812: "space shuttle",
            813: "spatula",
            814: "motorboat",
            815: "spider web",
            816: "spindle",
            817: "sports car",
            818: "spotlight",
            819: "stage",
            820: "steam locomotive",
            821: "through arch bridge",
            822: "steel drum",
            823: "stethoscope",
            824: "scarf",
            825: "stone wall",
            826: "stopwatch",
            827: "stove",
            828: "strainer",
            829: "tram",
            830: "stretcher",
            831: "couch",
            832: "stupa",
            833: "submarine",
            834: "suit",
            835: "sundial",
            836: "sunglasses",
            837: "sunglasses",
            838: "sunscreen",
            839: "suspension bridge",
            840: "mop",
            841: "sweatshirt",
            842: "swim trunks / shorts",
            843: "swing",
            844: "electrical switch",
            845: "syringe",
            846: "table lamp",
            847: "tank",
            848: "tape player",
            849: "teapot",
            850: "teddy bear",
            851: "television",
            852: "tennis ball",
            853: "thatched roof",
            854: "front curtain",
            855: "thimble",
            856: "threshing machine",
            857: "throne",
            858: "tile roof",
            859: "toaster",
            860: "tobacco shop",
            861: "toilet seat",
            862: "torch",
            863: "totem pole",
            864: "tow truck",
            865: "toy store",
            866: "tractor",
            867: "semi-trailer truck",
            868: "tray",
            869: "trench coat",
            870: "tricycle",
            871: "trimaran",
            872: "tripod",
            873: "triumphal arch",
            874: "trolleybus",
            875: "trombone",
            876: "hot tub",
            877: "turnstile",
            878: "typewriter keyboard",
            879: "umbrella",
            880: "unicycle",
            881: "upright piano",
            882: "vacuum cleaner",
            883: "vase",
            884: "vaulted or arched ceiling",
            885: "velvet fabric",
            886: "vending machine",
            887: "vestment",
            888: "viaduct",
            889: "violin",
            890: "volleyball",
            891: "waffle iron",
            892: "wall clock",
            893: "wallet",
            894: "wardrobe",
            895: "military aircraft",
            896: "sink",
            897: "washing machine",
            898: "water bottle",
            899: "water jug",
            900: "water tower",
            901: "whiskey jug",
            902: "whistle",
            903: "hair wig",
            904: "window screen",
            905: "window shade",
            906: "Windsor tie",
            907: "wine bottle",
            908: "airplane wing",
            909: "wok",
            910: "wooden spoon",
            911: "wool",
            912: "split-rail fence",
            913: "shipwreck",
            914: "sailboat",
            915: "yurt",
            916: "website",
            917: "comic book",
            918: "crossword",
            919: "traffic or street sign",
            920: "traffic light",
            921: "dust jacket",
            922: "menu",
            923: "plate",
            924: "guacamole",
            925: "consomme",
            926: "hot pot",
            927: "trifle",
            928: "ice cream",
            929: "popsicle",
            930: "baguette",
            931: "bagel",
            932: "pretzel",
            933: "cheeseburger",
            934: "hot dog",
            935: "mashed potatoes",
            936: "cabbage",
            937: "broccoli",
            938: "cauliflower",
            939: "zucchini",
            940: "spaghetti squash",
            941: "acorn squash",
            942: "butternut squash",
            943: "cucumber",
            944: "artichoke",
            945: "bell pepper",
            946: "cardoon",
            947: "mushroom",
            948: "Granny Smith apple",
            949: "strawberry",
            950: "orange",
            951: "lemon",
            952: "fig",
            953: "pineapple",
            954: "banana",
            955: "jackfruit",
            956: "cherimoya (custard apple)",
            957: "pomegranate",
            958: "hay",
            959: "carbonara",
            960: "chocolate syrup",
            961: "dough",
            962: "meatloaf",
            963: "pizza",
            964: "pot pie",
            965: "burrito",
            966: "red wine",
            967: "espresso",
            968: "tea cup",
            969: "eggnog",
            970: "mountain",
            971: "bubble",
            972: "cliff",
            973: "coral reef",
            974: "geyser",
            975: "lakeshore",
            976: "promontory",
            977: "sandbar",
            978: "beach",
            979: "valley",
            980: "volcano",
            981: "baseball player",
            982: "bridegroom",
            983: "scuba diver",
            984: "rapeseed",
            985: "daisy",
            986: "yellow lady's slipper",
            987: "corn",
            988: "acorn",
            989: "rose hip",
            990: "horse chestnut seed",
            991: "coral fungus",
            992: "agaric",
            993: "gyromitra",
            994: "stinkhorn mushroom",
            995: "earth star fungus",
            996: "hen of the woods mushroom",
            997: "bolete",
            998: "corn cob",
            999: "toilet paper",
        }
zero_shot_prompt_templates property
zero_shot_prompt_templates

Return the zero-shot prompt templates.

id2label property
id2label

Return the label mapping.

__getitem__
__getitem__(index)

Get an example at the given index.

Source code in mmlearn/datasets/imagenet.py
def __getitem__(self, index: int) -> Example:
    """Get an example at the given index."""
    image, target = super().__getitem__(index)
    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: target,
            EXAMPLE_INDEX_KEY: index,
        }
    )
    mask = self.mask_generator() if self.mask_generator else None
    if mask is not None:  # error will be raised during collation if `None`
        example[Modalities.RGB.mask] = mask
    return example

librispeech

LibriSpeech dataset.

LibriSpeech

Bases: Dataset[Example]

LibriSpeech dataset.

This is a wrapper around 🇵🇾class:torchaudio.datasets.LIBRISPEECH that assumes that the dataset is already downloaded and the top-level directory of the dataset in the root directory is librispeech.

Parameters:

Name Type Description Default
root_dir str

Root directory of dataset.

required
split (train - clean - 100, train - clean - 360, train - other - 500, dev - clean, dev - other, test - clean, test - other)

Split of the dataset to use.

"train-clean-100"

Raises:

Type Description
ImportError

If torchaudio is not installed.

Notes

This dataset only returns the audio and transcript from the dataset.

Source code in mmlearn/datasets/librispeech.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LIBRISPEECH_ROOT_DIR", MISSING),
)
class LibriSpeech(Dataset[Example]):
    """LibriSpeech dataset.

    This is a wrapper around :py:class:`torchaudio.datasets.LIBRISPEECH` that assumes
    that the dataset is already downloaded and the top-level directory of the dataset
    in the root directory is `librispeech`.

    Parameters
    ----------
    root_dir : str
        Root directory of dataset.
    split : {"train-clean-100", "train-clean-360", "train-other-500", "dev-clean", "dev-other", "test-clean", "test-other"}, default="train-clean-100"
        Split of the dataset to use.

    Raises
    ------
    ImportError
        If ``torchaudio`` is not installed.

    Notes
    -----
    This dataset only returns the audio and transcript from the dataset.

    """  # noqa: W505

    def __init__(self, root_dir: str, split: str = "train-clean-100") -> None:
        super().__init__()
        if not _TORCHAUDIO_AVAILABLE:
            raise ImportError(
                "LibriSpeech dataset requires `torchaudio`, which is not installed."
            )
        from torchaudio.datasets import LIBRISPEECH

        self.dataset = LIBRISPEECH(
            root=root_dir,
            url=split,
            download=False,
            folder_in_archive="librispeech",
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
        assert sample_rate == SAMPLE_RATE, (
            f"Expected sample rate to be `16000`, got {sample_rate}."
        )
        waveform = pad_or_trim(waveform.flatten())

        return Example(
            {
                Modalities.AUDIO.name: waveform,
                Modalities.TEXT.name: transcript,
                EXAMPLE_INDEX_KEY: idx,
            },
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/librispeech.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/librispeech.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
    assert sample_rate == SAMPLE_RATE, (
        f"Expected sample rate to be `16000`, got {sample_rate}."
    )
    waveform = pad_or_trim(waveform.flatten())

    return Example(
        {
            Modalities.AUDIO.name: waveform,
            Modalities.TEXT.name: transcript,
            EXAMPLE_INDEX_KEY: idx,
        },
    )
pad_or_trim
pad_or_trim(array, length=30 * SAMPLE_RATE, *, axis=-1)

Pad or trim the audio array to length along the given axis.

Parameters:

Name Type Description Default
array Tensor

Audio array.

required
length int

Length to pad or trim to. Defaults to 30 seconds at 16 kHz.

480000
axis int

Axis along which to pad or trim.

-1

Returns:

Name Type Description
array Tensor

Padded or trimmed audio array.

References

.. [1] https://github.com/openai/whisper/blob/main/whisper/audio.py#L65C1-L88C17

Source code in mmlearn/datasets/librispeech.py
def pad_or_trim(
    array: torch.Tensor, length: int = 30 * SAMPLE_RATE, *, axis: int = -1
) -> torch.Tensor:
    """Pad or trim the audio array to `length` along the given axis.

    Parameters
    ----------
    array : torch.Tensor
        Audio array.
    length : int, default=480000
        Length to pad or trim to. Defaults to 30 seconds at 16 kHz.
    axis : int, default=-1
        Axis along which to pad or trim.

    Returns
    -------
    array : torch.Tensor
        Padded or trimmed audio array.

    References
    ----------
    .. [1] https://github.com/openai/whisper/blob/main/whisper/audio.py#L65C1-L88C17

    """
    if array.shape[axis] > length:
        array = array.index_select(
            dim=axis,
            index=torch.arange(length, device=array.device),
        )

    if array.shape[axis] < length:
        pad_widths = [(0, 0)] * array.ndim
        pad_widths[axis] = (0, length - array.shape[axis])
        array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])

    return array

llvip

LLVIP dataset.

LLVIPDataset

Bases: Dataset[Example]

Low-Light Visible-Infrared Pair (LLVIP) dataset.

Loads pairs of RGB and THERMAL images from the LLVIP dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset. The directory should contain 'visible' and 'infrared' subdirectories.

required
train bool

Flag to indicate whether to load the training or test set.

True
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor. This is applied to both RGB and thermal images.

None
Source code in mmlearn/datasets/llvip.py
@store(
    name="LLVIP",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LLVIP_ROOT_DIR", MISSING),
)
class LLVIPDataset(Dataset[Example]):
    """Low-Light Visible-Infrared Pair (LLVIP) dataset.

    Loads pairs of `RGB` and `THERMAL` images from the LLVIP dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset. The directory should contain
        'visible' and 'infrared' subdirectories.
    train : bool, default=True
        Flag to indicate whether to load the training or test set.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor. This is applied to both RGB and thermal
        images.
    """

    def __init__(
        self,
        root_dir: str,
        train: bool = True,
        transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ):
        self.path_images_rgb = os.path.join(
            root_dir,
            "visible",
            "train" if train else "test",
        )
        self.path_images_ir = os.path.join(
            root_dir, "infrared", "train" if train else "test"
        )
        self.train = train
        self.transform = transform or transforms.ToTensor()

        self.rgb_images = sorted(glob.glob(os.path.join(self.path_images_rgb, "*.jpg")))
        self.ir_images = sorted(glob.glob(os.path.join(self.path_images_ir, "*.jpg")))

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.rgb_images)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        rgb_image_path = self.rgb_images[idx]
        ir_image_path = self.ir_images[idx]

        rgb_image = PILImage.open(rgb_image_path).convert("RGB")
        ir_image = PILImage.open(ir_image_path).convert("L")

        example = Example(
            {
                Modalities.RGB.name: self.transform(rgb_image),
                Modalities.THERMAL.name: self.transform(ir_image),
                EXAMPLE_INDEX_KEY: idx,
            },
        )

        if self.train:
            annot_path = (
                rgb_image_path.replace("visible", "Annotations")
                .replace(".jpg", ".xml")
                .replace("train", "")
            )
            annot = self._get_bbox(annot_path)
            example["annotation"] = {
                "bboxes": torch.from_numpy(annot["bboxes"]),
                "labels": torch.from_numpy(annot["labels"]),
            }
        return example

    def _get_bbox(self, filename: str) -> dict[str, np.ndarray]:
        """Parse the XML file to get bounding boxes and labels.

        Parameters
        ----------
        filename : str
            Path to the annotation XML file.

        Returns
        -------
        dict
            A dictionary containing bounding boxes and labels.
        """
        try:
            root = ET.parse(filename).getroot()

            bboxes, labels = [], []
            for obj in root.findall("object"):
                bbox_obj = obj.find("bndbox")
                bbox = [
                    int(bbox_obj.find(dim).text)  # type: ignore[union-attr,arg-type]
                    for dim in ["xmin", "ymin", "xmax", "ymax"]
                ]
                bboxes.append(bbox)
                labels.append(1)  # Assuming 'person' is the only label
            return {
                "bboxes": np.array(bboxes).astype("float"),
                "labels": np.array(labels).astype("int"),
            }
        except ET.ParseError as e:
            raise ValueError(f"Error parsing XML: {e}") from None
        except Exception as e:
            raise RuntimeError(
                f"Error processing annotation file {filename}: {e}",
            ) from None
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/llvip.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.rgb_images)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/llvip.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    rgb_image_path = self.rgb_images[idx]
    ir_image_path = self.ir_images[idx]

    rgb_image = PILImage.open(rgb_image_path).convert("RGB")
    ir_image = PILImage.open(ir_image_path).convert("L")

    example = Example(
        {
            Modalities.RGB.name: self.transform(rgb_image),
            Modalities.THERMAL.name: self.transform(ir_image),
            EXAMPLE_INDEX_KEY: idx,
        },
    )

    if self.train:
        annot_path = (
            rgb_image_path.replace("visible", "Annotations")
            .replace(".jpg", ".xml")
            .replace("train", "")
        )
        annot = self._get_bbox(annot_path)
        example["annotation"] = {
            "bboxes": torch.from_numpy(annot["bboxes"]),
            "labels": torch.from_numpy(annot["labels"]),
        }
    return example

nihcxr

NIH Chest X-ray Dataset.

NIHCXR

Bases: Dataset[Example]

NIH Chest X-ray dataset.

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, test, bbox)

Dataset split. "bbox" is a subset of "test" which contains bounding box info.

"train"
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/nihcxr.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NIH_CXR_DIR", MISSING),
    split="train",
)
class NIHCXR(Dataset[Example]):
    """NIH Chest X-ray dataset.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "test", "bbox"}
        Dataset split. "bbox" is a subset of "test" which contains bounding box info.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test", "bbox"],
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "test", "bbox"], f"split {split} is not available."
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        data_path = os.path.join(root_dir, split + "_data.json")

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

        self.bbox = split == "bbox"

    def __getitem__(self, idx: int) -> Example:
        """Return image-label or image-label-tabular(bbox)."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

        if self.bbox:
            example["bbox"] = entry["bbox"]

        return example

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return image-label or image-label-tabular(bbox).

Source code in mmlearn/datasets/nihcxr.py
def __getitem__(self, idx: int) -> Example:
    """Return image-label or image-label-tabular(bbox)."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )

    if self.bbox:
        example["bbox"] = entry["bbox"]

    return example
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nihcxr.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

nyuv2

SUN RGB-D dataset.

NYUv2Dataset

Bases: Dataset[Example]

NYUv2 dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images.

  • "disparity": Return the depth image as disparity map.
  • "image": Return the depth image as a 3-channel image.
"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None

Raises:

Type Description
ImportError

If opencv-python is not installed.

Source code in mmlearn/datasets/nyuv2.py
@store(
    name="NYUv2",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NYUV2_ROOT_DIR", MISSING),
)
class NYUv2Dataset(Dataset[Example]):
    """NYUv2 dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images.

        - `"disparity"`: Return the depth image as disparity map.
        - `"image"`: Return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    Raises
    ------
    ImportError
        If `opencv-python` is not installed.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "NYUv2 dataset requires `opencv-python` which is not installed.",
            )
        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.png") for f in file_ids]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = [str(open(f).read().strip()) for f in label_files]  # noqa: SIM115
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]

        self.samples = list(zip(rgb_files, depth_files, labels, strict=False))

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = depth_normalize(
                self.samples[idx][1],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nyuv2.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/nyuv2.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = depth_normalize(
            self.samples[idx][1],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )
depth_normalize
depth_normalize(depth_file, min_depth=0.01, max_depth=50)

Load depth file and convert to disparity image.

Parameters:

Name Type Description Default
depth_file str

Path to the depth file.

required
min_depth float

Minimum depth value to clip the depth image.

0.01
max_depth int

Maximum depth value to clip the depth image.

50

Returns:

Type Description
Tensor

The normalized depth image.

Source code in mmlearn/datasets/nyuv2.py
def depth_normalize(
    depth_file: str, min_depth: float = 0.01, max_depth: int = 50
) -> torch.Tensor:
    """Load depth file and convert to disparity image.

    Parameters
    ----------
    depth_file : str
        Path to the depth file.
    min_depth : float, default=0.01
        Minimum depth value to clip the depth image.
    max_depth : int, default=50
        Maximum depth value to clip the depth image.

    Returns
    -------
    torch.Tensor
        The normalized depth image.
    """
    depth_image = np.array(PILImage.open(depth_file))
    depth = np.array(depth_image).astype(np.float32)
    depth_in_meters = depth / 1000.0

    if min_depth is not None:
        depth_in_meters = depth_in_meters.clip(min=min_depth, max=max_depth)

    return torch.from_numpy(depth_in_meters).float()

processors

Data processors.

BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask
RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices
HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)
TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence
masking

Token mask generators.

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices
BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask
IJEPAMaskGenerator dataclass

Generates encoder and predictor masks for preprocessing.

This class generates masks dynamically for batches of examples.

Parameters:

Name Type Description Default
input_size tuple[int, int]

Input image size.

(224, 224)
patch_size int

Size of each patch.

16
min_keep int

Minimum number of patches to keep.

10
allow_overlap bool

Whether to allow overlap between encoder and predictor masks.

False
enc_mask_scale tuple[float, float]

Scale range for encoder mask.

(0.85, 1.0)
pred_mask_scale tuple[float, float]

Scale range for predictor mask.

(0.15, 0.2)
aspect_ratio tuple[float, float]

Aspect ratio range for mask blocks.

(0.75, 1.0)
nenc int

Number of encoder masks to generate.

1
npred int

Number of predictor masks to generate.

4
Source code in mmlearn/datasets/processors/masking.py
@dataclass
class IJEPAMaskGenerator:
    """Generates encoder and predictor masks for preprocessing.

    This class generates masks dynamically for batches of examples.

    Parameters
    ----------
    input_size : tuple[int, int], default=(224, 224)
        Input image size.
    patch_size : int, default=16
        Size of each patch.
    min_keep : int, default=10
        Minimum number of patches to keep.
    allow_overlap : bool, default=False
        Whether to allow overlap between encoder and predictor masks.
    enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
        Scale range for encoder mask.
    pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
        Scale range for predictor mask.
    aspect_ratio : tuple[float, float], default=(0.75, 1.0)
        Aspect ratio range for mask blocks.
    nenc : int, default=1
        Number of encoder masks to generate.
    npred : int, default=4
        Number of predictor masks to generate.
    """

    input_size: tuple[int, int] = (224, 224)
    patch_size: int = 16
    min_keep: int = 10
    allow_overlap: bool = False
    enc_mask_scale: tuple[float, float] = (0.85, 1.0)
    pred_mask_scale: tuple[float, float] = (0.15, 0.2)
    aspect_ratio: tuple[float, float] = (0.75, 1.5)
    nenc: int = 1
    npred: int = 4

    def __post_init__(self) -> None:
        """Initialize the mask generator."""
        self.height = self.input_size[0] // self.patch_size
        self.width = self.input_size[1] // self.patch_size

    def _sample_block_size(
        self,
        generator: torch.Generator,
        scale: tuple[float, float],
        aspect_ratio: tuple[float, float],
    ) -> tuple[int, int]:
        """Sample the size of the mask block based on scale and aspect ratio."""
        _rand = torch.rand(1, generator=generator).item()
        min_s, max_s = scale
        mask_scale = min_s + _rand * (max_s - min_s)
        max_keep = int(self.height * self.width * mask_scale)

        min_ar, max_ar = aspect_ratio
        aspect_ratio_val = min_ar + _rand * (max_ar - min_ar)

        h = int(round(math.sqrt(max_keep * aspect_ratio_val)))
        w = int(round(math.sqrt(max_keep / aspect_ratio_val)))

        h = min(h, self.height - 1)
        w = min(w, self.width - 1)

        return h, w

    def _sample_block_mask(
        self, b_size: tuple[int, int]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Sample a mask block."""
        h, w = b_size
        top = torch.randint(0, self.height - h, (1,)).item()
        left = torch.randint(0, self.width - w, (1,)).item()
        mask = torch.zeros((self.height, self.width), dtype=torch.int32)
        mask[top : top + h, left : left + w] = 1

        mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
        mask_complement[top : top + h, left : left + w] = 0

        return mask.flatten(), mask_complement.flatten()

    def __call__(self, batch_size: int = 1) -> dict[str, Any]:
        """Generate encoder and predictor masks for a batch of examples.

        Parameters
        ----------
        batch_size : int, default=1
            The batch size for which to generate masks.

        Returns
        -------
        dict[str, Any]
            A dictionary of encoder masks and predictor masks.
        """
        seed = torch.randint(
            0, 2**32, (1,)
        ).item()  # Sample random seed for reproducibility
        g = torch.Generator().manual_seed(seed)

        # Sample block sizes
        p_size = self._sample_block_size(
            generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
        )
        e_size = self._sample_block_size(
            generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
        )

        # Generate predictor masks
        masks_pred, masks_enc = [], []
        for _ in range(self.npred):
            mask_p, _ = self._sample_block_mask(p_size)
            # Expand mask to match batch size
            mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
            masks_pred.append(mask_p)

        # Generate encoder masks
        for _ in range(self.nenc):
            mask_e, _ = self._sample_block_mask(e_size)
            # Expand mask to match batch size
            mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
            masks_enc.append(mask_e)

        return {
            "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
            "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
        }
__post_init__
__post_init__()

Initialize the mask generator.

Source code in mmlearn/datasets/processors/masking.py
def __post_init__(self) -> None:
    """Initialize the mask generator."""
    self.height = self.input_size[0] // self.patch_size
    self.width = self.input_size[1] // self.patch_size
__call__
__call__(batch_size=1)

Generate encoder and predictor masks for a batch of examples.

Parameters:

Name Type Description Default
batch_size int

The batch size for which to generate masks.

1

Returns:

Type Description
dict[str, Any]

A dictionary of encoder masks and predictor masks.

Source code in mmlearn/datasets/processors/masking.py
def __call__(self, batch_size: int = 1) -> dict[str, Any]:
    """Generate encoder and predictor masks for a batch of examples.

    Parameters
    ----------
    batch_size : int, default=1
        The batch size for which to generate masks.

    Returns
    -------
    dict[str, Any]
        A dictionary of encoder masks and predictor masks.
    """
    seed = torch.randint(
        0, 2**32, (1,)
    ).item()  # Sample random seed for reproducibility
    g = torch.Generator().manual_seed(seed)

    # Sample block sizes
    p_size = self._sample_block_size(
        generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
    )
    e_size = self._sample_block_size(
        generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
    )

    # Generate predictor masks
    masks_pred, masks_enc = [], []
    for _ in range(self.npred):
        mask_p, _ = self._sample_block_mask(p_size)
        # Expand mask to match batch size
        mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
        masks_pred.append(mask_p)

    # Generate encoder masks
    for _ in range(self.nenc):
        mask_e, _ = self._sample_block_mask(e_size)
        # Expand mask to match batch size
        mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
        masks_enc.append(mask_e)

    return {
        "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
        "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
    }
apply_masks
apply_masks(x, masks)

Apply masks to the input tensor by selecting the patches to keep based on the masks.

This function is primarily intended to be used for the 🇵🇾class:i-JEPA <mmlearn.tasks.ijepa.IJEPA>.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, N, D).

required
masks Union[Tensor, list[Tensor]]

A list of mask tensors of shape (N,), (1, N), or (B, N).

required

Returns:

Type Description
Tensor

The masked tensor where only the patches indicated by the masks are kept. The output tensor has shape (B * num_masks, N', D), where N' is the number of patches kept.

Source code in mmlearn/datasets/processors/masking.py
def apply_masks(
    x: torch.Tensor, masks: Union[torch.Tensor, list[torch.Tensor]]
) -> torch.Tensor:
    """
    Apply masks to the input tensor by selecting the patches to keep based on the masks.

    This function is primarily intended to be used for the
    :py:class:`i-JEPA <mmlearn.tasks.ijepa.IJEPA>`.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(B, N, D)``.
    masks : Union[torch.Tensor, list[torch.Tensor]]
        A list of mask tensors of shape ``(N,)``, ``(1, N)``, or ``(B, N)``.

    Returns
    -------
    torch.Tensor
        The masked tensor where only the patches indicated by the masks are kept.
        The output tensor has shape ``(B * num_masks, N', D)``, where ``N'`` is
        the number of patches kept.
    """
    all_x = []
    batch_size = x.size(0)
    for m_ in masks:
        m = m_.to(x.device)

        # Ensure mask is at least 2D
        if m.dim() == 1:
            m = m.unsqueeze(0)  # Shape: (1, N)

        # Expand mask to match the batch size if needed
        if m.size(0) == 1 and batch_size > 1:
            m = m.expand(batch_size, -1)  # Shape: (B, N)

        # Expand mask to match x's dimensions
        m_expanded = (
            m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
        )  # Shape: (B, N, D)

        # Use boolean indexing
        selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
        all_x.append(selected_patches)

    # Concatenate along the batch dimension
    return torch.cat(all_x, dim=0)
tokenizers

Tokenizers - modules that convert raw input to sequences of tokens.

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)
Img2Seq

Bases: Module

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
img_size tuple of int

The size of the input image.

required
patch_size tuple of int

The size of the patch.

required
n_channels int

The number of channels in the input image.

required
d_model int

The dimension of the output sequence.

required
Source code in mmlearn/datasets/processors/tokenizers.py
class Img2Seq(nn.Module):
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    img_size : tuple of int
        The size of the input image.
    patch_size : tuple of int
        The size of the patch.
    n_channels : int
        The number of channels in the input image.
    d_model : int
        The dimension of the output sequence.

    """

    def __init__(
        self,
        img_size: tuple[int, int],
        patch_size: tuple[int, int],
        n_channels: int,
        d_model: int,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        n_tokens = nh * nw

        token_dim = patch_size[0] * patch_size[1] * n_channels
        self.linear = nn.Linear(token_dim, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_emb = nn.Parameter(torch.randn(n_tokens, d_model))

    def __call__(self, batch: torch.Tensor) -> torch.Tensor:
        """Convert a batch of images to a batch of sequences.

        Parameters
        ----------
        batch : torch.Tensor
            Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
            ``h`` is the height, ``w`` is the width, and ``c`` is the number of
            channels.

        Returns
        -------
        torch.Tensor
            Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
            ``s`` is the sequence length, and ``d`` is the dimension of the output
            sequence.
        """
        batch = _patchify(batch, self.patch_size)

        b, c, nh, nw, ph, pw = batch.shape

        # Flattening the patches
        batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
        batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

        batch = self.linear(batch)
        cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
        emb: torch.Tensor = batch + self.pos_emb

        return torch.cat([cls, emb], axis=1)
__call__
__call__(batch)

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
batch Tensor

Batch of images of shape (b, h, w, c) where b is the batch size, h is the height, w is the width, and c is the number of channels.

required

Returns:

Type Description
Tensor

Batch of sequences of shape (b, s, d) where b is the batch size, s is the sequence length, and d is the dimension of the output sequence.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(self, batch: torch.Tensor) -> torch.Tensor:
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    batch : torch.Tensor
        Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
        ``h`` is the height, ``w`` is the width, and ``c`` is the number of
        channels.

    Returns
    -------
    torch.Tensor
        Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
        ``s`` is the sequence length, and ``d`` is the dimension of the output
        sequence.
    """
    batch = _patchify(batch, self.patch_size)

    b, c, nh, nw, ph, pw = batch.shape

    # Flattening the patches
    batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
    batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

    batch = self.linear(batch)
    cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
    emb: torch.Tensor = batch + self.pos_emb

    return torch.cat([cls, emb], axis=1)
transforms

Custom transforms for datasets/inputs.

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence
repeat_interleave_batch
repeat_interleave_batch(x, b, repeat)

Repeat and interleave a tensor across the batch dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor to be repeated.

required
b int

Size of the batch to be repeated.

required
repeat int

Number of times to repeat each batch.

required

Returns:

Type Description
Tensor

The repeated tensor with shape adjusted for the batch.

Source code in mmlearn/datasets/processors/transforms.py
def repeat_interleave_batch(x: torch.Tensor, b: int, repeat: int) -> torch.Tensor:
    """Repeat and interleave a tensor across the batch dimension.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor to be repeated.
    b : int
        Size of the batch to be repeated.
    repeat : int
        Number of times to repeat each batch.

    Returns
    -------
    torch.Tensor
        The repeated tensor with shape adjusted for the batch.
    """
    n = len(x) // b
    return torch.cat(
        [
            torch.cat([x[i * b : (i + 1) * b] for _ in range(repeat)], dim=0)
            for i in range(n)
        ],
        dim=0,
    )

sunrgbd

SUN RGB-D dataset.

SUNRGBDDataset

Bases: Dataset[Example]

SUN RGB-D dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images. If "disparity", the depth images are converted to disparity similar to the ImageBind implementation. Otherwise, return the depth image as a 3-channel image.

"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None
References

.. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets

Source code in mmlearn/datasets/sunrgbd.py
@store(
    name="SUNRGBD",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("SUNRGBD_ROOT_DIR", MISSING),
)
class SUNRGBDDataset(Dataset[Example]):
    """SUN RGB-D dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images. If "disparity", the depth images are
        converted to disparity similar to the ImageBind implementation.
        Otherwise, return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    References
    ----------
    .. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "SUN RGB-D dataset requires `opencv-python` which is not installed.",
            )

        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.jpg") for f in file_ids]
        intrinsic_files = [
            os.path.join(root_dir, "intrinsics", f"{f}.txt") for f in file_ids
        ]

        sensor_types = [
            file.removeprefix(os.path.join(root_dir, "depth")).split(os.sep)[1]
            for file in depth_files
        ]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = []
        for label_file in label_files:
            with open(label_file, "r") as file:  # noqa: SIM115
                labels.append(file.read().strip())
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]
            intrinsic_files = [intrinsic_files[i] for i in valid_indices]
            sensor_types = [sensor_types[i] for i in valid_indices]

        self.samples = list(
            zip(
                rgb_files,
                depth_files,
                labels,
                intrinsic_files,
                sensor_types,
                strict=False,
            )
        )

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = convert_depth_to_disparity(
                self.samples[idx][1],
                self.samples[idx][3],
                self.samples[idx][4],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/sunrgbd.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/sunrgbd.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = convert_depth_to_disparity(
            self.samples[idx][1],
            self.samples[idx][3],
            self.samples[idx][4],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )
convert_depth_to_disparity
convert_depth_to_disparity(
    depth_file,
    intrinsics_file,
    sensor_type,
    min_depth=0.01,
    max_depth=50,
)

Load depth file and convert to disparity.

Parameters:

Name Type Description Default
depth_file str

Path to the depth file.

required
intrinsics_file str

Intrinsics_file is a txt file supplied in SUNRGBD with sensor information Can be found at the path: os.path.join(root_dir, room_name, "intrinsics.txt")

required
sensor_type str

Sensor type of the depth file.

required
min_depth float

Minimum depth value to clip the depth image.

0.01
max_depth int

Maximum depth value to clip the depth image.

50

Returns:

Type Description
Tensor

Disparity image from the depth image following the ImageBind implementation.

Source code in mmlearn/datasets/sunrgbd.py
def convert_depth_to_disparity(
    depth_file: str,
    intrinsics_file: str,
    sensor_type: str,
    min_depth: float = 0.01,
    max_depth: int = 50,
) -> torch.Tensor:
    """Load depth file and convert to disparity.

    Parameters
    ----------
    depth_file : str
        Path to the depth file.
    intrinsics_file : str
        Intrinsics_file is a txt file supplied in SUNRGBD with sensor information
        Can be found at the path: os.path.join(root_dir, room_name, "intrinsics.txt")
    sensor_type : str
        Sensor type of the depth file.
    min_depth : float, default=0.01
        Minimum depth value to clip the depth image.
    max_depth : int, default=50
        Maximum depth value to clip the depth image.

    Returns
    -------
    torch.Tensor
        Disparity image from the depth image following the ImageBind implementation.
    """
    with open(intrinsics_file, "r") as fh:
        lines = fh.readlines()
        focal_length = float(lines[0].strip().split()[0])
    baseline = sensor_to_params[sensor_type]["baseline"]
    depth_image = np.array(PILImage.open(depth_file))
    depth = np.array(depth_image).astype(np.float32)
    depth_in_meters = depth / 1000.0
    if min_depth is not None:
        depth_in_meters = depth_in_meters.clip(min=min_depth, max=max_depth)
    disparity = baseline * focal_length / depth_in_meters
    return torch.from_numpy(disparity).float()

hf_utils

Utilities for loading components from the HuggingFace transformers library.

load_huggingface_model

load_huggingface_model(
    model_type,
    model_name_or_path,
    load_pretrained_weights=True,
    get_model_attr=None,
    model_config_kwargs=None,
    config_type=None,
)

Load a model from the HuggingFace transformers library.

Parameters:

Name Type Description Default
model_type Type[_BaseAutoModelClass]

The model class to instantiate e.g. transformers.AutoModel.

required
model_name_or_path str

The model name or path to load the model from.

required
load_pretrained_weights bool

Whether to load the pretrained weights or not. If false, the argument pretrained_model_name_or_path will be used to get the model configuration and the model will be initialized with random weights.

True
get_model_attr Optional[str]

If not None, the attribute of the model to return. For example, if the model is an transformers.AutoModel and get_model_attr='encoder', the encoder part of the model will be returned. If None, the full model will be returned.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration. The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

None
config_type Optional[Type[PretrainedConfig]]

The class of the configuration to use. If None, transformers.AutoConfig will be used.

None

Returns:

Type Description
Module

The instantiated model.

Source code in mmlearn/hf_utils.py
def load_huggingface_model(
    model_type: Type[_BaseAutoModelClass],
    model_name_or_path: str,
    load_pretrained_weights: bool = True,
    get_model_attr: Optional[str] = None,
    model_config_kwargs: Optional[dict[str, Any]] = None,
    config_type: Optional[Type[PretrainedConfig]] = None,
) -> nn.Module:
    """Load a model from the HuggingFace ``transformers`` library.

    Parameters
    ----------
    model_type : Type[_BaseAutoModelClass]
        The model class to instantiate e.g. ``transformers.AutoModel``.
    model_name_or_path : str
        The model name or path to load the model from.
    load_pretrained_weights : bool, optional, default=True
        Whether to load the pretrained weights or not. If false, the argument
        ``pretrained_model_name_or_path`` will be used to get the model configuration
        and the model will be initialized with random weights.
    get_model_attr : Optional[str], optional, default=None
        If not None, the attribute of the model to return. For example, if the model
        is an ``transformers.AutoModel`` and ``get_model_attr='encoder'``, the
        encoder part of the model will be returned. If ``None``, the full model
        will be returned.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.
        The values in kwargs of any keys which are configuration attributes will
        be used to override the loaded values. Behavior concerning key/value pairs
        whose keys are *not* configuration attributes is controlled by the
        ``return_unused_kwargs`` keyword parameter.
    config_type : Optional[Type[PretrainedConfig]], optional, default=None
        The class of the configuration to use. If None, ``transformers.AutoConfig``
        will be used.

    Returns
    -------
    torch.nn.Module
        The instantiated model.
    """
    model_config_kwargs = model_config_kwargs or {}
    if load_pretrained_weights:
        model = model_type.from_pretrained(model_name_or_path, **model_config_kwargs)
    else:
        if config_type is None:
            config_type = AutoConfig
        config, kwargs = config_type.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path,
            return_unused_kwargs=True,
            **model_config_kwargs,
        )
        model = model_type.from_config(config, **kwargs)

    if get_model_attr is not None and hasattr(model, get_model_attr):
        model = getattr(model, get_model_attr)

    return model

modules

Reusable components for building tasks.

ema

Exponential Moving Average (EMA) module.

ExponentialMovingAverage

Exponential Moving Average (EMA) for the input model.

At each step the parameter of the EMA model is updates as the weighted average of the model's parameters.

Parameters:

Name Type Description Default
model Module

The model to apply EMA to.

required
ema_decay float

The initial decay value for EMA.

required
ema_end_decay float

The final decay value for EMA.

required
ema_anneal_end_step int

The number of steps to anneal the decay from ema_decay to ema_end_decay.

required
skip_keys Optional[Union[list[str], Set[str]]]

The keys to skip in the EMA update. These parameters will be copied directly from the model to the EMA model.

None

Raises:

Type Description
RuntimeError

If a deep copy of the model cannot be created.

Source code in mmlearn/modules/ema.py
class ExponentialMovingAverage:
    """Exponential Moving Average (EMA) for the input model.

    At each step the parameter of the EMA model is updates as the weighted average
    of the model's parameters.

    Parameters
    ----------
    model : torch.nn.Module
        The model to apply EMA to.
    ema_decay : float
        The initial decay value for EMA.
    ema_end_decay : float
        The final decay value for EMA.
    ema_anneal_end_step : int
        The number of steps to anneal the decay from ``ema_decay`` to ``ema_end_decay``.
    skip_keys : Optional[Union[list[str], Set[str]]], optional, default=None
        The keys to skip in the EMA update. These parameters will be copied directly
        from the model to the EMA model.

    Raises
    ------
    RuntimeError
        If a deep copy of the model cannot be created.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        ema_decay: float,
        ema_end_decay: float,
        ema_anneal_end_step: int,
        skip_keys: Optional[Union[list[str], Set[str]]] = None,
    ) -> None:
        self.model = self.deepcopy_model(model)

        self.skip_keys: Union[list[str], set[str]] = skip_keys or set()
        self.num_updates = 0
        self.decay = ema_decay  # stores the current decay value
        self.ema_decay = ema_decay
        self.ema_end_decay = ema_end_decay
        self.ema_anneal_end_step = ema_anneal_end_step

        self._model_configured = False

    @staticmethod
    def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
        """Deep copy the model.

        Parameters
        ----------
        model : torch.nn.Module
            The model to copy.

        Returns
        -------
        torch.nn.Module
            The copied model.

        Raises
        ------
        RuntimeError
            If the model cannot be copied.
        """
        try:
            return copy.deepcopy(model)
        except RuntimeError as e:
            raise RuntimeError("Unable to copy the model ", e) from e

    @staticmethod
    def get_annealed_rate(
        start: float,
        end: float,
        curr_step: int,
        total_steps: int,
    ) -> float:
        """Calculate EMA annealing rate."""
        r = end - start
        pct_remaining = 1 - curr_step / total_steps
        return end - r * pct_remaining

    def configure_model(self, device_id: Union[int, torch.device]) -> None:
        """Configure the model for EMA."""
        if self._model_configured:
            return

        self.model.requires_grad_(False)
        self.model.to(device_id)

        self._model_configured = True

    def step(self, new_model: torch.nn.Module) -> None:
        """Perform single EMA update step."""
        if not self._model_configured:
            raise RuntimeError(
                "Model is not configured for EMA. Call `configure_model` first."
            )

        self._update_weights(new_model)
        self._update_ema_decay()

    def restore(self, model: torch.nn.Module) -> torch.nn.Module:
        """Reassign weights from another model.

        Parameters
        ----------
        model : torch.nn.Module
            Model to load weights from.

        Returns
        -------
        torch.nn.Module
            model with new weights
        """
        d = self.model.state_dict()
        model.load_state_dict(d, strict=False)
        return model

    def state_dict(self) -> dict[str, Any]:
        """Return the state dict of the model."""
        return self.model.state_dict()  # type: ignore[no-any-return]

    @torch.no_grad()  # type: ignore[misc]
    def _update_weights(self, new_model: torch.nn.Module) -> None:
        if self.decay < 1:
            ema_state_dict = {}
            ema_params = self.model.state_dict()

            for key, param in new_model.state_dict().items():
                ema_param = ema_params[key].float()

                if param.shape != ema_param.shape:
                    raise ValueError(
                        "Incompatible tensor shapes between student param and teacher param"
                        + "{} vs. {}".format(param.shape, ema_param.shape)
                    )

                if key in self.skip_keys or not param.requires_grad:
                    ema_param = param.to(dtype=ema_param.dtype).clone()
                else:
                    ema_param.mul_(self.decay)
                    ema_param.add_(
                        param.to(dtype=ema_param.dtype),
                        alpha=1 - self.decay,
                    )
                ema_state_dict[key] = ema_param

            self.model.load_state_dict(ema_state_dict, strict=False)
            self.num_updates += 1
        else:
            rank_zero_warn(
                "Exponential Moving Average decay is 1.0, no update is applied to the model.",
                stacklevel=1,
                category=UserWarning,
            )

    def _update_ema_decay(self) -> None:
        if self.ema_decay != self.ema_end_decay:
            if self.num_updates >= self.ema_anneal_end_step:
                decay = self.ema_end_decay
            else:
                decay = self.get_annealed_rate(
                    self.ema_decay,
                    self.ema_end_decay,
                    self.num_updates,
                    self.ema_anneal_end_step,
                )
            self.decay = decay
deepcopy_model staticmethod
deepcopy_model(model)

Deep copy the model.

Parameters:

Name Type Description Default
model Module

The model to copy.

required

Returns:

Type Description
Module

The copied model.

Raises:

Type Description
RuntimeError

If the model cannot be copied.

Source code in mmlearn/modules/ema.py
@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
    """Deep copy the model.

    Parameters
    ----------
    model : torch.nn.Module
        The model to copy.

    Returns
    -------
    torch.nn.Module
        The copied model.

    Raises
    ------
    RuntimeError
        If the model cannot be copied.
    """
    try:
        return copy.deepcopy(model)
    except RuntimeError as e:
        raise RuntimeError("Unable to copy the model ", e) from e
get_annealed_rate staticmethod
get_annealed_rate(start, end, curr_step, total_steps)

Calculate EMA annealing rate.

Source code in mmlearn/modules/ema.py
@staticmethod
def get_annealed_rate(
    start: float,
    end: float,
    curr_step: int,
    total_steps: int,
) -> float:
    """Calculate EMA annealing rate."""
    r = end - start
    pct_remaining = 1 - curr_step / total_steps
    return end - r * pct_remaining
configure_model
configure_model(device_id)

Configure the model for EMA.

Source code in mmlearn/modules/ema.py
def configure_model(self, device_id: Union[int, torch.device]) -> None:
    """Configure the model for EMA."""
    if self._model_configured:
        return

    self.model.requires_grad_(False)
    self.model.to(device_id)

    self._model_configured = True
step
step(new_model)

Perform single EMA update step.

Source code in mmlearn/modules/ema.py
def step(self, new_model: torch.nn.Module) -> None:
    """Perform single EMA update step."""
    if not self._model_configured:
        raise RuntimeError(
            "Model is not configured for EMA. Call `configure_model` first."
        )

    self._update_weights(new_model)
    self._update_ema_decay()
restore
restore(model)

Reassign weights from another model.

Parameters:

Name Type Description Default
model Module

Model to load weights from.

required

Returns:

Type Description
Module

model with new weights

Source code in mmlearn/modules/ema.py
def restore(self, model: torch.nn.Module) -> torch.nn.Module:
    """Reassign weights from another model.

    Parameters
    ----------
    model : torch.nn.Module
        Model to load weights from.

    Returns
    -------
    torch.nn.Module
        model with new weights
    """
    d = self.model.state_dict()
    model.load_state_dict(d, strict=False)
    return model
state_dict
state_dict()

Return the state dict of the model.

Source code in mmlearn/modules/ema.py
def state_dict(self) -> dict[str, Any]:
    """Return the state dict of the model."""
    return self.model.state_dict()  # type: ignore[no-any-return]

encoders

Encoders.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)
HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )
HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)
HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches
clip

Wrappers and interfaces for CLIP models.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )
HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)
HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)
text

Huggingface text encoder model.

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
vision

Vision encoder implementations.

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches
VisionTransformer

Bases: Module

Vision Transformer.

This module implements a Vision Transformer that processes images using a series of transformer blocks and patch embeddings.

Parameters:

Name Type Description Default
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
img_size List[int]

List of input image sizes.

None
patch_size int

Size of each patch.

16
in_chans int

Number of input channels.

3
embed_dim int

Embedding dimension.

768
depth int

Number of transformer blocks.

12
num_heads int

Number of attention heads.

12
mlp_ratio float

Ratio of hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformer(nn.Module):
    """Vision Transformer.

    This module implements a Vision Transformer that processes images using a
    series of transformer blocks and patch embeddings.

    Parameters
    ----------
    modality : str, optional, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    img_size : List[int], optional, default=None
        List of input image sizes.
    patch_size : int, optional, default=16
        Size of each patch.
    in_chans : int, optional, default=3
        Number of input channels.
    embed_dim : int, optional, default=768
        Embedding dimension.
    depth : int, optional, default=12
        Number of transformer blocks.
    num_heads : int, optional, default=12
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        modality: str = "RGB",
        img_size: Optional[list[int]] = None,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        global_pool: Literal["", "avg", "avgmax", "max", "token"] = "",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
    ) -> None:
        super().__init__()
        assert global_pool in ("", "avg", "avgmax", "max", "token")

        self.modality = Modalities.get_modality(modality)
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        img_size = [224, 224] if img_size is None else img_size

        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim), requires_grad=False
        )
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            int(self.patch_embed.num_patches**0.5),
            cls_token=False,
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Transformer Blocks
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        self.global_pool = global_pool

        # Weight Initialization
        self.init_std = init_std
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp[-1].weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self, inputs: dict[str, Any], return_hidden_states: bool = False
    ) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
        """Forward pass through the Vision Transformer."""
        masks = inputs.get(self.modality.mask)
        if masks is not None and not isinstance(masks, list):
            masks = [masks]

        x = inputs[self.modality.name]
        # -- Patchify x
        x = self.patch_embed(x)

        # -- Add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed

        # -- Mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- Initialize a list to store hidden states
        hidden_states: Optional[list[torch.Tensor]] = (
            [] if return_hidden_states else None
        )

        # -- Forward propagation through blocks
        for _i, blk in enumerate(self.blocks):
            x = blk(x)
            if return_hidden_states and hidden_states is not None:
                hidden_states.append(x)

        # -- Apply normalization if present
        if self.norm is not None:
            x = self.norm(x)

        # -- Apply global pooling
        x = global_pool_nlc(x, pool_type=self.global_pool)

        # -- Return both final output and hidden states if requested
        if return_hidden_states:
            return x, hidden_states
        return (x, None)

    def interpolate_pos_encoding(
        self, x: torch.Tensor, pos_embed: torch.Tensor
    ) -> torch.Tensor:
        """Interpolate positional encoding to match the size of the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
        pos_embed : torch.Tensor
            Positional embedding tensor.

        Returns
        -------
        torch.Tensor
            Interpolated positional encoding.
        """
        npatch = x.shape[1] - 1
        n = pos_embed.shape[1] - 1
        if npatch == n:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
                0, 3, 1, 2
            ),
            scale_factor=math.sqrt(npatch / n),
            mode="bicubic",
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp[-1].weight.data, layer_id + 1)
forward
forward(inputs, return_hidden_states=False)

Forward pass through the Vision Transformer.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self, inputs: dict[str, Any], return_hidden_states: bool = False
) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
    """Forward pass through the Vision Transformer."""
    masks = inputs.get(self.modality.mask)
    if masks is not None and not isinstance(masks, list):
        masks = [masks]

    x = inputs[self.modality.name]
    # -- Patchify x
    x = self.patch_embed(x)

    # -- Add positional embedding to x
    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
    x = x + pos_embed

    # -- Mask x
    if masks is not None:
        x = apply_masks(x, masks)

    # -- Initialize a list to store hidden states
    hidden_states: Optional[list[torch.Tensor]] = (
        [] if return_hidden_states else None
    )

    # -- Forward propagation through blocks
    for _i, blk in enumerate(self.blocks):
        x = blk(x)
        if return_hidden_states and hidden_states is not None:
            hidden_states.append(x)

    # -- Apply normalization if present
    if self.norm is not None:
        x = self.norm(x)

    # -- Apply global pooling
    x = global_pool_nlc(x, pool_type=self.global_pool)

    # -- Return both final output and hidden states if requested
    if return_hidden_states:
        return x, hidden_states
    return (x, None)
interpolate_pos_encoding
interpolate_pos_encoding(x, pos_embed)

Interpolate positional encoding to match the size of the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pos_embed Tensor

Positional embedding tensor.

required

Returns:

Type Description
Tensor

Interpolated positional encoding.

Source code in mmlearn/modules/encoders/vision.py
def interpolate_pos_encoding(
    self, x: torch.Tensor, pos_embed: torch.Tensor
) -> torch.Tensor:
    """Interpolate positional encoding to match the size of the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    pos_embed : torch.Tensor
        Positional embedding tensor.

    Returns
    -------
    torch.Tensor
        Interpolated positional encoding.
    """
    npatch = x.shape[1] - 1
    n = pos_embed.shape[1] - 1
    if npatch == n:
        return pos_embed
    class_emb = pos_embed[:, 0]
    pos_embed = pos_embed[:, 1:]
    dim = x.shape[-1]
    pos_embed = nn.functional.interpolate(
        pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
            0, 3, 1, 2
        ),
        scale_factor=math.sqrt(npatch / n),
        mode="bicubic",
    )
    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
VisionTransformerPredictor

Bases: Module

Vision Transformer Predictor.

This module implements a Vision Transformer that predicts masked tokens using a series of transformer blocks.

Parameters:

Name Type Description Default
num_patches int

The number of patches in the input image.

196
embed_dim int

The embedding dimension.

768
predictor_embed_dim int

The embedding dimension for the predictor.

384
depth int

The number of transformer blocks.

6
num_heads int

The number of attention heads.

12
mlp_ratio float

Ratio of the hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformerPredictor(nn.Module):
    """Vision Transformer Predictor.

    This module implements a Vision Transformer that predicts masked tokens
    using a series of transformer blocks.

    Parameters
    ----------
    num_patches : int
        The number of patches in the input image.
    embed_dim : int, optional, default=768
        The embedding dimension.
    predictor_embed_dim : int, optional, default=384
        The embedding dimension for the predictor.
    depth : int, optional, default=6
        The number of transformer blocks.
    num_heads : int, optional, default=12
        The number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of the hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        num_patches: int = 196,
        embed_dim: int = 768,
        predictor_embed_dim: int = 384,
        depth: int = 6,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.predictor_embed = nn.Linear(self.embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        # Positional Embedding
        self.predictor_pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, predictor_embed_dim), requires_grad=False
        )
        predictor_pos_embed = get_2d_sincos_pos_embed(
            self.predictor_pos_embed.shape[-1],
            int(self.num_patches**0.5),
            cls_token=False,
        )
        self.predictor_pos_embed.data.copy_(
            torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)
        )

        # Transformer Blocks
        self.predictor_blocks = nn.ModuleList(
            [
                Block(
                    dim=predictor_embed_dim,
                    num_heads=self.num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )

        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)

        # Weight Initialization
        self.init_std = init_std
        _trunc_normal(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self,
        x: torch.Tensor,
        masks_x: Union[torch.Tensor, list[torch.Tensor]],
        masks: Union[torch.Tensor, list[torch.Tensor]],
    ) -> torch.Tensor:
        """Forward pass through the Vision Transformer Predictor."""
        assert (masks is not None) and (masks_x is not None), (
            "Cannot run predictor without mask indices"
        )

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        b = len(x) // len(masks_x)

        # -- Map from encoder-dim to predictor-dim
        x = self.predictor_embed(x)

        # -- Add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, n_ctxt, d = x.shape

        # -- Concatenate mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- Forward propagation
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- Return predictions for mask tokens
        x = x[:, n_ctxt:]
        return self.predictor_proj(x)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.predictor_blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp.fc2.weight.data, layer_id + 1)
forward
forward(x, masks_x, masks)

Forward pass through the Vision Transformer Predictor.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self,
    x: torch.Tensor,
    masks_x: Union[torch.Tensor, list[torch.Tensor]],
    masks: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
    """Forward pass through the Vision Transformer Predictor."""
    assert (masks is not None) and (masks_x is not None), (
        "Cannot run predictor without mask indices"
    )

    if not isinstance(masks_x, list):
        masks_x = [masks_x]

    if not isinstance(masks, list):
        masks = [masks]

    # -- Batch Size
    b = len(x) // len(masks_x)

    # -- Map from encoder-dim to predictor-dim
    x = self.predictor_embed(x)

    # -- Add positional embedding to x tokens
    x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
    x += apply_masks(x_pos_embed, masks_x)

    _, n_ctxt, d = x.shape

    # -- Concatenate mask tokens to x
    pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
    pos_embs = apply_masks(pos_embs, masks)
    pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
    pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
    pred_tokens += pos_embs
    x = x.repeat(len(masks), 1, 1)
    x = torch.cat([x, pred_tokens], dim=1)

    # -- Forward propagation
    for blk in self.predictor_blocks:
        x = blk(x)
    x = self.predictor_norm(x)

    # -- Return predictions for mask tokens
    x = x[:, n_ctxt:]
    return self.predictor_proj(x)
vit_predictor
vit_predictor(kwargs=None)

Create a VisionTransformerPredictor model.

Parameters:

Name Type Description Default
kwargs dict[str, Any]

Keyword arguments for the predictor.

None

Returns:

Type Description
VisionTransformerPredictor

An instance of VisionTransformerPredictor.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformerPredictor,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_predictor(
    kwargs: Optional[dict[str, Any]] = None,
) -> VisionTransformerPredictor:
    """Create a VisionTransformerPredictor model.

    Parameters
    ----------
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the predictor.

    Returns
    -------
    VisionTransformerPredictor
        An instance of VisionTransformerPredictor.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
vit_tiny
vit_tiny(patch_size=16, kwargs=None)

Create a VisionTransformer model with tiny configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_tiny(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with tiny configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_small
vit_small(patch_size=16, kwargs=None)

Create a VisionTransformer model with small configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_small(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with small configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_base
vit_base(patch_size=16, kwargs=None)

Create a VisionTransformer model with base configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_base(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with base configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_large
vit_large(patch_size=16, kwargs=None)

Create a VisionTransformer model with large configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_large(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with large configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_huge
vit_huge(patch_size=16, kwargs=None)

Create a VisionTransformer model with huge configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_huge(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with huge configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_giant
vit_giant(patch_size=16, kwargs=None)

Create a VisionTransformer model with giant configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_giant(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with giant configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1408,
        depth=40,
        num_heads=16,
        mlp_ratio=48 / 11,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

layers

Custom, reusable layers for models and tasks.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )
MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)
L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)
PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids
attention

Attention modules for Vision Transformer (ViT) and related models.

Attention

Bases: Module

Multi-head Self-Attention Mechanism.

Parameters:

Name Type Description Default
dim int

Number of input dimensions.

required
num_heads int

Number of attention heads.

8
qkv_bias bool

If True, adds a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override the default scale factor for the dot-product attention.

None
attn_drop float

Dropout probability for the attention weights.

0.0
proj_drop float

Dropout probability for the output of the attention layer.

0.0
Source code in mmlearn/modules/layers/attention.py
class Attention(nn.Module):
    """Multi-head Self-Attention Mechanism.

    Parameters
    ----------
    dim : int
        Number of input dimensions.
    num_heads : int, optional, default=8
        Number of attention heads.
    qkv_bias : bool, optional, default=False
        If True, adds a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default scale factor for the dot-product attention.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention weights.
    proj_drop : float, optional, default=0.0
        Dropout probability for the output of the attention layer.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through the multi-head self-attention module.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            The output tensor and the attention weights.
        """
        b, n, c = x.shape
        qkv = (
            self.qkv(x)
            .reshape(b, n, 3, self.num_heads, c // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn
forward
forward(x)

Forward pass through the multi-head self-attention module.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
tuple[Tensor, Tensor]

The output tensor and the attention weights.

Source code in mmlearn/modules/layers/attention.py
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass through the multi-head self-attention module.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        The output tensor and the attention weights.
    """
    b, n, c = x.shape
    qkv = (
        self.qkv(x)
        .reshape(b, n, 3, self.num_heads, c // self.num_heads)
        .permute(2, 0, 3, 1, 4)
    )
    q, k, v = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(b, n, c)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x, attn
embedding

Embedding layers.

PatchEmbed

Bases: Module

Image to Patch Embedding.

This module divides an image into patches and embeds them as a sequence of vectors.

Parameters:

Name Type Description Default
img_size int

Size of the input image (assumed to be square).

224
patch_size int

Size of each image patch (assumed to be square).

16
in_chans int

Number of input channels in the image.

3
embed_dim int

Dimension of the output embeddings.

768
Source code in mmlearn/modules/layers/embedding.py
class PatchEmbed(nn.Module):
    """Image to Patch Embedding.

    This module divides an image into patches and embeds them as a sequence of vectors.

    Parameters
    ----------
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    patch_size : int, optional, default=16
        Size of each image patch (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    embed_dim : int, optional, default=768
        Dimension of the output embeddings.

    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass to convert an image into patch embeddings."""
        return self.proj(x).flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass to convert an image into patch embeddings.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to convert an image into patch embeddings."""
    return self.proj(x).flatten(2).transpose(1, 2)
ConvEmbed

Bases: Module

3x3 Convolution stems for ViT following ViTC models.

This module builds convolutional stems for Vision Transformers (ViT) with intermediate batch normalization and ReLU activation.

Parameters:

Name Type Description Default
channels list[int]

list of channel sizes for each convolution layer.

required
strides list[int]

list of stride sizes for each convolution layer.

required
img_size int

Size of the input image (assumed to be square).

224
in_chans int

Number of input channels in the image.

3
batch_norm bool

Whether to include batch normalization after each convolution layer.

True
Source code in mmlearn/modules/layers/embedding.py
class ConvEmbed(nn.Module):
    """3x3 Convolution stems for ViT following ViTC models.

    This module builds convolutional stems for Vision Transformers (ViT)
    with intermediate batch normalization and ReLU activation.

    Parameters
    ----------
    channels : list[int]
        list of channel sizes for each convolution layer.
    strides : list[int]
        list of stride sizes for each convolution layer.
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    batch_norm : bool, optional, default=True
        Whether to include batch normalization after each convolution layer.

    """

    def __init__(
        self,
        channels: list[int],
        strides: list[int],
        img_size: int = 224,
        in_chans: int = 3,
        batch_norm: bool = True,
    ) -> None:
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [
                nn.Conv2d(
                    channels[i],
                    channels[i + 1],
                    kernel_size=3,
                    stride=strides[i],
                    padding=1,
                    bias=(not batch_norm),
                )
            ]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i + 1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [
            nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])
        ]
        self.stem = nn.Sequential(*stem)

        # Compute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size // stride_prod) ** 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the convolutional embedding layers."""
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass through the convolutional embedding layers.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the convolutional embedding layers."""
    p = self.stem(x)
    return p.flatten(2).transpose(1, 2)
get_2d_sincos_pos_embed
get_2d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 2D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid (both height and width).

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_sizegrid_size, embed_dim] or [1 + grid_sizegrid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid (both height and width).
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim] or
        [1 + grid_size*grid_size, embed_dim] if cls_token is True.
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
get_2d_sincos_pos_embed_from_grid
get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

Generate 2D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid ndarray

The grid of positions with shape [2, 1, grid_size, grid_size].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [grid_size*grid_size, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid : np.ndarray
        The grid of positions with shape [2, 1, grid_size, grid_size].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim].
    """
    assert embed_dim % 2 == 0

    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    return np.concatenate([emb_h, emb_w], axis=1)
get_1d_sincos_pos_embed
get_1d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 1D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid.

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_size, embed_dim] or [1 + grid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid.
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size, embed_dim] or
        [1 + grid_size, embed_dim] if cls_token is True.
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
get_1d_sincos_pos_embed_from_grid
get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

Generate 1D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
pos ndarray

A list of positions to be encoded, with shape [M,].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [M, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    pos : np.ndarray
        A list of positions to be encoded, with shape [M,].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [M, embed_dim].
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    return np.concatenate([emb_sin, emb_cos], axis=1)
logit_scaling

Learnable logit scaling layer.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )
mlp

Multi-layer perceptron (MLP).

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)
normalization

Normalization layers.

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)
patch_dropout

Patch dropout layer.

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids
transformer_block

Transformer Block and Embedding Modules for Vision Transformers (ViT).

DropPath

Bases: Module

Drop paths (Stochastic Depth) per sample.

Parameters:

Name Type Description Default
drop_prob float

Probability of dropping paths.

0.0
Source code in mmlearn/modules/layers/transformer_block.py
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample.

    Parameters
    ----------
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through DropPath module."""
        return drop_path(x, self.drop_prob, self.training)
forward
forward(x)

Forward pass through DropPath module.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through DropPath module."""
    return drop_path(x, self.drop_prob, self.training)
Block

Bases: Module

Transformer Block.

This module represents a Transformer block that includes self-attention, normalization layers, and a feedforward multi-layer perceptron (MLP) network.

Parameters:

Name Type Description Default
dim int

The input and output dimension of the block.

required
num_heads int

Number of attention heads.

required
mlp_ratio float

Ratio of hidden dimension to the input dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override default qk scale of head_dim ** -0.5 if set.

None
drop float

Dropout probability for the output of attention and MLP layers.

0.0
attn_drop float

Dropout probability for the attention scores.

0.0
drop_path float

Stochastic depth rate, a form of layer dropout.

0.0
act_layer Callable[..., Module]

Activation layer in the MLP.

nn.GELU
norm_layer Callable[..., Module]

Normalization layer.

torch.nn.LayerNorm
Source code in mmlearn/modules/layers/transformer_block.py
class Block(nn.Module):
    """Transformer Block.

    This module represents a Transformer block that includes self-attention,
    normalization layers, and a feedforward multi-layer perceptron (MLP) network.

    Parameters
    ----------
    dim : int
        The input and output dimension of the block.
    num_heads : int
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension to the input dimension in the MLP.
    qkv_bias : bool, optional, default=False
        If True, add a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override default qk scale of head_dim ** -0.5 if set.
    drop : float, optional, default=0.0
        Dropout probability for the output of attention and MLP layers.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention scores.
    drop_path : float, optional, default=0.0
        Stochastic depth rate, a form of layer dropout.
    act_layer : Callable[..., torch.nn.Module], optional, default=nn.GELU
        Activation layer in the MLP.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer.

    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)

        self.mlp = MLP(
            in_dim=dim,
            hidden_dims_multiplier=[mlp_ratio],
            activation_layer=act_layer,
            bias=True,
            dropout=drop,
        )

    def forward(
        self, x: torch.Tensor, return_attention: bool = False
    ) -> Union[torch.Tensor, torch.Tensor]:
        """Forward pass through the Transformer Block."""
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        return x + self.drop_path(self.mlp(self.norm2(x)))
forward
forward(x, return_attention=False)

Forward pass through the Transformer Block.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(
    self, x: torch.Tensor, return_attention: bool = False
) -> Union[torch.Tensor, torch.Tensor]:
    """Forward pass through the Transformer Block."""
    y, attn = self.attn(self.norm1(x))
    if return_attention:
        return attn
    x = x + self.drop_path(y)
    return x + self.drop_path(self.mlp(self.norm2(x)))
drop_path
drop_path(x, drop_prob=0.0, training=False)

Drop paths (Stochastic Depth) for regularization during training.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
drop_prob float

Probability of dropping paths.

0.0
training bool

Whether the model is in training mode.

False

Returns:

Name Type Description
output Tensor

Output tensor after applying drop path.

Source code in mmlearn/modules/layers/transformer_block.py
def drop_path(
    x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
    """Drop paths (Stochastic Depth) for regularization during training.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    training : bool, optional, default=False
        Whether the model is in training mode.

    Returns
    -------
    output : torch.Tensor
        Output tensor after applying drop path.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    return x.div(keep_prob) * random_tensor

losses

Loss functions.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()
Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)
contrastive

Implementations of the contrastive loss and its variants.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()
data2vec

Implementation of Data2vec loss function.

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

lr_schedulers

Learning rate schedulers for training models.

linear_warmup_cosine_annealing_lr
linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )
linear_warmup_cosine_lr

Linear warmup cosine annealing learning rate scheduler.

linear_warmup_cosine_annealing_lr
linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

metrics

Metrics for evaluating models.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )
retrieval_recall

Retrieval Recall@K metric.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

tasks

Modules for pretraining, downstream and evaluation tasks.

ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()
encode
encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs
on_train_epoch_start
on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()
training_step
training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss
on_before_zero_grad
on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")
on_validation_epoch_end
on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")
on_test_epoch_end
on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_load_checkpoint
on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)
on_save_checkpoint
on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)

IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)
on_before_zero_grad
on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)
training_step
training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")
on_validation_epoch_end
on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
on_test_epoch_end
on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_save_checkpoint
on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }
on_load_checkpoint
on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)

ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

base

Base class for all tasks in mmlearn that require training.

TrainingTask

Bases: LightningModule

Base class for all tasks in mmlearn that require training.

Parameters:

Name Type Description Default
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a partial function, created using functools.partial, that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a partial function that takes the optimizer as the only required argument or a dictionary with a scheduler key that specifies the scheduler and an optional extras key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
loss_fn Optional[Module]

Loss function to use for training.

None
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True

Raises:

Type Description
ValueError

If the loss function is not provided and either the validation or test loss needs to be computed.

Source code in mmlearn/tasks/base.py
class TrainingTask(L.LightningModule):
    """Base class for all tasks in mmlearn that require training.

    Parameters
    ----------
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a partial function,
        created using `functools.partial`, that takes the model parameters as the
        only required argument. If not provided, training will continue without an
        optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a partial function
        that takes the optimizer as the only required argument or a dictionary with
        a `scheduler` key that specifies the scheduler and an optional `extras` key
        that specifies additional arguments to pass to the scheduler. If not provided,
        the learning rate will not be adjusted during training.
    loss_fn : Optional[torch.nn.Module], optional, default=None
        Loss function to use for training.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.

    Raises
    ------
    ValueError
        If the loss function is not provided and either the validation or test loss
        needs to be computed.
    """  # noqa: W505

    def __init__(
        self,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        loss_fn: Optional[torch.nn.Module] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__()
        if loss_fn is None and (compute_validation_loss or compute_test_loss):
            raise ValueError(
                "Loss function must be provided to compute validation or test loss."
            )

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loss_fn = loss_fn
        self.compute_validation_loss = compute_validation_loss
        self.compute_test_loss = compute_test_loss

    def configure_optimizers(self) -> OptimizerLRScheduler:  # noqa: PLR0912
        """Configure the optimizer and learning rate scheduler."""
        if self.optimizer is None:
            rank_zero_warn(
                "Optimizer not provided. Training will continue without an optimizer. "
                "LR scheduler will not be used.",
            )
            return None

        weight_decay: Optional[float] = self.optimizer.keywords.get(
            "weight_decay", None
        )
        if weight_decay is None:  # try getting default value
            kw_param = inspect.signature(self.optimizer.func).parameters.get(
                "weight_decay"
            )
            if kw_param is not None and kw_param.default != inspect.Parameter.empty:
                weight_decay = kw_param.default

        parameters = [param for param in self.parameters() if param.requires_grad]

        if weight_decay is not None:
            decay_params = []
            no_decay_params = []

            for param in self.parameters():
                if not param.requires_grad:
                    continue

                if param.ndim < 2:  # includes all bias and normalization parameters
                    no_decay_params.append(param)
                else:
                    decay_params.append(param)

            parameters = [
                {
                    "params": decay_params,
                    "weight_decay": weight_decay,
                    "name": "weight_decay_params",
                },
                {
                    "params": no_decay_params,
                    "weight_decay": 0.0,
                    "name": "no_weight_decay_params",
                },
            ]

        optimizer = self.optimizer(parameters)
        if not isinstance(optimizer, torch.optim.Optimizer):
            raise TypeError(
                "Expected optimizer to be an instance of `torch.optim.Optimizer`, "
                f"but got {type(optimizer)}.",
            )

        if self.lr_scheduler is not None:
            if isinstance(self.lr_scheduler, dict):
                if "scheduler" not in self.lr_scheduler:
                    raise ValueError(
                        "Expected 'scheduler' key in the learning rate scheduler dictionary.",
                    )

                lr_scheduler = self.lr_scheduler["scheduler"](optimizer)
                if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                    raise TypeError(
                        "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                        f"but got {type(lr_scheduler)}.",
                    )
                lr_scheduler_dict: dict[
                    str, Union[torch.optim.lr_scheduler.LRScheduler, Any]
                ] = {"scheduler": lr_scheduler}

                if self.lr_scheduler.get("extras"):
                    lr_scheduler_dict.update(self.lr_scheduler["extras"])
                return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

            lr_scheduler = self.lr_scheduler(optimizer)
            if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                raise TypeError(
                    "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                    f"but got {type(lr_scheduler)}.",
                )
            return [optimizer], [lr_scheduler]

        return optimizer
configure_optimizers
configure_optimizers()

Configure the optimizer and learning rate scheduler.

Source code in mmlearn/tasks/base.py
def configure_optimizers(self) -> OptimizerLRScheduler:  # noqa: PLR0912
    """Configure the optimizer and learning rate scheduler."""
    if self.optimizer is None:
        rank_zero_warn(
            "Optimizer not provided. Training will continue without an optimizer. "
            "LR scheduler will not be used.",
        )
        return None

    weight_decay: Optional[float] = self.optimizer.keywords.get(
        "weight_decay", None
    )
    if weight_decay is None:  # try getting default value
        kw_param = inspect.signature(self.optimizer.func).parameters.get(
            "weight_decay"
        )
        if kw_param is not None and kw_param.default != inspect.Parameter.empty:
            weight_decay = kw_param.default

    parameters = [param for param in self.parameters() if param.requires_grad]

    if weight_decay is not None:
        decay_params = []
        no_decay_params = []

        for param in self.parameters():
            if not param.requires_grad:
                continue

            if param.ndim < 2:  # includes all bias and normalization parameters
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        parameters = [
            {
                "params": decay_params,
                "weight_decay": weight_decay,
                "name": "weight_decay_params",
            },
            {
                "params": no_decay_params,
                "weight_decay": 0.0,
                "name": "no_weight_decay_params",
            },
        ]

    optimizer = self.optimizer(parameters)
    if not isinstance(optimizer, torch.optim.Optimizer):
        raise TypeError(
            "Expected optimizer to be an instance of `torch.optim.Optimizer`, "
            f"but got {type(optimizer)}.",
        )

    if self.lr_scheduler is not None:
        if isinstance(self.lr_scheduler, dict):
            if "scheduler" not in self.lr_scheduler:
                raise ValueError(
                    "Expected 'scheduler' key in the learning rate scheduler dictionary.",
                )

            lr_scheduler = self.lr_scheduler["scheduler"](optimizer)
            if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                raise TypeError(
                    "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                    f"but got {type(lr_scheduler)}.",
                )
            lr_scheduler_dict: dict[
                str, Union[torch.optim.lr_scheduler.LRScheduler, Any]
            ] = {"scheduler": lr_scheduler}

            if self.lr_scheduler.get("extras"):
                lr_scheduler_dict.update(self.lr_scheduler["extras"])
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

        lr_scheduler = self.lr_scheduler(optimizer)
        if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
            raise TypeError(
                "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                f"but got {type(lr_scheduler)}.",
            )
        return [optimizer], [lr_scheduler]

    return optimizer

contrastive_pretraining

Contrastive pretraining task.

ModuleKeySpec dataclass

Module key specification for mapping modules to modalities.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class ModuleKeySpec:
    """Module key specification for mapping modules to modalities."""

    #: The key of the encoder module. If not provided, the modality name is used.
    encoder_key: Optional[str] = None

    #: The key of the head module. If not provided, the modality name is used.
    head_key: Optional[str] = None

    #: The key of the postprocessor module. If not provided, the modality name is used.
    postprocessor_key: Optional[str] = None
LossPairSpec dataclass

Specification for a pair of modalities to compute the contrastive loss.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class LossPairSpec:
    """Specification for a pair of modalities to compute the contrastive loss."""

    #: The pair of modalities to compute the contrastive loss between.
    modalities: tuple[str, str]

    #: The weight to apply to the contrastive loss for the pair of modalities.
    weight: float = 1.0
AuxiliaryTaskSpec dataclass

Specification for an auxiliary task to run alongside the main task.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class AuxiliaryTaskSpec:
    """Specification for an auxiliary task to run alongside the main task."""

    #: The modality of the encoder to use for the auxiliary task.
    modality: str

    #: The auxiliary task module. This is expected to be a partially-initialized
    #: instance of a :py:class:`~lightning.pytorch.core.LightningModule` created
    #: using :py:func:`functools.partial`, such that an initialized encoder can be
    #: passed as the only argument.
    task: Any  # `functools.partial[L.LightningModule]` expected

    #: The weight to apply to the auxiliary task loss.
    loss_weight: float = 1.0
EvaluationSpec dataclass

Specification for an evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class EvaluationSpec:
    """Specification for an evaluation task."""

    #: The evaluation task module. This is expected to be an instance of
    #: :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.
    task: Any  # `EvaluationHooks` expected

    #: Whether to run the evaluation task during validation.
    run_on_validation: bool = True

    #: Whether to run the evaluation task during training.
    run_on_test: bool = True
ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()
encode
encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs
on_train_epoch_start
on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()
training_step
training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss
on_before_zero_grad
on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")
on_validation_epoch_end
on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")
on_test_epoch_end
on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_load_checkpoint
on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)
on_save_checkpoint
on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)

hooks

Task-related hooks for Lightning modules.

EvaluationHooks

Hooks for evaluation.

Source code in mmlearn/tasks/hooks.py
class EvaluationHooks:
    """Hooks for evaluation."""

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Prepare the evaluation loop.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        """

    def evaluation_step(
        self, pl_module: pl.LightningModule, batch: Any, batch_idx: int
    ) -> Optional[Mapping[str, Any]]:
        """Run a single evaluation step.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : Any
            A batch of data.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[Mapping[str, Any]]
            A dictionary of evaluation results for the batch or ``None`` if no
            batch results are available.

        """
        rank_zero_warn(
            f"`evaluation_step` must be implemented to use {self.__class__.__name__} for evaluation."
        )
        return None

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[Union[Mapping[str, Any]]]:
        """Run after the evaluation epoch.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[Union[Mapping[str, Any]]]
            A dictionary of evaluation results or ``None`` if no results are available.
        """
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Prepare the evaluation loop.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
Source code in mmlearn/tasks/hooks.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Prepare the evaluation loop.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    """
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run a single evaluation step.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch Any

A batch of data.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Mapping[str, Any]]

A dictionary of evaluation results for the batch or None if no batch results are available.

Source code in mmlearn/tasks/hooks.py
def evaluation_step(
    self, pl_module: pl.LightningModule, batch: Any, batch_idx: int
) -> Optional[Mapping[str, Any]]:
    """Run a single evaluation step.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : Any
        A batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[Mapping[str, Any]]
        A dictionary of evaluation results for the batch or ``None`` if no
        batch results are available.

    """
    rank_zero_warn(
        f"`evaluation_step` must be implemented to use {self.__class__.__name__} for evaluation."
    )
    return None
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Run after the evaluation epoch.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[Union[Mapping[str, Any]]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/hooks.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[Union[Mapping[str, Any]]]:
    """Run after the evaluation epoch.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[Union[Mapping[str, Any]]]
        A dictionary of evaluation results or ``None`` if no results are available.
    """

ijepa

IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task.

IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)
on_before_zero_grad
on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)
training_step
training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")
on_validation_epoch_end
on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
on_test_epoch_end
on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_save_checkpoint
on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }
on_load_checkpoint
on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)

zero_shot_classification

Zero-shot classification evaluation task.

ClassificationTaskSpec dataclass

Specification for a classification task.

Source code in mmlearn/tasks/zero_shot_classification.py
@dataclass
class ClassificationTaskSpec:
    """Specification for a classification task."""

    #: The modality of the query input.
    query_modality: str

    #: The top-k values for which to compute the classification metrics like accuracy.
    top_k: list[int]
ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

zero_shot_retrieval

Zero-shot cross-modal retrieval evaluation task.

RetrievalTaskSpec dataclass

Specification for a retrieval task.

Source code in mmlearn/tasks/zero_shot_retrieval.py
@dataclass
class RetrievalTaskSpec:
    """Specification for a retrieval task."""

    #: The query modality.
    query_modality: str

    #: The target modality.
    target_modality: str

    #: The top-k values for which to compute the retrieval recall metrics.
    top_k: list[int]
ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

CLI Module

mmlearn.cli

Command Line Interface for mmlearn.

run

Main entry point for training and evaluation.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)

DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)

HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()
encode
encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs
on_train_epoch_start
on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()
training_step
training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss
on_before_zero_grad
on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")
on_validation_epoch_end
on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")
on_test_epoch_end
on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_load_checkpoint
on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)
on_save_checkpoint
on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)

IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)
on_before_zero_grad
on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)
training_step
training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")
on_validation_epoch_end
on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
on_test_epoch_end
on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_save_checkpoint
on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }
on_load_checkpoint
on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)

ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

find_matching_indices

find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices

linear_warmup_cosine_annealing_lr

linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

main

main(cfg)

Entry point for training or evaluation.

Source code in mmlearn/cli/run.py
@_hydra_main(
    config_path="pkg://mmlearn.conf", config_name="base_config", version_base=None
)
def main(cfg: MMLearnConf) -> None:  # noqa: PLR0912
    """Entry point for training or evaluation."""
    cfg_copy = copy.deepcopy(cfg)  # copy of the config for logging

    L.seed_everything(cfg.seed, workers=True)

    if is_torch_tf32_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        if "16-mixed" in str(cfg.trainer.precision):
            cfg.trainer.precision = "bf16-mixed"

    # setup trainer first so that we can get some variables for distributed training
    callbacks = instantiate_callbacks(cfg.trainer.get("callbacks"))
    cfg.trainer["callbacks"] = None  # will be replaced with the instantiated object
    loggers = instantiate_loggers(cfg.trainer.get("logger"))
    cfg.trainer["logger"] = None
    trainer: Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=loggers, _convert_="all"
    )
    assert isinstance(trainer, Trainer), (
        "Trainer must be an instance of `lightning.pytorch.trainer.Trainer`"
    )

    if rank_zero_only.rank == 0 and loggers is not None:  # update wandb config
        for trainer_logger in loggers:
            if isinstance(trainer_logger, WandbLogger):
                trainer_logger.experiment.config.update(
                    OmegaConf.to_container(cfg_copy, resolve=True, enum_to_str=True),
                    allow_val_change=True,
                )
    trainer.print(OmegaConf.to_yaml(cfg_copy, resolve=True))

    requires_distributed_sampler = (
        trainer.distributed_sampler_kwargs is not None
        and trainer._accelerator_connector.use_distributed_sampler
    )
    if requires_distributed_sampler:  # we handle distributed samplers
        trainer._accelerator_connector.use_distributed_sampler = False

    # prepare dataloaders
    if cfg.job_type == JobType.train:
        train_dataset = instantiate_datasets(cfg.datasets.train)
        assert train_dataset is not None, (
            "Train dataset (`cfg.datasets.train`) is required for training."
        )

        train_sampler = instantiate_sampler(
            cfg.dataloader.train.get("sampler"),
            train_dataset,
            requires_distributed_sampler=requires_distributed_sampler,
            distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
        )
        cfg.dataloader.train["sampler"] = None  # replaced with the instantiated object
        train_loader: DataLoader = hydra.utils.instantiate(
            cfg.dataloader.train, dataset=train_dataset, sampler=train_sampler
        )

        val_loader: Optional[DataLoader] = None
        val_dataset = instantiate_datasets(cfg.datasets.val)
        if val_dataset is not None:
            val_sampler = instantiate_sampler(
                cfg.dataloader.val.get("sampler"),
                val_dataset,
                requires_distributed_sampler=requires_distributed_sampler,
                distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
            )
            cfg.dataloader.val["sampler"] = None
            val_loader = hydra.utils.instantiate(
                cfg.dataloader.val, dataset=val_dataset, sampler=val_sampler
            )
    else:
        test_dataset = instantiate_datasets(cfg.datasets.test)
        assert test_dataset is not None, (
            "Test dataset (`cfg.datasets.test`) is required for evaluation."
        )

        test_sampler = instantiate_sampler(
            cfg.dataloader.test.get("sampler"),
            test_dataset,
            requires_distributed_sampler=requires_distributed_sampler,
            distributed_sampler_kwargs=trainer.distributed_sampler_kwargs,
        )
        cfg.dataloader.test["sampler"] = None
        test_loader = hydra.utils.instantiate(
            cfg.dataloader.test, dataset=test_dataset, sampler=test_sampler
        )

    # setup task module
    if cfg.task is None or "_target_" not in cfg.task:
        raise ValueError(
            "Expected a non-empty config for `cfg.task` with a `_target_` key. "
            f"But got: {cfg.task}"
        )
    logger.info(f"Instantiating task module: {cfg.task['_target_']}")
    model: L.LightningModule = hydra.utils.instantiate(cfg.task, _convert_="partial")
    assert isinstance(model, L.LightningModule), "Task must be a `LightningModule`"
    model.strict_loading = cfg.strict_loading

    # compile model
    model = torch.compile(model, **OmegaConf.to_object(cfg.torch_compile_kwargs))

    if cfg.job_type == JobType.train:
        trainer.fit(
            model, train_loader, val_loader, ckpt_path=cfg.resume_from_checkpoint
        )
    elif cfg.job_type == JobType.eval:
        trainer.test(model, test_loader, ckpt_path=cfg.resume_from_checkpoint)

Configuration Module

mmlearn.conf

Hydra/Hydra-zen-based configurations.

JobType

Bases: str, Enum

Type of the job.

Source code in mmlearn/conf/__init__.py
class JobType(str, Enum):
    """Type of the job."""

    train = "train"
    eval = "eval"

DatasetConf dataclass

Configuration template for the datasets.

Source code in mmlearn/conf/__init__.py
@dataclass
class DatasetConf:
    """Configuration template for the datasets."""

    #: Configuration for the training dataset.
    train: Optional[Any] = field(
        default=None,
        metadata={"help": "Configuration for the training dataset."},
    )
    #: Configuration for the validation dataset.
    val: Optional[Any] = field(
        default=None, metadata={"help": "Configuration for the validation dataset."}
    )
    #: Configuration for the test dataset.
    test: Optional[Any] = field(
        default=None,
        metadata={"help": "Configuration for the test dataset."},
    )

DataLoaderConf dataclass

Configuration for the dataloader.

Source code in mmlearn/conf/__init__.py
@dataclass
class DataLoaderConf:
    """Configuration for the dataloader."""

    #: Configuration for the training dataloader.
    train: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the training dataloader."},
    )
    #: Configuration for the validation dataloader.
    val: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the validation dataloader."},
    )
    #: Configuration for the test dataloader.
    test: Any = field(
        default_factory=_DataLoaderConf,
        metadata={"help": "Configuration for the test dataloader."},
    )

MMLearnConf dataclass

Top-level configuration for mmlearn experiments.

Source code in mmlearn/conf/__init__.py
@dataclass
class MMLearnConf:
    """Top-level configuration for mmlearn experiments."""

    defaults: list[Any] = field(
        default_factory=lambda: [
            "_self_",  # See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order for more information
            {"task": MISSING},
            {"override hydra/launcher": "submitit_slurm"},
        ]
    )
    #: Name of the experiment. This must be specified for any experiment to run.
    experiment_name: str = field(default=MISSING)
    #: Type of the job.
    job_type: JobType = field(default=JobType.train)
    #: Seed for the random number generators. This is set for Python, Numpy and PyTorch,
    #: including the workers in PyTorch Dataloaders.
    seed: Optional[int] = field(default=None)
    #: Configuration for the datasets.
    datasets: DatasetConf = field(default_factory=DatasetConf)
    #: Configuration for the dataloaders.
    dataloader: DataLoaderConf = field(default_factory=DataLoaderConf)
    #: Configuration for the task. This is required to run any experiment.
    task: Any = field(default=MISSING)
    #: Configuration for the trainer. The options here are the same as in
    #: :py:class:`~lightning.pytorch.trainer.trainer.Trainer`
    trainer: Any = field(
        default_factory=builds(
            lightning_trainer.Trainer,
            populate_full_signature=True,
            enable_model_summary=True,
            enable_progress_bar=True,
            enable_checkpointing=True,
            default_root_dir=_get_default_ckpt_dir(),
        )
    )
    #: Tags for the experiment. This is useful for `wandb <https://docs.wandb.ai/ref/python/init>`_
    #: logging.
    tags: Optional[list[str]] = field(default_factory=lambda: [II("experiment_name")])
    #: Path to the checkpoint to resume training from.
    resume_from_checkpoint: Optional[Path] = field(default=None)
    #: Whether to strictly enforce loading of model weights i.e. `strict=True` in
    #: :py:meth:`~lightning.pytorch.core.module.LightningModule.load_from_checkpoint`.
    strict_loading: bool = field(default=True)
    #: Configuration for torch.compile. These are essentially the same as the
    #: arguments for :py:func:`torch.compile`.
    torch_compile_kwargs: dict[str, Any] = field(
        default_factory=lambda: {
            "disable": True,
            "fullgraph": False,
            "dynamic": None,
            "backend": "inductor",
            "mode": None,
            "options": None,
        }
    )
    #: Hydra configuration.
    hydra: HydraConf = field(
        default_factory=lambda: HydraConf(
            searchpath=["pkg://mmlearn.conf"],
            run=RunDir(
                dir=SI("./outputs/${experiment_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}")
            ),
            sweep=SweepDir(
                dir=SI("./outputs/${experiment_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}"),
                subdir=SI("${hydra.job.num}_${hydra.job.id}"),
            ),
            help=HelpConf(
                app_name="mmlearn",
                header="mmlearn: A modular framework for research on multimodal representation learning.",
            ),
            job=JobConf(
                name=II("experiment_name"),
                env_set={
                    "TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
                    "HYDRA_FULL_ERROR": "1",
                },
            ),
        )
    )

register_external_modules

register_external_modules(
    module,
    group,
    name=None,
    package=None,
    provider=None,
    base_cls=None,
    ignore_cls=None,
    ignore_prefix=None,
    **kwargs_for_builds
)

Add all classes in an external module to a ZenStore.

Parameters:

Name Type Description Default
module ModuleType

The module to add classes from.

required
group str

The config group to add the classes to.

required
name Optional[str]

The name to give to the dynamically-generated configs. If None, the class name is used.

None
package Optional[str]

The package to add the configs to.

None
provider Optional[str]

The provider to add the configs to.

None
base_cls Optional[type]

The base class to filter classes by. The base class is also excluded from the configs.

None
ignore_cls Optional[list[type]]

list of classes to ignore.

None
ignore_prefix Optional[str]

Ignore classes whose names start with this prefix.

None
kwargs_for_builds Any

Additional keyword arguments to pass to hydra_zen.builds.

{}
Source code in mmlearn/conf/__init__.py
def register_external_modules(
    module: ModuleType,
    group: str,
    name: Optional[str] = None,
    package: Optional[str] = None,
    provider: Optional[str] = None,
    base_cls: Optional[type] = None,
    ignore_cls: Optional[list[type]] = None,
    ignore_prefix: Optional[str] = None,
    **kwargs_for_builds: Any,
) -> None:
    """Add all classes in an external module to a ZenStore.

    Parameters
    ----------
    module : ModuleType
        The module to add classes from.
    group : str
        The config group to add the classes to.
    name : Optional[str], optional, default=None
        The name to give to the dynamically-generated configs. If `None`, the
        class name is used.
    package : Optional[str], optional, default=None
        The package to add the configs to.
    provider : Optional[str], optional, default=None
        The provider to add the configs to.
    base_cls : Optional[type], optional, default=None
        The base class to filter classes by. The base class is also excluded from
        the configs.
    ignore_cls : Optional[list[type]], optional, default=None
        list of classes to ignore.
    ignore_prefix : Optional[str], optional, default=None
        Ignore classes whose names start with this prefix.
    kwargs_for_builds : Any
        Additional keyword arguments to pass to ``hydra_zen.builds``.

    """
    for key, cls in module.__dict__.items():
        if (
            isinstance(cls, type)
            and (base_cls is None or issubclass(cls, base_cls))
            and cls != base_cls
            and (ignore_cls is None or cls not in ignore_cls)
            and (ignore_prefix is None or not key.startswith(ignore_prefix))
        ):
            external_store(
                builds(cls, populate_full_signature=True, **kwargs_for_builds),
                name=name or key,
                group=group,
                package=package,
                provider=provider,
            )

Datasets Module

mmlearn.datasets

Datasets.

CheXpert

Bases: Dataset[Example]

CheXpert dataset.

Each datapoint is a pair of (image, target label).

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, valid)

Dataset split.

"train"
labeler Optional[{chexpert, chexbert, vchexbert}]

Labeler used to extract labels from the training images. "valid" split has no labeler, labeling for valid split was done by human radiologists.

None
transform Optional[Callable[[PIL.Image], torch.Tensor]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/chexpert.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("CHEXPERT_ROOT_DIR", MISSING),
    split="train",
)
class CheXpert(Dataset[Example]):
    """CheXpert dataset.

    Each datapoint is a pair of `(image, target label)`.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "valid"}
        Dataset split.
    labeler : Optional[{"chexpert", "chexbert", "vchexbert"}], optional, default=None
        Labeler used to extract labels from the training images. "valid" split
        has no labeler, labeling for valid split was done by human radiologists.
    transform : Optional[Callable[[PIL.Image], torch.Tensor], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "valid"],
        labeler: Optional[Literal["chexpert", "chexbert", "vchexbert"]] = None,
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "valid"], f"split {split} is not available."
        assert labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None, (
            f"labeler {labeler} is not available."
        )
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        if split == "valid":
            data_file = f"{split}_data.json"
        elif split == "train":
            data_file = f"{labeler}_{split}_data.json"
        data_path = os.path.join(root_dir, data_file)

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

    def __getitem__(self, idx: int) -> Example:
        """Return the idx'th datapoint."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        return Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)

__getitem__

__getitem__(idx)

Return the idx'th datapoint.

Source code in mmlearn/datasets/chexpert.py
def __getitem__(self, idx: int) -> Example:
    """Return the idx'th datapoint."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    return Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/chexpert.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

ImageNet

Bases: ImageFolder

ImageNet dataset.

This is a wrapper around the 🇵🇾class:~torchvision.datasets.ImageFolder class that returns an 🇵🇾class:~mmlearn.datasets.core.example.Example object.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, val)

The split of the dataset to use.

"train"
transform Optional[Callable]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
target_transform Optional[Callable]

A callable that takes in the target and transforms it.

None
mask_generator Optional[Callable]

A callable that generates a mask for the image.

None
Source code in mmlearn/datasets/imagenet.py
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("IMAGENET_ROOT_DIR", MISSING),
)
class ImageNet(ImageFolder):
    """ImageNet dataset.

    This is a wrapper around the :py:class:`~torchvision.datasets.ImageFolder` class
    that returns an :py:class:`~mmlearn.datasets.core.example.Example` object.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "val"}, default="train"
        The split of the dataset to use.
    transform : Optional[Callable], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    target_transform : Optional[Callable], optional, default=None
        A callable that takes in the target and transforms it.
    mask_generator : Optional[Callable], optional, default=None
        A callable that generates a mask for the image.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "val"] = "train",
        transform: Optional[Callable[..., Any]] = None,
        target_transform: Optional[Callable[..., Any]] = None,
        mask_generator: Optional[Callable[..., Any]] = None,
    ) -> None:
        split = "train" if split == "train" else "val"
        root_dir = os.path.join(root_dir, split)
        super().__init__(
            root=root_dir, transform=transform, target_transform=target_transform
        )
        self.mask_generator = mask_generator

    def __getitem__(self, index: int) -> Example:
        """Get an example at the given index."""
        image, target = super().__getitem__(index)
        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: target,
                EXAMPLE_INDEX_KEY: index,
            }
        )
        mask = self.mask_generator() if self.mask_generator else None
        if mask is not None:  # error will be raised during collation if `None`
            example[Modalities.RGB.mask] = mask
        return example

    @property
    def zero_shot_prompt_templates(self) -> list[str]:
        """Return the zero-shot prompt templates."""
        return [
            "a bad photo of a {}.",
            "a photo of many {}.",
            "a sculpture of a {}.",
            "a photo of the hard to see {}.",
            "a low resolution photo of the {}.",
            "a rendering of a {}.",
            "graffiti of a {}.",
            "a bad photo of the {}.",
            "a cropped photo of the {}.",
            "a tattoo of a {}.",
            "the embroidered {}.",
            "a photo of a hard to see {}.",
            "a bright photo of a {}.",
            "a photo of a clean {}.",
            "a photo of a dirty {}.",
            "a dark photo of the {}.",
            "a drawing of a {}.",
            "a photo of my {}.",
            "the plastic {}.",
            "a photo of the cool {}.",
            "a close-up photo of a {}.",
            "a black and white photo of the {}.",
            "a painting of the {}.",
            "a painting of a {}.",
            "a pixelated photo of the {}.",
            "a sculpture of the {}.",
            "a bright photo of the {}.",
            "a cropped photo of a {}.",
            "a plastic {}.",
            "a photo of the dirty {}.",
            "a jpeg corrupted photo of a {}.",
            "a blurry photo of the {}.",
            "a photo of the {}.",
            "a good photo of the {}.",
            "a rendering of the {}.",
            "a {} in a video game.",
            "a photo of one {}.",
            "a doodle of a {}.",
            "a close-up photo of the {}.",
            "a photo of a {}.",
            "the origami {}.",
            "the {} in a video game.",
            "a sketch of a {}.",
            "a doodle of the {}.",
            "a origami {}.",
            "a low resolution photo of a {}.",
            "the toy {}.",
            "a rendition of the {}.",
            "a photo of the clean {}.",
            "a photo of a large {}.",
            "a rendition of a {}.",
            "a photo of a nice {}.",
            "a photo of a weird {}.",
            "a blurry photo of a {}.",
            "a cartoon {}.",
            "art of a {}.",
            "a sketch of the {}.",
            "a embroidered {}.",
            "a pixelated photo of a {}.",
            "itap of the {}.",
            "a jpeg corrupted photo of the {}.",
            "a good photo of a {}.",
            "a plushie {}.",
            "a photo of the nice {}.",
            "a photo of the small {}.",
            "a photo of the weird {}.",
            "the cartoon {}.",
            "art of the {}.",
            "a drawing of the {}.",
            "a photo of the large {}.",
            "a black and white photo of a {}.",
            "the plushie {}.",
            "a dark photo of a {}.",
            "itap of a {}.",
            "graffiti of the {}.",
            "a toy {}.",
            "itap of my {}.",
            "a photo of a cool {}.",
            "a photo of a small {}.",
            "a tattoo of the {}.",
        ]

    @property
    def id2label(self) -> dict[int, str]:
        """Return the label mapping."""
        return {
            0: "tench",
            1: "goldfish",
            2: "great white shark",
            3: "tiger shark",
            4: "hammerhead shark",
            5: "electric ray",
            6: "stingray",
            7: "rooster",
            8: "hen",
            9: "ostrich",
            10: "brambling",
            11: "goldfinch",
            12: "house finch",
            13: "junco",
            14: "indigo bunting",
            15: "American robin",
            16: "bulbul",
            17: "jay",
            18: "magpie",
            19: "chickadee",
            20: "American dipper",
            21: "kite (bird of prey)",
            22: "bald eagle",
            23: "vulture",
            24: "great grey owl",
            25: "fire salamander",
            26: "smooth newt",
            27: "newt",
            28: "spotted salamander",
            29: "axolotl",
            30: "American bullfrog",
            31: "tree frog",
            32: "tailed frog",
            33: "loggerhead sea turtle",
            34: "leatherback sea turtle",
            35: "mud turtle",
            36: "terrapin",
            37: "box turtle",
            38: "banded gecko",
            39: "green iguana",
            40: "Carolina anole",
            41: "desert grassland whiptail lizard",
            42: "agama",
            43: "frilled-necked lizard",
            44: "alligator lizard",
            45: "Gila monster",
            46: "European green lizard",
            47: "chameleon",
            48: "Komodo dragon",
            49: "Nile crocodile",
            50: "American alligator",
            51: "triceratops",
            52: "worm snake",
            53: "ring-necked snake",
            54: "eastern hog-nosed snake",
            55: "smooth green snake",
            56: "kingsnake",
            57: "garter snake",
            58: "water snake",
            59: "vine snake",
            60: "night snake",
            61: "boa constrictor",
            62: "African rock python",
            63: "Indian cobra",
            64: "green mamba",
            65: "sea snake",
            66: "Saharan horned viper",
            67: "eastern diamondback rattlesnake",
            68: "sidewinder rattlesnake",
            69: "trilobite",
            70: "harvestman",
            71: "scorpion",
            72: "yellow garden spider",
            73: "barn spider",
            74: "European garden spider",
            75: "southern black widow",
            76: "tarantula",
            77: "wolf spider",
            78: "tick",
            79: "centipede",
            80: "black grouse",
            81: "ptarmigan",
            82: "ruffed grouse",
            83: "prairie grouse",
            84: "peafowl",
            85: "quail",
            86: "partridge",
            87: "african grey parrot",
            88: "macaw",
            89: "sulphur-crested cockatoo",
            90: "lorikeet",
            91: "coucal",
            92: "bee eater",
            93: "hornbill",
            94: "hummingbird",
            95: "jacamar",
            96: "toucan",
            97: "duck",
            98: "red-breasted merganser",
            99: "goose",
            100: "black swan",
            101: "tusker",
            102: "echidna",
            103: "platypus",
            104: "wallaby",
            105: "koala",
            106: "wombat",
            107: "jellyfish",
            108: "sea anemone",
            109: "brain coral",
            110: "flatworm",
            111: "nematode",
            112: "conch",
            113: "snail",
            114: "slug",
            115: "sea slug",
            116: "chiton",
            117: "chambered nautilus",
            118: "Dungeness crab",
            119: "rock crab",
            120: "fiddler crab",
            121: "red king crab",
            122: "American lobster",
            123: "spiny lobster",
            124: "crayfish",
            125: "hermit crab",
            126: "isopod",
            127: "white stork",
            128: "black stork",
            129: "spoonbill",
            130: "flamingo",
            131: "little blue heron",
            132: "great egret",
            133: "bittern bird",
            134: "crane bird",
            135: "limpkin",
            136: "common gallinule",
            137: "American coot",
            138: "bustard",
            139: "ruddy turnstone",
            140: "dunlin",
            141: "common redshank",
            142: "dowitcher",
            143: "oystercatcher",
            144: "pelican",
            145: "king penguin",
            146: "albatross",
            147: "grey whale",
            148: "killer whale",
            149: "dugong",
            150: "sea lion",
            151: "Chihuahua",
            152: "Japanese Chin",
            153: "Maltese",
            154: "Pekingese",
            155: "Shih Tzu",
            156: "King Charles Spaniel",
            157: "Papillon",
            158: "toy terrier",
            159: "Rhodesian Ridgeback",
            160: "Afghan Hound",
            161: "Basset Hound",
            162: "Beagle",
            163: "Bloodhound",
            164: "Bluetick Coonhound",
            165: "Black and Tan Coonhound",
            166: "Treeing Walker Coonhound",
            167: "English foxhound",
            168: "Redbone Coonhound",
            169: "borzoi",
            170: "Irish Wolfhound",
            171: "Italian Greyhound",
            172: "Whippet",
            173: "Ibizan Hound",
            174: "Norwegian Elkhound",
            175: "Otterhound",
            176: "Saluki",
            177: "Scottish Deerhound",
            178: "Weimaraner",
            179: "Staffordshire Bull Terrier",
            180: "American Staffordshire Terrier",
            181: "Bedlington Terrier",
            182: "Border Terrier",
            183: "Kerry Blue Terrier",
            184: "Irish Terrier",
            185: "Norfolk Terrier",
            186: "Norwich Terrier",
            187: "Yorkshire Terrier",
            188: "Wire Fox Terrier",
            189: "Lakeland Terrier",
            190: "Sealyham Terrier",
            191: "Airedale Terrier",
            192: "Cairn Terrier",
            193: "Australian Terrier",
            194: "Dandie Dinmont Terrier",
            195: "Boston Terrier",
            196: "Miniature Schnauzer",
            197: "Giant Schnauzer",
            198: "Standard Schnauzer",
            199: "Scottish Terrier",
            200: "Tibetan Terrier",
            201: "Australian Silky Terrier",
            202: "Soft-coated Wheaten Terrier",
            203: "West Highland White Terrier",
            204: "Lhasa Apso",
            205: "Flat-Coated Retriever",
            206: "Curly-coated Retriever",
            207: "Golden Retriever",
            208: "Labrador Retriever",
            209: "Chesapeake Bay Retriever",
            210: "German Shorthaired Pointer",
            211: "Vizsla",
            212: "English Setter",
            213: "Irish Setter",
            214: "Gordon Setter",
            215: "Brittany dog",
            216: "Clumber Spaniel",
            217: "English Springer Spaniel",
            218: "Welsh Springer Spaniel",
            219: "Cocker Spaniel",
            220: "Sussex Spaniel",
            221: "Irish Water Spaniel",
            222: "Kuvasz",
            223: "Schipperke",
            224: "Groenendael dog",
            225: "Malinois",
            226: "Briard",
            227: "Australian Kelpie",
            228: "Komondor",
            229: "Old English Sheepdog",
            230: "Shetland Sheepdog",
            231: "collie",
            232: "Border Collie",
            233: "Bouvier des Flandres dog",
            234: "Rottweiler",
            235: "German Shepherd Dog",
            236: "Dobermann",
            237: "Miniature Pinscher",
            238: "Greater Swiss Mountain Dog",
            239: "Bernese Mountain Dog",
            240: "Appenzeller Sennenhund",
            241: "Entlebucher Sennenhund",
            242: "Boxer",
            243: "Bullmastiff",
            244: "Tibetan Mastiff",
            245: "French Bulldog",
            246: "Great Dane",
            247: "St. Bernard",
            248: "husky",
            249: "Alaskan Malamute",
            250: "Siberian Husky",
            251: "Dalmatian",
            252: "Affenpinscher",
            253: "Basenji",
            254: "pug",
            255: "Leonberger",
            256: "Newfoundland dog",
            257: "Great Pyrenees dog",
            258: "Samoyed",
            259: "Pomeranian",
            260: "Chow Chow",
            261: "Keeshond",
            262: "brussels griffon",
            263: "Pembroke Welsh Corgi",
            264: "Cardigan Welsh Corgi",
            265: "Toy Poodle",
            266: "Miniature Poodle",
            267: "Standard Poodle",
            268: "Mexican hairless dog (xoloitzcuintli)",
            269: "grey wolf",
            270: "Alaskan tundra wolf",
            271: "red wolf or maned wolf",
            272: "coyote",
            273: "dingo",
            274: "dhole",
            275: "African wild dog",
            276: "hyena",
            277: "red fox",
            278: "kit fox",
            279: "Arctic fox",
            280: "grey fox",
            281: "tabby cat",
            282: "tiger cat",
            283: "Persian cat",
            284: "Siamese cat",
            285: "Egyptian Mau",
            286: "cougar",
            287: "lynx",
            288: "leopard",
            289: "snow leopard",
            290: "jaguar",
            291: "lion",
            292: "tiger",
            293: "cheetah",
            294: "brown bear",
            295: "American black bear",
            296: "polar bear",
            297: "sloth bear",
            298: "mongoose",
            299: "meerkat",
            300: "tiger beetle",
            301: "ladybug",
            302: "ground beetle",
            303: "longhorn beetle",
            304: "leaf beetle",
            305: "dung beetle",
            306: "rhinoceros beetle",
            307: "weevil",
            308: "fly",
            309: "bee",
            310: "ant",
            311: "grasshopper",
            312: "cricket insect",
            313: "stick insect",
            314: "cockroach",
            315: "praying mantis",
            316: "cicada",
            317: "leafhopper",
            318: "lacewing",
            319: "dragonfly",
            320: "damselfly",
            321: "red admiral butterfly",
            322: "ringlet butterfly",
            323: "monarch butterfly",
            324: "small white butterfly",
            325: "sulphur butterfly",
            326: "gossamer-winged butterfly",
            327: "starfish",
            328: "sea urchin",
            329: "sea cucumber",
            330: "cottontail rabbit",
            331: "hare",
            332: "Angora rabbit",
            333: "hamster",
            334: "porcupine",
            335: "fox squirrel",
            336: "marmot",
            337: "beaver",
            338: "guinea pig",
            339: "common sorrel horse",
            340: "zebra",
            341: "pig",
            342: "wild boar",
            343: "warthog",
            344: "hippopotamus",
            345: "ox",
            346: "water buffalo",
            347: "bison",
            348: "ram (adult male sheep)",
            349: "bighorn sheep",
            350: "Alpine ibex",
            351: "hartebeest",
            352: "impala (antelope)",
            353: "gazelle",
            354: "arabian camel",
            355: "llama",
            356: "weasel",
            357: "mink",
            358: "European polecat",
            359: "black-footed ferret",
            360: "otter",
            361: "skunk",
            362: "badger",
            363: "armadillo",
            364: "three-toed sloth",
            365: "orangutan",
            366: "gorilla",
            367: "chimpanzee",
            368: "gibbon",
            369: "siamang",
            370: "guenon",
            371: "patas monkey",
            372: "baboon",
            373: "macaque",
            374: "langur",
            375: "black-and-white colobus",
            376: "proboscis monkey",
            377: "marmoset",
            378: "white-headed capuchin",
            379: "howler monkey",
            380: "titi monkey",
            381: "Geoffroy's spider monkey",
            382: "common squirrel monkey",
            383: "ring-tailed lemur",
            384: "indri",
            385: "Asian elephant",
            386: "African bush elephant",
            387: "red panda",
            388: "giant panda",
            389: "snoek fish",
            390: "eel",
            391: "silver salmon",
            392: "rock beauty fish",
            393: "clownfish",
            394: "sturgeon",
            395: "gar fish",
            396: "lionfish",
            397: "pufferfish",
            398: "abacus",
            399: "abaya",
            400: "academic gown",
            401: "accordion",
            402: "acoustic guitar",
            403: "aircraft carrier",
            404: "airliner",
            405: "airship",
            406: "altar",
            407: "ambulance",
            408: "amphibious vehicle",
            409: "analog clock",
            410: "apiary",
            411: "apron",
            412: "trash can",
            413: "assault rifle",
            414: "backpack",
            415: "bakery",
            416: "balance beam",
            417: "balloon",
            418: "ballpoint pen",
            419: "Band-Aid",
            420: "banjo",
            421: "baluster / handrail",
            422: "barbell",
            423: "barber chair",
            424: "barbershop",
            425: "barn",
            426: "barometer",
            427: "barrel",
            428: "wheelbarrow",
            429: "baseball",
            430: "basketball",
            431: "bassinet",
            432: "bassoon",
            433: "swimming cap",
            434: "bath towel",
            435: "bathtub",
            436: "station wagon",
            437: "lighthouse",
            438: "beaker",
            439: "military hat (bearskin or shako)",
            440: "beer bottle",
            441: "beer glass",
            442: "bell tower",
            443: "baby bib",
            444: "tandem bicycle",
            445: "bikini",
            446: "ring binder",
            447: "binoculars",
            448: "birdhouse",
            449: "boathouse",
            450: "bobsleigh",
            451: "bolo tie",
            452: "poke bonnet",
            453: "bookcase",
            454: "bookstore",
            455: "bottle cap",
            456: "hunting bow",
            457: "bow tie",
            458: "brass memorial plaque",
            459: "bra",
            460: "breakwater",
            461: "breastplate",
            462: "broom",
            463: "bucket",
            464: "buckle",
            465: "bulletproof vest",
            466: "high-speed train",
            467: "butcher shop",
            468: "taxicab",
            469: "cauldron",
            470: "candle",
            471: "cannon",
            472: "canoe",
            473: "can opener",
            474: "cardigan",
            475: "car mirror",
            476: "carousel",
            477: "tool kit",
            478: "cardboard box / carton",
            479: "car wheel",
            480: "automated teller machine",
            481: "cassette",
            482: "cassette player",
            483: "castle",
            484: "catamaran",
            485: "CD player",
            486: "cello",
            487: "mobile phone",
            488: "chain",
            489: "chain-link fence",
            490: "chain mail",
            491: "chainsaw",
            492: "storage chest",
            493: "chiffonier",
            494: "bell or wind chime",
            495: "china cabinet",
            496: "Christmas stocking",
            497: "church",
            498: "movie theater",
            499: "cleaver",
            500: "cliff dwelling",
            501: "cloak",
            502: "clogs",
            503: "cocktail shaker",
            504: "coffee mug",
            505: "coffeemaker",
            506: "spiral or coil",
            507: "combination lock",
            508: "computer keyboard",
            509: "candy store",
            510: "container ship",
            511: "convertible",
            512: "corkscrew",
            513: "cornet",
            514: "cowboy boot",
            515: "cowboy hat",
            516: "cradle",
            517: "construction crane",
            518: "crash helmet",
            519: "crate",
            520: "infant bed",
            521: "Crock Pot",
            522: "croquet ball",
            523: "crutch",
            524: "cuirass",
            525: "dam",
            526: "desk",
            527: "desktop computer",
            528: "rotary dial telephone",
            529: "diaper",
            530: "digital clock",
            531: "digital watch",
            532: "dining table",
            533: "dishcloth",
            534: "dishwasher",
            535: "disc brake",
            536: "dock",
            537: "dog sled",
            538: "dome",
            539: "doormat",
            540: "drilling rig",
            541: "drum",
            542: "drumstick",
            543: "dumbbell",
            544: "Dutch oven",
            545: "electric fan",
            546: "electric guitar",
            547: "electric locomotive",
            548: "entertainment center",
            549: "envelope",
            550: "espresso machine",
            551: "face powder",
            552: "feather boa",
            553: "filing cabinet",
            554: "fireboat",
            555: "fire truck",
            556: "fire screen",
            557: "flagpole",
            558: "flute",
            559: "folding chair",
            560: "football helmet",
            561: "forklift",
            562: "fountain",
            563: "fountain pen",
            564: "four-poster bed",
            565: "freight car",
            566: "French horn",
            567: "frying pan",
            568: "fur coat",
            569: "garbage truck",
            570: "gas mask or respirator",
            571: "gas pump",
            572: "goblet",
            573: "go-kart",
            574: "golf ball",
            575: "golf cart",
            576: "gondola",
            577: "gong",
            578: "gown",
            579: "grand piano",
            580: "greenhouse",
            581: "radiator grille",
            582: "grocery store",
            583: "guillotine",
            584: "hair clip",
            585: "hair spray",
            586: "half-track",
            587: "hammer",
            588: "hamper",
            589: "hair dryer",
            590: "hand-held computer",
            591: "handkerchief",
            592: "hard disk drive",
            593: "harmonica",
            594: "harp",
            595: "combine harvester",
            596: "hatchet",
            597: "holster",
            598: "home theater",
            599: "honeycomb",
            600: "hook",
            601: "hoop skirt",
            602: "gymnastic horizontal bar",
            603: "horse-drawn vehicle",
            604: "hourglass",
            605: "iPod",
            606: "clothes iron",
            607: "carved pumpkin",
            608: "jeans",
            609: "jeep",
            610: "T-shirt",
            611: "jigsaw puzzle",
            612: "rickshaw",
            613: "joystick",
            614: "kimono",
            615: "knee pad",
            616: "knot",
            617: "lab coat",
            618: "ladle",
            619: "lampshade",
            620: "laptop computer",
            621: "lawn mower",
            622: "lens cap",
            623: "letter opener",
            624: "library",
            625: "lifeboat",
            626: "lighter",
            627: "limousine",
            628: "ocean liner",
            629: "lipstick",
            630: "slip-on shoe",
            631: "lotion",
            632: "music speaker",
            633: "loupe magnifying glass",
            634: "sawmill",
            635: "magnetic compass",
            636: "messenger bag",
            637: "mailbox",
            638: "tights",
            639: "one-piece bathing suit",
            640: "manhole cover",
            641: "maraca",
            642: "marimba",
            643: "mask",
            644: "matchstick",
            645: "maypole",
            646: "maze",
            647: "measuring cup",
            648: "medicine cabinet",
            649: "megalith",
            650: "microphone",
            651: "microwave oven",
            652: "military uniform",
            653: "milk can",
            654: "minibus",
            655: "miniskirt",
            656: "minivan",
            657: "missile",
            658: "mitten",
            659: "mixing bowl",
            660: "mobile home",
            661: "ford model t",
            662: "modem",
            663: "monastery",
            664: "monitor",
            665: "moped",
            666: "mortar and pestle",
            667: "graduation cap",
            668: "mosque",
            669: "mosquito net",
            670: "vespa",
            671: "mountain bike",
            672: "tent",
            673: "computer mouse",
            674: "mousetrap",
            675: "moving van",
            676: "muzzle",
            677: "metal nail",
            678: "neck brace",
            679: "necklace",
            680: "baby pacifier",
            681: "notebook computer",
            682: "obelisk",
            683: "oboe",
            684: "ocarina",
            685: "odometer",
            686: "oil filter",
            687: "pipe organ",
            688: "oscilloscope",
            689: "overskirt",
            690: "bullock cart",
            691: "oxygen mask",
            692: "product packet / packaging",
            693: "paddle",
            694: "paddle wheel",
            695: "padlock",
            696: "paintbrush",
            697: "pajamas",
            698: "palace",
            699: "pan flute",
            700: "paper towel",
            701: "parachute",
            702: "parallel bars",
            703: "park bench",
            704: "parking meter",
            705: "railroad car",
            706: "patio",
            707: "payphone",
            708: "pedestal",
            709: "pencil case",
            710: "pencil sharpener",
            711: "perfume",
            712: "Petri dish",
            713: "photocopier",
            714: "plectrum",
            715: "Pickelhaube",
            716: "picket fence",
            717: "pickup truck",
            718: "pier",
            719: "piggy bank",
            720: "pill bottle",
            721: "pillow",
            722: "ping-pong ball",
            723: "pinwheel",
            724: "pirate ship",
            725: "drink pitcher",
            726: "block plane",
            727: "planetarium",
            728: "plastic bag",
            729: "plate rack",
            730: "farm plow",
            731: "plunger",
            732: "Polaroid camera",
            733: "pole",
            734: "police van",
            735: "poncho",
            736: "pool table",
            737: "soda bottle",
            738: "plant pot",
            739: "potter's wheel",
            740: "power drill",
            741: "prayer rug",
            742: "printer",
            743: "prison",
            744: "missile",
            745: "projector",
            746: "hockey puck",
            747: "punching bag",
            748: "purse",
            749: "quill",
            750: "quilt",
            751: "race car",
            752: "racket",
            753: "radiator",
            754: "radio",
            755: "radio telescope",
            756: "rain barrel",
            757: "recreational vehicle",
            758: "fishing casting reel",
            759: "reflex camera",
            760: "refrigerator",
            761: "remote control",
            762: "restaurant",
            763: "revolver",
            764: "rifle",
            765: "rocking chair",
            766: "rotisserie",
            767: "eraser",
            768: "rugby ball",
            769: "ruler measuring stick",
            770: "sneaker",
            771: "safe",
            772: "safety pin",
            773: "salt shaker",
            774: "sandal",
            775: "sarong",
            776: "saxophone",
            777: "scabbard",
            778: "weighing scale",
            779: "school bus",
            780: "schooner",
            781: "scoreboard",
            782: "CRT monitor",
            783: "screw",
            784: "screwdriver",
            785: "seat belt",
            786: "sewing machine",
            787: "shield",
            788: "shoe store",
            789: "shoji screen / room divider",
            790: "shopping basket",
            791: "shopping cart",
            792: "shovel",
            793: "shower cap",
            794: "shower curtain",
            795: "ski",
            796: "balaclava ski mask",
            797: "sleeping bag",
            798: "slide rule",
            799: "sliding door",
            800: "slot machine",
            801: "snorkel",
            802: "snowmobile",
            803: "snowplow",
            804: "soap dispenser",
            805: "soccer ball",
            806: "sock",
            807: "solar thermal collector",
            808: "sombrero",
            809: "soup bowl",
            810: "keyboard space bar",
            811: "space heater",
            812: "space shuttle",
            813: "spatula",
            814: "motorboat",
            815: "spider web",
            816: "spindle",
            817: "sports car",
            818: "spotlight",
            819: "stage",
            820: "steam locomotive",
            821: "through arch bridge",
            822: "steel drum",
            823: "stethoscope",
            824: "scarf",
            825: "stone wall",
            826: "stopwatch",
            827: "stove",
            828: "strainer",
            829: "tram",
            830: "stretcher",
            831: "couch",
            832: "stupa",
            833: "submarine",
            834: "suit",
            835: "sundial",
            836: "sunglasses",
            837: "sunglasses",
            838: "sunscreen",
            839: "suspension bridge",
            840: "mop",
            841: "sweatshirt",
            842: "swim trunks / shorts",
            843: "swing",
            844: "electrical switch",
            845: "syringe",
            846: "table lamp",
            847: "tank",
            848: "tape player",
            849: "teapot",
            850: "teddy bear",
            851: "television",
            852: "tennis ball",
            853: "thatched roof",
            854: "front curtain",
            855: "thimble",
            856: "threshing machine",
            857: "throne",
            858: "tile roof",
            859: "toaster",
            860: "tobacco shop",
            861: "toilet seat",
            862: "torch",
            863: "totem pole",
            864: "tow truck",
            865: "toy store",
            866: "tractor",
            867: "semi-trailer truck",
            868: "tray",
            869: "trench coat",
            870: "tricycle",
            871: "trimaran",
            872: "tripod",
            873: "triumphal arch",
            874: "trolleybus",
            875: "trombone",
            876: "hot tub",
            877: "turnstile",
            878: "typewriter keyboard",
            879: "umbrella",
            880: "unicycle",
            881: "upright piano",
            882: "vacuum cleaner",
            883: "vase",
            884: "vaulted or arched ceiling",
            885: "velvet fabric",
            886: "vending machine",
            887: "vestment",
            888: "viaduct",
            889: "violin",
            890: "volleyball",
            891: "waffle iron",
            892: "wall clock",
            893: "wallet",
            894: "wardrobe",
            895: "military aircraft",
            896: "sink",
            897: "washing machine",
            898: "water bottle",
            899: "water jug",
            900: "water tower",
            901: "whiskey jug",
            902: "whistle",
            903: "hair wig",
            904: "window screen",
            905: "window shade",
            906: "Windsor tie",
            907: "wine bottle",
            908: "airplane wing",
            909: "wok",
            910: "wooden spoon",
            911: "wool",
            912: "split-rail fence",
            913: "shipwreck",
            914: "sailboat",
            915: "yurt",
            916: "website",
            917: "comic book",
            918: "crossword",
            919: "traffic or street sign",
            920: "traffic light",
            921: "dust jacket",
            922: "menu",
            923: "plate",
            924: "guacamole",
            925: "consomme",
            926: "hot pot",
            927: "trifle",
            928: "ice cream",
            929: "popsicle",
            930: "baguette",
            931: "bagel",
            932: "pretzel",
            933: "cheeseburger",
            934: "hot dog",
            935: "mashed potatoes",
            936: "cabbage",
            937: "broccoli",
            938: "cauliflower",
            939: "zucchini",
            940: "spaghetti squash",
            941: "acorn squash",
            942: "butternut squash",
            943: "cucumber",
            944: "artichoke",
            945: "bell pepper",
            946: "cardoon",
            947: "mushroom",
            948: "Granny Smith apple",
            949: "strawberry",
            950: "orange",
            951: "lemon",
            952: "fig",
            953: "pineapple",
            954: "banana",
            955: "jackfruit",
            956: "cherimoya (custard apple)",
            957: "pomegranate",
            958: "hay",
            959: "carbonara",
            960: "chocolate syrup",
            961: "dough",
            962: "meatloaf",
            963: "pizza",
            964: "pot pie",
            965: "burrito",
            966: "red wine",
            967: "espresso",
            968: "tea cup",
            969: "eggnog",
            970: "mountain",
            971: "bubble",
            972: "cliff",
            973: "coral reef",
            974: "geyser",
            975: "lakeshore",
            976: "promontory",
            977: "sandbar",
            978: "beach",
            979: "valley",
            980: "volcano",
            981: "baseball player",
            982: "bridegroom",
            983: "scuba diver",
            984: "rapeseed",
            985: "daisy",
            986: "yellow lady's slipper",
            987: "corn",
            988: "acorn",
            989: "rose hip",
            990: "horse chestnut seed",
            991: "coral fungus",
            992: "agaric",
            993: "gyromitra",
            994: "stinkhorn mushroom",
            995: "earth star fungus",
            996: "hen of the woods mushroom",
            997: "bolete",
            998: "corn cob",
            999: "toilet paper",
        }

zero_shot_prompt_templates property

zero_shot_prompt_templates

Return the zero-shot prompt templates.

id2label property

id2label

Return the label mapping.

__getitem__

__getitem__(index)

Get an example at the given index.

Source code in mmlearn/datasets/imagenet.py
def __getitem__(self, index: int) -> Example:
    """Get an example at the given index."""
    image, target = super().__getitem__(index)
    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: target,
            EXAMPLE_INDEX_KEY: index,
        }
    )
    mask = self.mask_generator() if self.mask_generator else None
    if mask is not None:  # error will be raised during collation if `None`
        example[Modalities.RGB.mask] = mask
    return example

LibriSpeech

Bases: Dataset[Example]

LibriSpeech dataset.

This is a wrapper around 🇵🇾class:torchaudio.datasets.LIBRISPEECH that assumes that the dataset is already downloaded and the top-level directory of the dataset in the root directory is librispeech.

Parameters:

Name Type Description Default
root_dir str

Root directory of dataset.

required
split (train - clean - 100, train - clean - 360, train - other - 500, dev - clean, dev - other, test - clean, test - other)

Split of the dataset to use.

"train-clean-100"

Raises:

Type Description
ImportError

If torchaudio is not installed.

Notes

This dataset only returns the audio and transcript from the dataset.

Source code in mmlearn/datasets/librispeech.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LIBRISPEECH_ROOT_DIR", MISSING),
)
class LibriSpeech(Dataset[Example]):
    """LibriSpeech dataset.

    This is a wrapper around :py:class:`torchaudio.datasets.LIBRISPEECH` that assumes
    that the dataset is already downloaded and the top-level directory of the dataset
    in the root directory is `librispeech`.

    Parameters
    ----------
    root_dir : str
        Root directory of dataset.
    split : {"train-clean-100", "train-clean-360", "train-other-500", "dev-clean", "dev-other", "test-clean", "test-other"}, default="train-clean-100"
        Split of the dataset to use.

    Raises
    ------
    ImportError
        If ``torchaudio`` is not installed.

    Notes
    -----
    This dataset only returns the audio and transcript from the dataset.

    """  # noqa: W505

    def __init__(self, root_dir: str, split: str = "train-clean-100") -> None:
        super().__init__()
        if not _TORCHAUDIO_AVAILABLE:
            raise ImportError(
                "LibriSpeech dataset requires `torchaudio`, which is not installed."
            )
        from torchaudio.datasets import LIBRISPEECH

        self.dataset = LIBRISPEECH(
            root=root_dir,
            url=split,
            download=False,
            folder_in_archive="librispeech",
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
        assert sample_rate == SAMPLE_RATE, (
            f"Expected sample rate to be `16000`, got {sample_rate}."
        )
        waveform = pad_or_trim(waveform.flatten())

        return Example(
            {
                Modalities.AUDIO.name: waveform,
                Modalities.TEXT.name: transcript,
                EXAMPLE_INDEX_KEY: idx,
            },
        )

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/librispeech.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)

__getitem__

__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/librispeech.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
    assert sample_rate == SAMPLE_RATE, (
        f"Expected sample rate to be `16000`, got {sample_rate}."
    )
    waveform = pad_or_trim(waveform.flatten())

    return Example(
        {
            Modalities.AUDIO.name: waveform,
            Modalities.TEXT.name: transcript,
            EXAMPLE_INDEX_KEY: idx,
        },
    )

LLVIPDataset

Bases: Dataset[Example]

Low-Light Visible-Infrared Pair (LLVIP) dataset.

Loads pairs of RGB and THERMAL images from the LLVIP dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset. The directory should contain 'visible' and 'infrared' subdirectories.

required
train bool

Flag to indicate whether to load the training or test set.

True
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor. This is applied to both RGB and thermal images.

None
Source code in mmlearn/datasets/llvip.py
@store(
    name="LLVIP",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LLVIP_ROOT_DIR", MISSING),
)
class LLVIPDataset(Dataset[Example]):
    """Low-Light Visible-Infrared Pair (LLVIP) dataset.

    Loads pairs of `RGB` and `THERMAL` images from the LLVIP dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset. The directory should contain
        'visible' and 'infrared' subdirectories.
    train : bool, default=True
        Flag to indicate whether to load the training or test set.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor. This is applied to both RGB and thermal
        images.
    """

    def __init__(
        self,
        root_dir: str,
        train: bool = True,
        transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ):
        self.path_images_rgb = os.path.join(
            root_dir,
            "visible",
            "train" if train else "test",
        )
        self.path_images_ir = os.path.join(
            root_dir, "infrared", "train" if train else "test"
        )
        self.train = train
        self.transform = transform or transforms.ToTensor()

        self.rgb_images = sorted(glob.glob(os.path.join(self.path_images_rgb, "*.jpg")))
        self.ir_images = sorted(glob.glob(os.path.join(self.path_images_ir, "*.jpg")))

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.rgb_images)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        rgb_image_path = self.rgb_images[idx]
        ir_image_path = self.ir_images[idx]

        rgb_image = PILImage.open(rgb_image_path).convert("RGB")
        ir_image = PILImage.open(ir_image_path).convert("L")

        example = Example(
            {
                Modalities.RGB.name: self.transform(rgb_image),
                Modalities.THERMAL.name: self.transform(ir_image),
                EXAMPLE_INDEX_KEY: idx,
            },
        )

        if self.train:
            annot_path = (
                rgb_image_path.replace("visible", "Annotations")
                .replace(".jpg", ".xml")
                .replace("train", "")
            )
            annot = self._get_bbox(annot_path)
            example["annotation"] = {
                "bboxes": torch.from_numpy(annot["bboxes"]),
                "labels": torch.from_numpy(annot["labels"]),
            }
        return example

    def _get_bbox(self, filename: str) -> dict[str, np.ndarray]:
        """Parse the XML file to get bounding boxes and labels.

        Parameters
        ----------
        filename : str
            Path to the annotation XML file.

        Returns
        -------
        dict
            A dictionary containing bounding boxes and labels.
        """
        try:
            root = ET.parse(filename).getroot()

            bboxes, labels = [], []
            for obj in root.findall("object"):
                bbox_obj = obj.find("bndbox")
                bbox = [
                    int(bbox_obj.find(dim).text)  # type: ignore[union-attr,arg-type]
                    for dim in ["xmin", "ymin", "xmax", "ymax"]
                ]
                bboxes.append(bbox)
                labels.append(1)  # Assuming 'person' is the only label
            return {
                "bboxes": np.array(bboxes).astype("float"),
                "labels": np.array(labels).astype("int"),
            }
        except ET.ParseError as e:
            raise ValueError(f"Error parsing XML: {e}") from None
        except Exception as e:
            raise RuntimeError(
                f"Error processing annotation file {filename}: {e}",
            ) from None

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/llvip.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.rgb_images)

__getitem__

__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/llvip.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    rgb_image_path = self.rgb_images[idx]
    ir_image_path = self.ir_images[idx]

    rgb_image = PILImage.open(rgb_image_path).convert("RGB")
    ir_image = PILImage.open(ir_image_path).convert("L")

    example = Example(
        {
            Modalities.RGB.name: self.transform(rgb_image),
            Modalities.THERMAL.name: self.transform(ir_image),
            EXAMPLE_INDEX_KEY: idx,
        },
    )

    if self.train:
        annot_path = (
            rgb_image_path.replace("visible", "Annotations")
            .replace(".jpg", ".xml")
            .replace("train", "")
        )
        annot = self._get_bbox(annot_path)
        example["annotation"] = {
            "bboxes": torch.from_numpy(annot["bboxes"]),
            "labels": torch.from_numpy(annot["labels"]),
        }
    return example

NIHCXR

Bases: Dataset[Example]

NIH Chest X-ray dataset.

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, test, bbox)

Dataset split. "bbox" is a subset of "test" which contains bounding box info.

"train"
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/nihcxr.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NIH_CXR_DIR", MISSING),
    split="train",
)
class NIHCXR(Dataset[Example]):
    """NIH Chest X-ray dataset.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "test", "bbox"}
        Dataset split. "bbox" is a subset of "test" which contains bounding box info.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test", "bbox"],
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "test", "bbox"], f"split {split} is not available."
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        data_path = os.path.join(root_dir, split + "_data.json")

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

        self.bbox = split == "bbox"

    def __getitem__(self, idx: int) -> Example:
        """Return image-label or image-label-tabular(bbox)."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

        if self.bbox:
            example["bbox"] = entry["bbox"]

        return example

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)

__getitem__

__getitem__(idx)

Return image-label or image-label-tabular(bbox).

Source code in mmlearn/datasets/nihcxr.py
def __getitem__(self, idx: int) -> Example:
    """Return image-label or image-label-tabular(bbox)."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )

    if self.bbox:
        example["bbox"] = entry["bbox"]

    return example

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nihcxr.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

NYUv2Dataset

Bases: Dataset[Example]

NYUv2 dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images.

  • "disparity": Return the depth image as disparity map.
  • "image": Return the depth image as a 3-channel image.
"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None

Raises:

Type Description
ImportError

If opencv-python is not installed.

Source code in mmlearn/datasets/nyuv2.py
@store(
    name="NYUv2",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NYUV2_ROOT_DIR", MISSING),
)
class NYUv2Dataset(Dataset[Example]):
    """NYUv2 dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images.

        - `"disparity"`: Return the depth image as disparity map.
        - `"image"`: Return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    Raises
    ------
    ImportError
        If `opencv-python` is not installed.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "NYUv2 dataset requires `opencv-python` which is not installed.",
            )
        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.png") for f in file_ids]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = [str(open(f).read().strip()) for f in label_files]  # noqa: SIM115
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]

        self.samples = list(zip(rgb_files, depth_files, labels, strict=False))

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = depth_normalize(
                self.samples[idx][1],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nyuv2.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)

__getitem__

__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/nyuv2.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = depth_normalize(
            self.samples[idx][1],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

SUNRGBDDataset

Bases: Dataset[Example]

SUN RGB-D dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images. If "disparity", the depth images are converted to disparity similar to the ImageBind implementation. Otherwise, return the depth image as a 3-channel image.

"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None
References

.. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets

Source code in mmlearn/datasets/sunrgbd.py
@store(
    name="SUNRGBD",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("SUNRGBD_ROOT_DIR", MISSING),
)
class SUNRGBDDataset(Dataset[Example]):
    """SUN RGB-D dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images. If "disparity", the depth images are
        converted to disparity similar to the ImageBind implementation.
        Otherwise, return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    References
    ----------
    .. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "SUN RGB-D dataset requires `opencv-python` which is not installed.",
            )

        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.jpg") for f in file_ids]
        intrinsic_files = [
            os.path.join(root_dir, "intrinsics", f"{f}.txt") for f in file_ids
        ]

        sensor_types = [
            file.removeprefix(os.path.join(root_dir, "depth")).split(os.sep)[1]
            for file in depth_files
        ]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = []
        for label_file in label_files:
            with open(label_file, "r") as file:  # noqa: SIM115
                labels.append(file.read().strip())
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]
            intrinsic_files = [intrinsic_files[i] for i in valid_indices]
            sensor_types = [sensor_types[i] for i in valid_indices]

        self.samples = list(
            zip(
                rgb_files,
                depth_files,
                labels,
                intrinsic_files,
                sensor_types,
                strict=False,
            )
        )

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = convert_depth_to_disparity(
                self.samples[idx][1],
                self.samples[idx][3],
                self.samples[idx][4],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )

__len__

__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/sunrgbd.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)

__getitem__

__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/sunrgbd.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = convert_depth_to_disparity(
            self.samples[idx][1],
            self.samples[idx][3],
            self.samples[idx][4],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

chexpert

CheXpert Dataset.

CheXpert

Bases: Dataset[Example]

CheXpert dataset.

Each datapoint is a pair of (image, target label).

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, valid)

Dataset split.

"train"
labeler Optional[{chexpert, chexbert, vchexbert}]

Labeler used to extract labels from the training images. "valid" split has no labeler, labeling for valid split was done by human radiologists.

None
transform Optional[Callable[[PIL.Image], torch.Tensor]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/chexpert.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("CHEXPERT_ROOT_DIR", MISSING),
    split="train",
)
class CheXpert(Dataset[Example]):
    """CheXpert dataset.

    Each datapoint is a pair of `(image, target label)`.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "valid"}
        Dataset split.
    labeler : Optional[{"chexpert", "chexbert", "vchexbert"}], optional, default=None
        Labeler used to extract labels from the training images. "valid" split
        has no labeler, labeling for valid split was done by human radiologists.
    transform : Optional[Callable[[PIL.Image], torch.Tensor], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "valid"],
        labeler: Optional[Literal["chexpert", "chexbert", "vchexbert"]] = None,
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "valid"], f"split {split} is not available."
        assert labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None, (
            f"labeler {labeler} is not available."
        )
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        if split == "valid":
            data_file = f"{split}_data.json"
        elif split == "train":
            data_file = f"{labeler}_{split}_data.json"
        data_path = os.path.join(root_dir, data_file)

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

    def __getitem__(self, idx: int) -> Example:
        """Return the idx'th datapoint."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        return Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return the idx'th datapoint.

Source code in mmlearn/datasets/chexpert.py
def __getitem__(self, idx: int) -> Example:
    """Return the idx'th datapoint."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    return Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/chexpert.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

core

Modules for core dataloading functionality.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)

DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

find_matching_indices

find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices

combined_dataset

Wrapper for combining multiple datasets into one.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]

data_collator

Data collators for batching examples.

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch
collate_example_list
collate_example_list(examples)

Collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch.

Parameters:

Name Type Description Default
examples list[Example]

list of examples to collate.

required

Returns:

Type Description
dict[str, Any]

Dictionary of batched examples.

Source code in mmlearn/datasets/core/data_collator.py
def collate_example_list(examples: list[Example]) -> dict[str, Any]:
    """Collate a list of :py:class:`~mmlearn.datasets.core.example.Example` objects into a batch.

    Parameters
    ----------
    examples : list[Example]
        list of examples to collate.

    Returns
    -------
    dict[str, Any]
        Dictionary of batched examples.

    """  # noqa: W505
    return _collate_example_dict(_merge_examples(examples))

example

Module for example-related classes and functions.

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)
find_matching_indices
find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices

modalities

Module for managing supported modalities in the library.

Modality dataclass

A representation of a modality in the library.

This class is used to represent a modality in the library. It contains the name of the modality and the properties that can be associated with it. The properties are dynamically generated based on the name of the modality and can be accessed as attributes of the class.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality, by default None

None

Raises:

Type Description
ValueError

If the property already exists for the modality or if the format string is invalid.

Source code in mmlearn/datasets/core/modalities.py
@dataclass
class Modality:
    """A representation of a modality in the library.

    This class is used to represent a modality in the library. It contains the name of
    the modality and the properties that can be associated with it. The properties are
    dynamically generated based on the name of the modality and can be accessed as
    attributes of the class.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality, by default None

    Raises
    ------
    ValueError
        If the property already exists for the modality or if the format string is
        invalid.
    """

    #: The name of the modality.
    name: str

    #: Target/label associated with the modality. This will return ``name_target``.
    target: str = field(init=False, repr=False)

    #: Attention mask associated with the modality. This will return
    # ``name_attention_mask``.
    attention_mask: str = field(init=False, repr=False)

    #: Input mask associated with the modality. This will return ``name_mask``.
    mask: str = field(init=False, repr=False)

    #: Embedding associated with the modality. This will return ``name_embedding``.
    embedding: str = field(init=False, repr=False)

    #: Masked embedding associated with the modality. This will return
    # ``name_masked_embedding``.
    masked_embedding: str = field(init=False, repr=False)

    #: Embedding from an Exponential Moving Average (EMA) encoder associated with
    #: the modality.
    ema_embedding: str = field(init=False, repr=False)

    #: Other properties specific to the modality.
    modality_specific_properties: Optional[dict[str, str]] = field(
        default=None, repr=False
    )

    def __post_init__(self) -> None:
        """Initialize the modality with the name and properties."""
        self.name = self.name.lower()
        self._properties = {}

        for field_name in self.__dataclass_fields__:
            if field_name not in ("name", "modality_specific_properties"):
                field_value = f"{self.name}_{field_name}"
                self._properties[field_name] = field_value
                setattr(self, field_name, field_value)

        if self.modality_specific_properties is not None:
            for (
                property_name,
                format_string,
            ) in self.modality_specific_properties.items():
                self.add_property(property_name, format_string)

    @property
    def properties(self) -> dict[str, str]:
        """Return the properties associated with the modality."""
        return self._properties

    def add_property(self, name: str, format_string: str) -> None:
        """Add a new property to the modality.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the modality. It will overwrite the
            existing property.

        Raises
        ------
        ValueError
            If `format_string` is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        if name in self._properties:
            warnings.warn(
                f"Property '{name}' already exists for modality '{super().__str__()}'."
                "Will overwrite the existing property.",
                category=UserWarning,
                stacklevel=2,
            )

        if not _is_format_string(format_string):
            raise ValueError(
                f"Invalid format string '{format_string}' for property "
                f"'{name}' of modality '{super().__str__()}'."
            )

        self._properties[name] = format_string.format(self.name)
        setattr(self, name, self._properties[name])

    def __str__(self) -> str:
        """Return the object as a string."""
        return self.name.lower()
properties property
properties

Return the properties associated with the modality.

__post_init__
__post_init__()

Initialize the modality with the name and properties.

Source code in mmlearn/datasets/core/modalities.py
def __post_init__(self) -> None:
    """Initialize the modality with the name and properties."""
    self.name = self.name.lower()
    self._properties = {}

    for field_name in self.__dataclass_fields__:
        if field_name not in ("name", "modality_specific_properties"):
            field_value = f"{self.name}_{field_name}"
            self._properties[field_name] = field_value
            setattr(self, field_name, field_value)

    if self.modality_specific_properties is not None:
        for (
            property_name,
            format_string,
        ) in self.modality_specific_properties.items():
            self.add_property(property_name, format_string)
add_property
add_property(name, format_string)

Add a new property to the modality.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the modality. It will overwrite the existing property.

Raises:

Type Description
ValueError

If format_string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_property(self, name: str, format_string: str) -> None:
    """Add a new property to the modality.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the modality. It will overwrite the
        existing property.

    Raises
    ------
    ValueError
        If `format_string` is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    if name in self._properties:
        warnings.warn(
            f"Property '{name}' already exists for modality '{super().__str__()}'."
            "Will overwrite the existing property.",
            category=UserWarning,
            stacklevel=2,
        )

    if not _is_format_string(format_string):
        raise ValueError(
            f"Invalid format string '{format_string}' for property "
            f"'{name}' of modality '{super().__str__()}'."
        )

    self._properties[name] = format_string.format(self.name)
    setattr(self, name, self._properties[name])
__str__
__str__()

Return the object as a string.

Source code in mmlearn/datasets/core/modalities.py
def __str__(self) -> str:
    """Return the object as a string."""
    return self.name.lower()
ModalityRegistry

Modality registry.

A singleton class that manages the supported modalities (and their properties) in the library. The class provides methods to add new modalities and properties, and to access the existing modalities. The class is implemented as a singleton to ensure that there is only one instance of the registry in the library.

Source code in mmlearn/datasets/core/modalities.py
class ModalityRegistry:
    """Modality registry.

    A singleton class that manages the supported modalities (and their properties) in
    the library. The class provides methods to add new modalities and properties, and
    to access the existing modalities. The class is implemented as a singleton to
    ensure that there is only one instance of the registry in the library.
    """

    _instance: ClassVar[Any] = None
    _modality_registry: dict[str, Modality] = {}

    def __new__(cls) -> Self:
        """Create a new instance of the class if it does not exist."""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._modality_registry = {}
        return cls._instance  # type: ignore[no-any-return]

    def register_modality(
        self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
    ) -> None:
        """Add a new modality to the registry.

        Parameters
        ----------
        name : str
            The name of the modality.
        modality_specific_properties : Optional[dict[str, str]], optional, default=None
            Additional properties specific to the modality.

        Warns
        -----
        UserWarning
            If the modality already exists in the registry. It will overwrite the
            existing modality.

        """
        if name.lower() in self._modality_registry:
            warnings.warn(
                f"Modality '{name}' already exists in the registry. Overwriting...",
                category=UserWarning,
                stacklevel=2,
            )

        name = name.lower()
        modality = Modality(name, modality_specific_properties)
        self._modality_registry[name] = modality
        setattr(self, name, modality)

    def add_default_property(self, name: str, format_string: str) -> None:
        """Add a new property that is applicable to all modalities.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the default properties. It will
            overwrite the existing property.

        Raises
        ------
        ValueError
            If the format string is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        for modality in self._modality_registry.values():
            modality.add_property(name, format_string)

    def has_modality(self, name: str) -> bool:
        """Check if the modality exists in the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        bool
            True if the modality exists in the registry, False otherwise.
        """
        return name.lower() in self._modality_registry

    def get_modality(self, name: str) -> Modality:
        """Get the modality name from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        Modality
            The modality object from the registry.
        """
        return self._modality_registry[name.lower()]

    def get_modality_properties(self, name: str) -> dict[str, str]:
        """Get the properties of a modality from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        dict[str, str]
            The properties associated with the modality.
        """
        return self.get_modality(name).properties

    def list_modalities(self) -> list[Modality]:
        """Get the list of supported modalities in the registry.

        Returns
        -------
        list[Modality]
            The list of supported modalities in the registry.
        """
        return list(self._modality_registry.values())

    def __getattr__(self, name: str) -> Modality:
        """Access a modality as an attribute by its name."""
        if name.lower() in self._modality_registry:
            return self._modality_registry[name.lower()]
        raise AttributeError(
            f"'{self.__class__.__name__}' object has no attribute '{name}'"
        )
__new__
__new__()

Create a new instance of the class if it does not exist.

Source code in mmlearn/datasets/core/modalities.py
def __new__(cls) -> Self:
    """Create a new instance of the class if it does not exist."""
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._modality_registry = {}
    return cls._instance  # type: ignore[no-any-return]
register_modality
register_modality(name, modality_specific_properties=None)

Add a new modality to the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality.

None

Warns:

Type Description
UserWarning

If the modality already exists in the registry. It will overwrite the existing modality.

Source code in mmlearn/datasets/core/modalities.py
def register_modality(
    self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
) -> None:
    """Add a new modality to the registry.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality.

    Warns
    -----
    UserWarning
        If the modality already exists in the registry. It will overwrite the
        existing modality.

    """
    if name.lower() in self._modality_registry:
        warnings.warn(
            f"Modality '{name}' already exists in the registry. Overwriting...",
            category=UserWarning,
            stacklevel=2,
        )

    name = name.lower()
    modality = Modality(name, modality_specific_properties)
    self._modality_registry[name] = modality
    setattr(self, name, modality)
add_default_property
add_default_property(name, format_string)

Add a new property that is applicable to all modalities.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the default properties. It will overwrite the existing property.

Raises:

Type Description
ValueError

If the format string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_default_property(self, name: str, format_string: str) -> None:
    """Add a new property that is applicable to all modalities.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the default properties. It will
        overwrite the existing property.

    Raises
    ------
    ValueError
        If the format string is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    for modality in self._modality_registry.values():
        modality.add_property(name, format_string)
has_modality
has_modality(name)

Check if the modality exists in the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
bool

True if the modality exists in the registry, False otherwise.

Source code in mmlearn/datasets/core/modalities.py
def has_modality(self, name: str) -> bool:
    """Check if the modality exists in the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    bool
        True if the modality exists in the registry, False otherwise.
    """
    return name.lower() in self._modality_registry
get_modality
get_modality(name)

Get the modality name from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
Modality

The modality object from the registry.

Source code in mmlearn/datasets/core/modalities.py
def get_modality(self, name: str) -> Modality:
    """Get the modality name from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    Modality
        The modality object from the registry.
    """
    return self._modality_registry[name.lower()]
get_modality_properties
get_modality_properties(name)

Get the properties of a modality from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
dict[str, str]

The properties associated with the modality.

Source code in mmlearn/datasets/core/modalities.py
def get_modality_properties(self, name: str) -> dict[str, str]:
    """Get the properties of a modality from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    dict[str, str]
        The properties associated with the modality.
    """
    return self.get_modality(name).properties
list_modalities
list_modalities()

Get the list of supported modalities in the registry.

Returns:

Type Description
list[Modality]

The list of supported modalities in the registry.

Source code in mmlearn/datasets/core/modalities.py
def list_modalities(self) -> list[Modality]:
    """Get the list of supported modalities in the registry.

    Returns
    -------
    list[Modality]
        The list of supported modalities in the registry.
    """
    return list(self._modality_registry.values())
__getattr__
__getattr__(name)

Access a modality as an attribute by its name.

Source code in mmlearn/datasets/core/modalities.py
def __getattr__(self, name: str) -> Modality:
    """Access a modality as an attribute by its name."""
    if name.lower() in self._modality_registry:
        return self._modality_registry[name.lower()]
    raise AttributeError(
        f"'{self.__class__.__name__}' object has no attribute '{name}'"
    )

samplers

Samplers for data loading.

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)
DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

imagenet

ImageNet dataset.

ImageNet

Bases: ImageFolder

ImageNet dataset.

This is a wrapper around the 🇵🇾class:~torchvision.datasets.ImageFolder class that returns an 🇵🇾class:~mmlearn.datasets.core.example.Example object.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, val)

The split of the dataset to use.

"train"
transform Optional[Callable]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
target_transform Optional[Callable]

A callable that takes in the target and transforms it.

None
mask_generator Optional[Callable]

A callable that generates a mask for the image.

None
Source code in mmlearn/datasets/imagenet.py
  14
  15
  16
  17
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("IMAGENET_ROOT_DIR", MISSING),
)
class ImageNet(ImageFolder):
    """ImageNet dataset.

    This is a wrapper around the :py:class:`~torchvision.datasets.ImageFolder` class
    that returns an :py:class:`~mmlearn.datasets.core.example.Example` object.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "val"}, default="train"
        The split of the dataset to use.
    transform : Optional[Callable], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    target_transform : Optional[Callable], optional, default=None
        A callable that takes in the target and transforms it.
    mask_generator : Optional[Callable], optional, default=None
        A callable that generates a mask for the image.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "val"] = "train",
        transform: Optional[Callable[..., Any]] = None,
        target_transform: Optional[Callable[..., Any]] = None,
        mask_generator: Optional[Callable[..., Any]] = None,
    ) -> None:
        split = "train" if split == "train" else "val"
        root_dir = os.path.join(root_dir, split)
        super().__init__(
            root=root_dir, transform=transform, target_transform=target_transform
        )
        self.mask_generator = mask_generator

    def __getitem__(self, index: int) -> Example:
        """Get an example at the given index."""
        image, target = super().__getitem__(index)
        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: target,
                EXAMPLE_INDEX_KEY: index,
            }
        )
        mask = self.mask_generator() if self.mask_generator else None
        if mask is not None:  # error will be raised during collation if `None`
            example[Modalities.RGB.mask] = mask
        return example

    @property
    def zero_shot_prompt_templates(self) -> list[str]:
        """Return the zero-shot prompt templates."""
        return [
            "a bad photo of a {}.",
            "a photo of many {}.",
            "a sculpture of a {}.",
            "a photo of the hard to see {}.",
            "a low resolution photo of the {}.",
            "a rendering of a {}.",
            "graffiti of a {}.",
            "a bad photo of the {}.",
            "a cropped photo of the {}.",
            "a tattoo of a {}.",
            "the embroidered {}.",
            "a photo of a hard to see {}.",
            "a bright photo of a {}.",
            "a photo of a clean {}.",
            "a photo of a dirty {}.",
            "a dark photo of the {}.",
            "a drawing of a {}.",
            "a photo of my {}.",
            "the plastic {}.",
            "a photo of the cool {}.",
            "a close-up photo of a {}.",
            "a black and white photo of the {}.",
            "a painting of the {}.",
            "a painting of a {}.",
            "a pixelated photo of the {}.",
            "a sculpture of the {}.",
            "a bright photo of the {}.",
            "a cropped photo of a {}.",
            "a plastic {}.",
            "a photo of the dirty {}.",
            "a jpeg corrupted photo of a {}.",
            "a blurry photo of the {}.",
            "a photo of the {}.",
            "a good photo of the {}.",
            "a rendering of the {}.",
            "a {} in a video game.",
            "a photo of one {}.",
            "a doodle of a {}.",
            "a close-up photo of the {}.",
            "a photo of a {}.",
            "the origami {}.",
            "the {} in a video game.",
            "a sketch of a {}.",
            "a doodle of the {}.",
            "a origami {}.",
            "a low resolution photo of a {}.",
            "the toy {}.",
            "a rendition of the {}.",
            "a photo of the clean {}.",
            "a photo of a large {}.",
            "a rendition of a {}.",
            "a photo of a nice {}.",
            "a photo of a weird {}.",
            "a blurry photo of a {}.",
            "a cartoon {}.",
            "art of a {}.",
            "a sketch of the {}.",
            "a embroidered {}.",
            "a pixelated photo of a {}.",
            "itap of the {}.",
            "a jpeg corrupted photo of the {}.",
            "a good photo of a {}.",
            "a plushie {}.",
            "a photo of the nice {}.",
            "a photo of the small {}.",
            "a photo of the weird {}.",
            "the cartoon {}.",
            "art of the {}.",
            "a drawing of the {}.",
            "a photo of the large {}.",
            "a black and white photo of a {}.",
            "the plushie {}.",
            "a dark photo of a {}.",
            "itap of a {}.",
            "graffiti of the {}.",
            "a toy {}.",
            "itap of my {}.",
            "a photo of a cool {}.",
            "a photo of a small {}.",
            "a tattoo of the {}.",
        ]

    @property
    def id2label(self) -> dict[int, str]:
        """Return the label mapping."""
        return {
            0: "tench",
            1: "goldfish",
            2: "great white shark",
            3: "tiger shark",
            4: "hammerhead shark",
            5: "electric ray",
            6: "stingray",
            7: "rooster",
            8: "hen",
            9: "ostrich",
            10: "brambling",
            11: "goldfinch",
            12: "house finch",
            13: "junco",
            14: "indigo bunting",
            15: "American robin",
            16: "bulbul",
            17: "jay",
            18: "magpie",
            19: "chickadee",
            20: "American dipper",
            21: "kite (bird of prey)",
            22: "bald eagle",
            23: "vulture",
            24: "great grey owl",
            25: "fire salamander",
            26: "smooth newt",
            27: "newt",
            28: "spotted salamander",
            29: "axolotl",
            30: "American bullfrog",
            31: "tree frog",
            32: "tailed frog",
            33: "loggerhead sea turtle",
            34: "leatherback sea turtle",
            35: "mud turtle",
            36: "terrapin",
            37: "box turtle",
            38: "banded gecko",
            39: "green iguana",
            40: "Carolina anole",
            41: "desert grassland whiptail lizard",
            42: "agama",
            43: "frilled-necked lizard",
            44: "alligator lizard",
            45: "Gila monster",
            46: "European green lizard",
            47: "chameleon",
            48: "Komodo dragon",
            49: "Nile crocodile",
            50: "American alligator",
            51: "triceratops",
            52: "worm snake",
            53: "ring-necked snake",
            54: "eastern hog-nosed snake",
            55: "smooth green snake",
            56: "kingsnake",
            57: "garter snake",
            58: "water snake",
            59: "vine snake",
            60: "night snake",
            61: "boa constrictor",
            62: "African rock python",
            63: "Indian cobra",
            64: "green mamba",
            65: "sea snake",
            66: "Saharan horned viper",
            67: "eastern diamondback rattlesnake",
            68: "sidewinder rattlesnake",
            69: "trilobite",
            70: "harvestman",
            71: "scorpion",
            72: "yellow garden spider",
            73: "barn spider",
            74: "European garden spider",
            75: "southern black widow",
            76: "tarantula",
            77: "wolf spider",
            78: "tick",
            79: "centipede",
            80: "black grouse",
            81: "ptarmigan",
            82: "ruffed grouse",
            83: "prairie grouse",
            84: "peafowl",
            85: "quail",
            86: "partridge",
            87: "african grey parrot",
            88: "macaw",
            89: "sulphur-crested cockatoo",
            90: "lorikeet",
            91: "coucal",
            92: "bee eater",
            93: "hornbill",
            94: "hummingbird",
            95: "jacamar",
            96: "toucan",
            97: "duck",
            98: "red-breasted merganser",
            99: "goose",
            100: "black swan",
            101: "tusker",
            102: "echidna",
            103: "platypus",
            104: "wallaby",
            105: "koala",
            106: "wombat",
            107: "jellyfish",
            108: "sea anemone",
            109: "brain coral",
            110: "flatworm",
            111: "nematode",
            112: "conch",
            113: "snail",
            114: "slug",
            115: "sea slug",
            116: "chiton",
            117: "chambered nautilus",
            118: "Dungeness crab",
            119: "rock crab",
            120: "fiddler crab",
            121: "red king crab",
            122: "American lobster",
            123: "spiny lobster",
            124: "crayfish",
            125: "hermit crab",
            126: "isopod",
            127: "white stork",
            128: "black stork",
            129: "spoonbill",
            130: "flamingo",
            131: "little blue heron",
            132: "great egret",
            133: "bittern bird",
            134: "crane bird",
            135: "limpkin",
            136: "common gallinule",
            137: "American coot",
            138: "bustard",
            139: "ruddy turnstone",
            140: "dunlin",
            141: "common redshank",
            142: "dowitcher",
            143: "oystercatcher",
            144: "pelican",
            145: "king penguin",
            146: "albatross",
            147: "grey whale",
            148: "killer whale",
            149: "dugong",
            150: "sea lion",
            151: "Chihuahua",
            152: "Japanese Chin",
            153: "Maltese",
            154: "Pekingese",
            155: "Shih Tzu",
            156: "King Charles Spaniel",
            157: "Papillon",
            158: "toy terrier",
            159: "Rhodesian Ridgeback",
            160: "Afghan Hound",
            161: "Basset Hound",
            162: "Beagle",
            163: "Bloodhound",
            164: "Bluetick Coonhound",
            165: "Black and Tan Coonhound",
            166: "Treeing Walker Coonhound",
            167: "English foxhound",
            168: "Redbone Coonhound",
            169: "borzoi",
            170: "Irish Wolfhound",
            171: "Italian Greyhound",
            172: "Whippet",
            173: "Ibizan Hound",
            174: "Norwegian Elkhound",
            175: "Otterhound",
            176: "Saluki",
            177: "Scottish Deerhound",
            178: "Weimaraner",
            179: "Staffordshire Bull Terrier",
            180: "American Staffordshire Terrier",
            181: "Bedlington Terrier",
            182: "Border Terrier",
            183: "Kerry Blue Terrier",
            184: "Irish Terrier",
            185: "Norfolk Terrier",
            186: "Norwich Terrier",
            187: "Yorkshire Terrier",
            188: "Wire Fox Terrier",
            189: "Lakeland Terrier",
            190: "Sealyham Terrier",
            191: "Airedale Terrier",
            192: "Cairn Terrier",
            193: "Australian Terrier",
            194: "Dandie Dinmont Terrier",
            195: "Boston Terrier",
            196: "Miniature Schnauzer",
            197: "Giant Schnauzer",
            198: "Standard Schnauzer",
            199: "Scottish Terrier",
            200: "Tibetan Terrier",
            201: "Australian Silky Terrier",
            202: "Soft-coated Wheaten Terrier",
            203: "West Highland White Terrier",
            204: "Lhasa Apso",
            205: "Flat-Coated Retriever",
            206: "Curly-coated Retriever",
            207: "Golden Retriever",
            208: "Labrador Retriever",
            209: "Chesapeake Bay Retriever",
            210: "German Shorthaired Pointer",
            211: "Vizsla",
            212: "English Setter",
            213: "Irish Setter",
            214: "Gordon Setter",
            215: "Brittany dog",
            216: "Clumber Spaniel",
            217: "English Springer Spaniel",
            218: "Welsh Springer Spaniel",
            219: "Cocker Spaniel",
            220: "Sussex Spaniel",
            221: "Irish Water Spaniel",
            222: "Kuvasz",
            223: "Schipperke",
            224: "Groenendael dog",
            225: "Malinois",
            226: "Briard",
            227: "Australian Kelpie",
            228: "Komondor",
            229: "Old English Sheepdog",
            230: "Shetland Sheepdog",
            231: "collie",
            232: "Border Collie",
            233: "Bouvier des Flandres dog",
            234: "Rottweiler",
            235: "German Shepherd Dog",
            236: "Dobermann",
            237: "Miniature Pinscher",
            238: "Greater Swiss Mountain Dog",
            239: "Bernese Mountain Dog",
            240: "Appenzeller Sennenhund",
            241: "Entlebucher Sennenhund",
            242: "Boxer",
            243: "Bullmastiff",
            244: "Tibetan Mastiff",
            245: "French Bulldog",
            246: "Great Dane",
            247: "St. Bernard",
            248: "husky",
            249: "Alaskan Malamute",
            250: "Siberian Husky",
            251: "Dalmatian",
            252: "Affenpinscher",
            253: "Basenji",
            254: "pug",
            255: "Leonberger",
            256: "Newfoundland dog",
            257: "Great Pyrenees dog",
            258: "Samoyed",
            259: "Pomeranian",
            260: "Chow Chow",
            261: "Keeshond",
            262: "brussels griffon",
            263: "Pembroke Welsh Corgi",
            264: "Cardigan Welsh Corgi",
            265: "Toy Poodle",
            266: "Miniature Poodle",
            267: "Standard Poodle",
            268: "Mexican hairless dog (xoloitzcuintli)",
            269: "grey wolf",
            270: "Alaskan tundra wolf",
            271: "red wolf or maned wolf",
            272: "coyote",
            273: "dingo",
            274: "dhole",
            275: "African wild dog",
            276: "hyena",
            277: "red fox",
            278: "kit fox",
            279: "Arctic fox",
            280: "grey fox",
            281: "tabby cat",
            282: "tiger cat",
            283: "Persian cat",
            284: "Siamese cat",
            285: "Egyptian Mau",
            286: "cougar",
            287: "lynx",
            288: "leopard",
            289: "snow leopard",
            290: "jaguar",
            291: "lion",
            292: "tiger",
            293: "cheetah",
            294: "brown bear",
            295: "American black bear",
            296: "polar bear",
            297: "sloth bear",
            298: "mongoose",
            299: "meerkat",
            300: "tiger beetle",
            301: "ladybug",
            302: "ground beetle",
            303: "longhorn beetle",
            304: "leaf beetle",
            305: "dung beetle",
            306: "rhinoceros beetle",
            307: "weevil",
            308: "fly",
            309: "bee",
            310: "ant",
            311: "grasshopper",
            312: "cricket insect",
            313: "stick insect",
            314: "cockroach",
            315: "praying mantis",
            316: "cicada",
            317: "leafhopper",
            318: "lacewing",
            319: "dragonfly",
            320: "damselfly",
            321: "red admiral butterfly",
            322: "ringlet butterfly",
            323: "monarch butterfly",
            324: "small white butterfly",
            325: "sulphur butterfly",
            326: "gossamer-winged butterfly",
            327: "starfish",
            328: "sea urchin",
            329: "sea cucumber",
            330: "cottontail rabbit",
            331: "hare",
            332: "Angora rabbit",
            333: "hamster",
            334: "porcupine",
            335: "fox squirrel",
            336: "marmot",
            337: "beaver",
            338: "guinea pig",
            339: "common sorrel horse",
            340: "zebra",
            341: "pig",
            342: "wild boar",
            343: "warthog",
            344: "hippopotamus",
            345: "ox",
            346: "water buffalo",
            347: "bison",
            348: "ram (adult male sheep)",
            349: "bighorn sheep",
            350: "Alpine ibex",
            351: "hartebeest",
            352: "impala (antelope)",
            353: "gazelle",
            354: "arabian camel",
            355: "llama",
            356: "weasel",
            357: "mink",
            358: "European polecat",
            359: "black-footed ferret",
            360: "otter",
            361: "skunk",
            362: "badger",
            363: "armadillo",
            364: "three-toed sloth",
            365: "orangutan",
            366: "gorilla",
            367: "chimpanzee",
            368: "gibbon",
            369: "siamang",
            370: "guenon",
            371: "patas monkey",
            372: "baboon",
            373: "macaque",
            374: "langur",
            375: "black-and-white colobus",
            376: "proboscis monkey",
            377: "marmoset",
            378: "white-headed capuchin",
            379: "howler monkey",
            380: "titi monkey",
            381: "Geoffroy's spider monkey",
            382: "common squirrel monkey",
            383: "ring-tailed lemur",
            384: "indri",
            385: "Asian elephant",
            386: "African bush elephant",
            387: "red panda",
            388: "giant panda",
            389: "snoek fish",
            390: "eel",
            391: "silver salmon",
            392: "rock beauty fish",
            393: "clownfish",
            394: "sturgeon",
            395: "gar fish",
            396: "lionfish",
            397: "pufferfish",
            398: "abacus",
            399: "abaya",
            400: "academic gown",
            401: "accordion",
            402: "acoustic guitar",
            403: "aircraft carrier",
            404: "airliner",
            405: "airship",
            406: "altar",
            407: "ambulance",
            408: "amphibious vehicle",
            409: "analog clock",
            410: "apiary",
            411: "apron",
            412: "trash can",
            413: "assault rifle",
            414: "backpack",
            415: "bakery",
            416: "balance beam",
            417: "balloon",
            418: "ballpoint pen",
            419: "Band-Aid",
            420: "banjo",
            421: "baluster / handrail",
            422: "barbell",
            423: "barber chair",
            424: "barbershop",
            425: "barn",
            426: "barometer",
            427: "barrel",
            428: "wheelbarrow",
            429: "baseball",
            430: "basketball",
            431: "bassinet",
            432: "bassoon",
            433: "swimming cap",
            434: "bath towel",
            435: "bathtub",
            436: "station wagon",
            437: "lighthouse",
            438: "beaker",
            439: "military hat (bearskin or shako)",
            440: "beer bottle",
            441: "beer glass",
            442: "bell tower",
            443: "baby bib",
            444: "tandem bicycle",
            445: "bikini",
            446: "ring binder",
            447: "binoculars",
            448: "birdhouse",
            449: "boathouse",
            450: "bobsleigh",
            451: "bolo tie",
            452: "poke bonnet",
            453: "bookcase",
            454: "bookstore",
            455: "bottle cap",
            456: "hunting bow",
            457: "bow tie",
            458: "brass memorial plaque",
            459: "bra",
            460: "breakwater",
            461: "breastplate",
            462: "broom",
            463: "bucket",
            464: "buckle",
            465: "bulletproof vest",
            466: "high-speed train",
            467: "butcher shop",
            468: "taxicab",
            469: "cauldron",
            470: "candle",
            471: "cannon",
            472: "canoe",
            473: "can opener",
            474: "cardigan",
            475: "car mirror",
            476: "carousel",
            477: "tool kit",
            478: "cardboard box / carton",
            479: "car wheel",
            480: "automated teller machine",
            481: "cassette",
            482: "cassette player",
            483: "castle",
            484: "catamaran",
            485: "CD player",
            486: "cello",
            487: "mobile phone",
            488: "chain",
            489: "chain-link fence",
            490: "chain mail",
            491: "chainsaw",
            492: "storage chest",
            493: "chiffonier",
            494: "bell or wind chime",
            495: "china cabinet",
            496: "Christmas stocking",
            497: "church",
            498: "movie theater",
            499: "cleaver",
            500: "cliff dwelling",
            501: "cloak",
            502: "clogs",
            503: "cocktail shaker",
            504: "coffee mug",
            505: "coffeemaker",
            506: "spiral or coil",
            507: "combination lock",
            508: "computer keyboard",
            509: "candy store",
            510: "container ship",
            511: "convertible",
            512: "corkscrew",
            513: "cornet",
            514: "cowboy boot",
            515: "cowboy hat",
            516: "cradle",
            517: "construction crane",
            518: "crash helmet",
            519: "crate",
            520: "infant bed",
            521: "Crock Pot",
            522: "croquet ball",
            523: "crutch",
            524: "cuirass",
            525: "dam",
            526: "desk",
            527: "desktop computer",
            528: "rotary dial telephone",
            529: "diaper",
            530: "digital clock",
            531: "digital watch",
            532: "dining table",
            533: "dishcloth",
            534: "dishwasher",
            535: "disc brake",
            536: "dock",
            537: "dog sled",
            538: "dome",
            539: "doormat",
            540: "drilling rig",
            541: "drum",
            542: "drumstick",
            543: "dumbbell",
            544: "Dutch oven",
            545: "electric fan",
            546: "electric guitar",
            547: "electric locomotive",
            548: "entertainment center",
            549: "envelope",
            550: "espresso machine",
            551: "face powder",
            552: "feather boa",
            553: "filing cabinet",
            554: "fireboat",
            555: "fire truck",
            556: "fire screen",
            557: "flagpole",
            558: "flute",
            559: "folding chair",
            560: "football helmet",
            561: "forklift",
            562: "fountain",
            563: "fountain pen",
            564: "four-poster bed",
            565: "freight car",
            566: "French horn",
            567: "frying pan",
            568: "fur coat",
            569: "garbage truck",
            570: "gas mask or respirator",
            571: "gas pump",
            572: "goblet",
            573: "go-kart",
            574: "golf ball",
            575: "golf cart",
            576: "gondola",
            577: "gong",
            578: "gown",
            579: "grand piano",
            580: "greenhouse",
            581: "radiator grille",
            582: "grocery store",
            583: "guillotine",
            584: "hair clip",
            585: "hair spray",
            586: "half-track",
            587: "hammer",
            588: "hamper",
            589: "hair dryer",
            590: "hand-held computer",
            591: "handkerchief",
            592: "hard disk drive",
            593: "harmonica",
            594: "harp",
            595: "combine harvester",
            596: "hatchet",
            597: "holster",
            598: "home theater",
            599: "honeycomb",
            600: "hook",
            601: "hoop skirt",
            602: "gymnastic horizontal bar",
            603: "horse-drawn vehicle",
            604: "hourglass",
            605: "iPod",
            606: "clothes iron",
            607: "carved pumpkin",
            608: "jeans",
            609: "jeep",
            610: "T-shirt",
            611: "jigsaw puzzle",
            612: "rickshaw",
            613: "joystick",
            614: "kimono",
            615: "knee pad",
            616: "knot",
            617: "lab coat",
            618: "ladle",
            619: "lampshade",
            620: "laptop computer",
            621: "lawn mower",
            622: "lens cap",
            623: "letter opener",
            624: "library",
            625: "lifeboat",
            626: "lighter",
            627: "limousine",
            628: "ocean liner",
            629: "lipstick",
            630: "slip-on shoe",
            631: "lotion",
            632: "music speaker",
            633: "loupe magnifying glass",
            634: "sawmill",
            635: "magnetic compass",
            636: "messenger bag",
            637: "mailbox",
            638: "tights",
            639: "one-piece bathing suit",
            640: "manhole cover",
            641: "maraca",
            642: "marimba",
            643: "mask",
            644: "matchstick",
            645: "maypole",
            646: "maze",
            647: "measuring cup",
            648: "medicine cabinet",
            649: "megalith",
            650: "microphone",
            651: "microwave oven",
            652: "military uniform",
            653: "milk can",
            654: "minibus",
            655: "miniskirt",
            656: "minivan",
            657: "missile",
            658: "mitten",
            659: "mixing bowl",
            660: "mobile home",
            661: "ford model t",
            662: "modem",
            663: "monastery",
            664: "monitor",
            665: "moped",
            666: "mortar and pestle",
            667: "graduation cap",
            668: "mosque",
            669: "mosquito net",
            670: "vespa",
            671: "mountain bike",
            672: "tent",
            673: "computer mouse",
            674: "mousetrap",
            675: "moving van",
            676: "muzzle",
            677: "metal nail",
            678: "neck brace",
            679: "necklace",
            680: "baby pacifier",
            681: "notebook computer",
            682: "obelisk",
            683: "oboe",
            684: "ocarina",
            685: "odometer",
            686: "oil filter",
            687: "pipe organ",
            688: "oscilloscope",
            689: "overskirt",
            690: "bullock cart",
            691: "oxygen mask",
            692: "product packet / packaging",
            693: "paddle",
            694: "paddle wheel",
            695: "padlock",
            696: "paintbrush",
            697: "pajamas",
            698: "palace",
            699: "pan flute",
            700: "paper towel",
            701: "parachute",
            702: "parallel bars",
            703: "park bench",
            704: "parking meter",
            705: "railroad car",
            706: "patio",
            707: "payphone",
            708: "pedestal",
            709: "pencil case",
            710: "pencil sharpener",
            711: "perfume",
            712: "Petri dish",
            713: "photocopier",
            714: "plectrum",
            715: "Pickelhaube",
            716: "picket fence",
            717: "pickup truck",
            718: "pier",
            719: "piggy bank",
            720: "pill bottle",
            721: "pillow",
            722: "ping-pong ball",
            723: "pinwheel",
            724: "pirate ship",
            725: "drink pitcher",
            726: "block plane",
            727: "planetarium",
            728: "plastic bag",
            729: "plate rack",
            730: "farm plow",
            731: "plunger",
            732: "Polaroid camera",
            733: "pole",
            734: "police van",
            735: "poncho",
            736: "pool table",
            737: "soda bottle",
            738: "plant pot",
            739: "potter's wheel",
            740: "power drill",
            741: "prayer rug",
            742: "printer",
            743: "prison",
            744: "missile",
            745: "projector",
            746: "hockey puck",
            747: "punching bag",
            748: "purse",
            749: "quill",
            750: "quilt",
            751: "race car",
            752: "racket",
            753: "radiator",
            754: "radio",
            755: "radio telescope",
            756: "rain barrel",
            757: "recreational vehicle",
            758: "fishing casting reel",
            759: "reflex camera",
            760: "refrigerator",
            761: "remote control",
            762: "restaurant",
            763: "revolver",
            764: "rifle",
            765: "rocking chair",
            766: "rotisserie",
            767: "eraser",
            768: "rugby ball",
            769: "ruler measuring stick",
            770: "sneaker",
            771: "safe",
            772: "safety pin",
            773: "salt shaker",
            774: "sandal",
            775: "sarong",
            776: "saxophone",
            777: "scabbard",
            778: "weighing scale",
            779: "school bus",
            780: "schooner",
            781: "scoreboard",
            782: "CRT monitor",
            783: "screw",
            784: "screwdriver",
            785: "seat belt",
            786: "sewing machine",
            787: "shield",
            788: "shoe store",
            789: "shoji screen / room divider",
            790: "shopping basket",
            791: "shopping cart",
            792: "shovel",
            793: "shower cap",
            794: "shower curtain",
            795: "ski",
            796: "balaclava ski mask",
            797: "sleeping bag",
            798: "slide rule",
            799: "sliding door",
            800: "slot machine",
            801: "snorkel",
            802: "snowmobile",
            803: "snowplow",
            804: "soap dispenser",
            805: "soccer ball",
            806: "sock",
            807: "solar thermal collector",
            808: "sombrero",
            809: "soup bowl",
            810: "keyboard space bar",
            811: "space heater",
            812: "space shuttle",
            813: "spatula",
            814: "motorboat",
            815: "spider web",
            816: "spindle",
            817: "sports car",
            818: "spotlight",
            819: "stage",
            820: "steam locomotive",
            821: "through arch bridge",
            822: "steel drum",
            823: "stethoscope",
            824: "scarf",
            825: "stone wall",
            826: "stopwatch",
            827: "stove",
            828: "strainer",
            829: "tram",
            830: "stretcher",
            831: "couch",
            832: "stupa",
            833: "submarine",
            834: "suit",
            835: "sundial",
            836: "sunglasses",
            837: "sunglasses",
            838: "sunscreen",
            839: "suspension bridge",
            840: "mop",
            841: "sweatshirt",
            842: "swim trunks / shorts",
            843: "swing",
            844: "electrical switch",
            845: "syringe",
            846: "table lamp",
            847: "tank",
            848: "tape player",
            849: "teapot",
            850: "teddy bear",
            851: "television",
            852: "tennis ball",
            853: "thatched roof",
            854: "front curtain",
            855: "thimble",
            856: "threshing machine",
            857: "throne",
            858: "tile roof",
            859: "toaster",
            860: "tobacco shop",
            861: "toilet seat",
            862: "torch",
            863: "totem pole",
            864: "tow truck",
            865: "toy store",
            866: "tractor",
            867: "semi-trailer truck",
            868: "tray",
            869: "trench coat",
            870: "tricycle",
            871: "trimaran",
            872: "tripod",
            873: "triumphal arch",
            874: "trolleybus",
            875: "trombone",
            876: "hot tub",
            877: "turnstile",
            878: "typewriter keyboard",
            879: "umbrella",
            880: "unicycle",
            881: "upright piano",
            882: "vacuum cleaner",
            883: "vase",
            884: "vaulted or arched ceiling",
            885: "velvet fabric",
            886: "vending machine",
            887: "vestment",
            888: "viaduct",
            889: "violin",
            890: "volleyball",
            891: "waffle iron",
            892: "wall clock",
            893: "wallet",
            894: "wardrobe",
            895: "military aircraft",
            896: "sink",
            897: "washing machine",
            898: "water bottle",
            899: "water jug",
            900: "water tower",
            901: "whiskey jug",
            902: "whistle",
            903: "hair wig",
            904: "window screen",
            905: "window shade",
            906: "Windsor tie",
            907: "wine bottle",
            908: "airplane wing",
            909: "wok",
            910: "wooden spoon",
            911: "wool",
            912: "split-rail fence",
            913: "shipwreck",
            914: "sailboat",
            915: "yurt",
            916: "website",
            917: "comic book",
            918: "crossword",
            919: "traffic or street sign",
            920: "traffic light",
            921: "dust jacket",
            922: "menu",
            923: "plate",
            924: "guacamole",
            925: "consomme",
            926: "hot pot",
            927: "trifle",
            928: "ice cream",
            929: "popsicle",
            930: "baguette",
            931: "bagel",
            932: "pretzel",
            933: "cheeseburger",
            934: "hot dog",
            935: "mashed potatoes",
            936: "cabbage",
            937: "broccoli",
            938: "cauliflower",
            939: "zucchini",
            940: "spaghetti squash",
            941: "acorn squash",
            942: "butternut squash",
            943: "cucumber",
            944: "artichoke",
            945: "bell pepper",
            946: "cardoon",
            947: "mushroom",
            948: "Granny Smith apple",
            949: "strawberry",
            950: "orange",
            951: "lemon",
            952: "fig",
            953: "pineapple",
            954: "banana",
            955: "jackfruit",
            956: "cherimoya (custard apple)",
            957: "pomegranate",
            958: "hay",
            959: "carbonara",
            960: "chocolate syrup",
            961: "dough",
            962: "meatloaf",
            963: "pizza",
            964: "pot pie",
            965: "burrito",
            966: "red wine",
            967: "espresso",
            968: "tea cup",
            969: "eggnog",
            970: "mountain",
            971: "bubble",
            972: "cliff",
            973: "coral reef",
            974: "geyser",
            975: "lakeshore",
            976: "promontory",
            977: "sandbar",
            978: "beach",
            979: "valley",
            980: "volcano",
            981: "baseball player",
            982: "bridegroom",
            983: "scuba diver",
            984: "rapeseed",
            985: "daisy",
            986: "yellow lady's slipper",
            987: "corn",
            988: "acorn",
            989: "rose hip",
            990: "horse chestnut seed",
            991: "coral fungus",
            992: "agaric",
            993: "gyromitra",
            994: "stinkhorn mushroom",
            995: "earth star fungus",
            996: "hen of the woods mushroom",
            997: "bolete",
            998: "corn cob",
            999: "toilet paper",
        }
zero_shot_prompt_templates property
zero_shot_prompt_templates

Return the zero-shot prompt templates.

id2label property
id2label

Return the label mapping.

__getitem__
__getitem__(index)

Get an example at the given index.

Source code in mmlearn/datasets/imagenet.py
def __getitem__(self, index: int) -> Example:
    """Get an example at the given index."""
    image, target = super().__getitem__(index)
    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: target,
            EXAMPLE_INDEX_KEY: index,
        }
    )
    mask = self.mask_generator() if self.mask_generator else None
    if mask is not None:  # error will be raised during collation if `None`
        example[Modalities.RGB.mask] = mask
    return example

librispeech

LibriSpeech dataset.

LibriSpeech

Bases: Dataset[Example]

LibriSpeech dataset.

This is a wrapper around 🇵🇾class:torchaudio.datasets.LIBRISPEECH that assumes that the dataset is already downloaded and the top-level directory of the dataset in the root directory is librispeech.

Parameters:

Name Type Description Default
root_dir str

Root directory of dataset.

required
split (train - clean - 100, train - clean - 360, train - other - 500, dev - clean, dev - other, test - clean, test - other)

Split of the dataset to use.

"train-clean-100"

Raises:

Type Description
ImportError

If torchaudio is not installed.

Notes

This dataset only returns the audio and transcript from the dataset.

Source code in mmlearn/datasets/librispeech.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LIBRISPEECH_ROOT_DIR", MISSING),
)
class LibriSpeech(Dataset[Example]):
    """LibriSpeech dataset.

    This is a wrapper around :py:class:`torchaudio.datasets.LIBRISPEECH` that assumes
    that the dataset is already downloaded and the top-level directory of the dataset
    in the root directory is `librispeech`.

    Parameters
    ----------
    root_dir : str
        Root directory of dataset.
    split : {"train-clean-100", "train-clean-360", "train-other-500", "dev-clean", "dev-other", "test-clean", "test-other"}, default="train-clean-100"
        Split of the dataset to use.

    Raises
    ------
    ImportError
        If ``torchaudio`` is not installed.

    Notes
    -----
    This dataset only returns the audio and transcript from the dataset.

    """  # noqa: W505

    def __init__(self, root_dir: str, split: str = "train-clean-100") -> None:
        super().__init__()
        if not _TORCHAUDIO_AVAILABLE:
            raise ImportError(
                "LibriSpeech dataset requires `torchaudio`, which is not installed."
            )
        from torchaudio.datasets import LIBRISPEECH

        self.dataset = LIBRISPEECH(
            root=root_dir,
            url=split,
            download=False,
            folder_in_archive="librispeech",
        )

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
        assert sample_rate == SAMPLE_RATE, (
            f"Expected sample rate to be `16000`, got {sample_rate}."
        )
        waveform = pad_or_trim(waveform.flatten())

        return Example(
            {
                Modalities.AUDIO.name: waveform,
                Modalities.TEXT.name: transcript,
                EXAMPLE_INDEX_KEY: idx,
            },
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/librispeech.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.dataset)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/librispeech.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
    assert sample_rate == SAMPLE_RATE, (
        f"Expected sample rate to be `16000`, got {sample_rate}."
    )
    waveform = pad_or_trim(waveform.flatten())

    return Example(
        {
            Modalities.AUDIO.name: waveform,
            Modalities.TEXT.name: transcript,
            EXAMPLE_INDEX_KEY: idx,
        },
    )

pad_or_trim

pad_or_trim(array, length=30 * SAMPLE_RATE, *, axis=-1)

Pad or trim the audio array to length along the given axis.

Parameters:

Name Type Description Default
array Tensor

Audio array.

required
length int

Length to pad or trim to. Defaults to 30 seconds at 16 kHz.

480000
axis int

Axis along which to pad or trim.

-1

Returns:

Name Type Description
array Tensor

Padded or trimmed audio array.

References

.. [1] https://github.com/openai/whisper/blob/main/whisper/audio.py#L65C1-L88C17

Source code in mmlearn/datasets/librispeech.py
def pad_or_trim(
    array: torch.Tensor, length: int = 30 * SAMPLE_RATE, *, axis: int = -1
) -> torch.Tensor:
    """Pad or trim the audio array to `length` along the given axis.

    Parameters
    ----------
    array : torch.Tensor
        Audio array.
    length : int, default=480000
        Length to pad or trim to. Defaults to 30 seconds at 16 kHz.
    axis : int, default=-1
        Axis along which to pad or trim.

    Returns
    -------
    array : torch.Tensor
        Padded or trimmed audio array.

    References
    ----------
    .. [1] https://github.com/openai/whisper/blob/main/whisper/audio.py#L65C1-L88C17

    """
    if array.shape[axis] > length:
        array = array.index_select(
            dim=axis,
            index=torch.arange(length, device=array.device),
        )

    if array.shape[axis] < length:
        pad_widths = [(0, 0)] * array.ndim
        pad_widths[axis] = (0, length - array.shape[axis])
        array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])

    return array

llvip

LLVIP dataset.

LLVIPDataset

Bases: Dataset[Example]

Low-Light Visible-Infrared Pair (LLVIP) dataset.

Loads pairs of RGB and THERMAL images from the LLVIP dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset. The directory should contain 'visible' and 'infrared' subdirectories.

required
train bool

Flag to indicate whether to load the training or test set.

True
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor. This is applied to both RGB and thermal images.

None
Source code in mmlearn/datasets/llvip.py
@store(
    name="LLVIP",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("LLVIP_ROOT_DIR", MISSING),
)
class LLVIPDataset(Dataset[Example]):
    """Low-Light Visible-Infrared Pair (LLVIP) dataset.

    Loads pairs of `RGB` and `THERMAL` images from the LLVIP dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset. The directory should contain
        'visible' and 'infrared' subdirectories.
    train : bool, default=True
        Flag to indicate whether to load the training or test set.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor. This is applied to both RGB and thermal
        images.
    """

    def __init__(
        self,
        root_dir: str,
        train: bool = True,
        transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ):
        self.path_images_rgb = os.path.join(
            root_dir,
            "visible",
            "train" if train else "test",
        )
        self.path_images_ir = os.path.join(
            root_dir, "infrared", "train" if train else "test"
        )
        self.train = train
        self.transform = transform or transforms.ToTensor()

        self.rgb_images = sorted(glob.glob(os.path.join(self.path_images_rgb, "*.jpg")))
        self.ir_images = sorted(glob.glob(os.path.join(self.path_images_ir, "*.jpg")))

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.rgb_images)

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the dataset."""
        rgb_image_path = self.rgb_images[idx]
        ir_image_path = self.ir_images[idx]

        rgb_image = PILImage.open(rgb_image_path).convert("RGB")
        ir_image = PILImage.open(ir_image_path).convert("L")

        example = Example(
            {
                Modalities.RGB.name: self.transform(rgb_image),
                Modalities.THERMAL.name: self.transform(ir_image),
                EXAMPLE_INDEX_KEY: idx,
            },
        )

        if self.train:
            annot_path = (
                rgb_image_path.replace("visible", "Annotations")
                .replace(".jpg", ".xml")
                .replace("train", "")
            )
            annot = self._get_bbox(annot_path)
            example["annotation"] = {
                "bboxes": torch.from_numpy(annot["bboxes"]),
                "labels": torch.from_numpy(annot["labels"]),
            }
        return example

    def _get_bbox(self, filename: str) -> dict[str, np.ndarray]:
        """Parse the XML file to get bounding boxes and labels.

        Parameters
        ----------
        filename : str
            Path to the annotation XML file.

        Returns
        -------
        dict
            A dictionary containing bounding boxes and labels.
        """
        try:
            root = ET.parse(filename).getroot()

            bboxes, labels = [], []
            for obj in root.findall("object"):
                bbox_obj = obj.find("bndbox")
                bbox = [
                    int(bbox_obj.find(dim).text)  # type: ignore[union-attr,arg-type]
                    for dim in ["xmin", "ymin", "xmax", "ymax"]
                ]
                bboxes.append(bbox)
                labels.append(1)  # Assuming 'person' is the only label
            return {
                "bboxes": np.array(bboxes).astype("float"),
                "labels": np.array(labels).astype("int"),
            }
        except ET.ParseError as e:
            raise ValueError(f"Error parsing XML: {e}") from None
        except Exception as e:
            raise RuntimeError(
                f"Error processing annotation file {filename}: {e}",
            ) from None
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/llvip.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.rgb_images)
__getitem__
__getitem__(idx)

Return an example from the dataset.

Source code in mmlearn/datasets/llvip.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the dataset."""
    rgb_image_path = self.rgb_images[idx]
    ir_image_path = self.ir_images[idx]

    rgb_image = PILImage.open(rgb_image_path).convert("RGB")
    ir_image = PILImage.open(ir_image_path).convert("L")

    example = Example(
        {
            Modalities.RGB.name: self.transform(rgb_image),
            Modalities.THERMAL.name: self.transform(ir_image),
            EXAMPLE_INDEX_KEY: idx,
        },
    )

    if self.train:
        annot_path = (
            rgb_image_path.replace("visible", "Annotations")
            .replace(".jpg", ".xml")
            .replace("train", "")
        )
        annot = self._get_bbox(annot_path)
        example["annotation"] = {
            "bboxes": torch.from_numpy(annot["bboxes"]),
            "labels": torch.from_numpy(annot["labels"]),
        }
    return example

nihcxr

NIH Chest X-ray Dataset.

NIHCXR

Bases: Dataset[Example]

NIH Chest X-ray dataset.

Parameters:

Name Type Description Default
root_dir str

Directory which contains .json files stating all dataset entries.

required
split (train, test, bbox)

Dataset split. "bbox" is a subset of "test" which contains bounding box info.

"train"
transform Optional[Callable[[Image], Tensor]]

A callable that takes in a PIL image and returns a transformed version of the image as a PyTorch tensor.

None
Source code in mmlearn/datasets/nihcxr.py
@store(
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NIH_CXR_DIR", MISSING),
    split="train",
)
class NIHCXR(Dataset[Example]):
    """NIH Chest X-ray dataset.

    Parameters
    ----------
    root_dir : str
        Directory which contains `.json` files stating all dataset entries.
    split : {"train", "test", "bbox"}
        Dataset split. "bbox" is a subset of "test" which contains bounding box info.
    transform : Optional[Callable[[PIL.Image], torch.Tensor]], optional, default=None
        A callable that takes in a PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test", "bbox"],
        transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
    ) -> None:
        assert split in ["train", "test", "bbox"], f"split {split} is not available."
        assert callable(transform) or transform is None, (
            "transform is not callable or None."
        )

        data_path = os.path.join(root_dir, split + "_data.json")

        assert os.path.isfile(data_path), f"entries file does not exist: {data_path}."

        with open(data_path, "rb") as file:
            entries = json.load(file)
        self.entries = entries

        if transform is not None:
            self.transform = transform
        else:
            self.transform = Compose([Resize(224), CenterCrop(224), ToTensor()])

        self.bbox = split == "bbox"

    def __getitem__(self, idx: int) -> Example:
        """Return image-label or image-label-tabular(bbox)."""
        entry = self.entries[idx]
        image = Image.open(entry["image_path"]).convert("RGB")
        image = self.transform(image)
        label = torch.tensor(entry["label"])

        example = Example(
            {
                Modalities.RGB.name: image,
                Modalities.RGB.target: label,
                "qid": entry["qid"],
                EXAMPLE_INDEX_KEY: idx,
            }
        )

        if self.bbox:
            example["bbox"] = entry["bbox"]

        return example

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.entries)
__getitem__
__getitem__(idx)

Return image-label or image-label-tabular(bbox).

Source code in mmlearn/datasets/nihcxr.py
def __getitem__(self, idx: int) -> Example:
    """Return image-label or image-label-tabular(bbox)."""
    entry = self.entries[idx]
    image = Image.open(entry["image_path"]).convert("RGB")
    image = self.transform(image)
    label = torch.tensor(entry["label"])

    example = Example(
        {
            Modalities.RGB.name: image,
            Modalities.RGB.target: label,
            "qid": entry["qid"],
            EXAMPLE_INDEX_KEY: idx,
        }
    )

    if self.bbox:
        example["bbox"] = entry["bbox"]

    return example
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nihcxr.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.entries)

nyuv2

SUN RGB-D dataset.

NYUv2Dataset

Bases: Dataset[Example]

NYUv2 dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images.

  • "disparity": Return the depth image as disparity map.
  • "image": Return the depth image as a 3-channel image.
"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None

Raises:

Type Description
ImportError

If opencv-python is not installed.

Source code in mmlearn/datasets/nyuv2.py
@store(
    name="NYUv2",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("NYUV2_ROOT_DIR", MISSING),
)
class NYUv2Dataset(Dataset[Example]):
    """NYUv2 dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images.

        - `"disparity"`: Return the depth image as disparity map.
        - `"image"`: Return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    Raises
    ------
    ImportError
        If `opencv-python` is not installed.
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "NYUv2 dataset requires `opencv-python` which is not installed.",
            )
        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.png") for f in file_ids]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = [str(open(f).read().strip()) for f in label_files]  # noqa: SIM115
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]

        self.samples = list(zip(rgb_files, depth_files, labels, strict=False))

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = depth_normalize(
                self.samples[idx][1],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/nyuv2.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/nyuv2.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = depth_normalize(
            self.samples[idx][1],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

depth_normalize

depth_normalize(depth_file, min_depth=0.01, max_depth=50)

Load depth file and convert to disparity image.

Parameters:

Name Type Description Default
depth_file str

Path to the depth file.

required
min_depth float

Minimum depth value to clip the depth image.

0.01
max_depth int

Maximum depth value to clip the depth image.

50

Returns:

Type Description
Tensor

The normalized depth image.

Source code in mmlearn/datasets/nyuv2.py
def depth_normalize(
    depth_file: str, min_depth: float = 0.01, max_depth: int = 50
) -> torch.Tensor:
    """Load depth file and convert to disparity image.

    Parameters
    ----------
    depth_file : str
        Path to the depth file.
    min_depth : float, default=0.01
        Minimum depth value to clip the depth image.
    max_depth : int, default=50
        Maximum depth value to clip the depth image.

    Returns
    -------
    torch.Tensor
        The normalized depth image.
    """
    depth_image = np.array(PILImage.open(depth_file))
    depth = np.array(depth_image).astype(np.float32)
    depth_in_meters = depth / 1000.0

    if min_depth is not None:
        depth_in_meters = depth_in_meters.clip(min=min_depth, max=max_depth)

    return torch.from_numpy(depth_in_meters).float()

processors

Data processors.

BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence

masking

Token mask generators.

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices
BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask
IJEPAMaskGenerator dataclass

Generates encoder and predictor masks for preprocessing.

This class generates masks dynamically for batches of examples.

Parameters:

Name Type Description Default
input_size tuple[int, int]

Input image size.

(224, 224)
patch_size int

Size of each patch.

16
min_keep int

Minimum number of patches to keep.

10
allow_overlap bool

Whether to allow overlap between encoder and predictor masks.

False
enc_mask_scale tuple[float, float]

Scale range for encoder mask.

(0.85, 1.0)
pred_mask_scale tuple[float, float]

Scale range for predictor mask.

(0.15, 0.2)
aspect_ratio tuple[float, float]

Aspect ratio range for mask blocks.

(0.75, 1.0)
nenc int

Number of encoder masks to generate.

1
npred int

Number of predictor masks to generate.

4
Source code in mmlearn/datasets/processors/masking.py
@dataclass
class IJEPAMaskGenerator:
    """Generates encoder and predictor masks for preprocessing.

    This class generates masks dynamically for batches of examples.

    Parameters
    ----------
    input_size : tuple[int, int], default=(224, 224)
        Input image size.
    patch_size : int, default=16
        Size of each patch.
    min_keep : int, default=10
        Minimum number of patches to keep.
    allow_overlap : bool, default=False
        Whether to allow overlap between encoder and predictor masks.
    enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
        Scale range for encoder mask.
    pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
        Scale range for predictor mask.
    aspect_ratio : tuple[float, float], default=(0.75, 1.0)
        Aspect ratio range for mask blocks.
    nenc : int, default=1
        Number of encoder masks to generate.
    npred : int, default=4
        Number of predictor masks to generate.
    """

    input_size: tuple[int, int] = (224, 224)
    patch_size: int = 16
    min_keep: int = 10
    allow_overlap: bool = False
    enc_mask_scale: tuple[float, float] = (0.85, 1.0)
    pred_mask_scale: tuple[float, float] = (0.15, 0.2)
    aspect_ratio: tuple[float, float] = (0.75, 1.5)
    nenc: int = 1
    npred: int = 4

    def __post_init__(self) -> None:
        """Initialize the mask generator."""
        self.height = self.input_size[0] // self.patch_size
        self.width = self.input_size[1] // self.patch_size

    def _sample_block_size(
        self,
        generator: torch.Generator,
        scale: tuple[float, float],
        aspect_ratio: tuple[float, float],
    ) -> tuple[int, int]:
        """Sample the size of the mask block based on scale and aspect ratio."""
        _rand = torch.rand(1, generator=generator).item()
        min_s, max_s = scale
        mask_scale = min_s + _rand * (max_s - min_s)
        max_keep = int(self.height * self.width * mask_scale)

        min_ar, max_ar = aspect_ratio
        aspect_ratio_val = min_ar + _rand * (max_ar - min_ar)

        h = int(round(math.sqrt(max_keep * aspect_ratio_val)))
        w = int(round(math.sqrt(max_keep / aspect_ratio_val)))

        h = min(h, self.height - 1)
        w = min(w, self.width - 1)

        return h, w

    def _sample_block_mask(
        self, b_size: tuple[int, int]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Sample a mask block."""
        h, w = b_size
        top = torch.randint(0, self.height - h, (1,)).item()
        left = torch.randint(0, self.width - w, (1,)).item()
        mask = torch.zeros((self.height, self.width), dtype=torch.int32)
        mask[top : top + h, left : left + w] = 1

        mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
        mask_complement[top : top + h, left : left + w] = 0

        return mask.flatten(), mask_complement.flatten()

    def __call__(self, batch_size: int = 1) -> dict[str, Any]:
        """Generate encoder and predictor masks for a batch of examples.

        Parameters
        ----------
        batch_size : int, default=1
            The batch size for which to generate masks.

        Returns
        -------
        dict[str, Any]
            A dictionary of encoder masks and predictor masks.
        """
        seed = torch.randint(
            0, 2**32, (1,)
        ).item()  # Sample random seed for reproducibility
        g = torch.Generator().manual_seed(seed)

        # Sample block sizes
        p_size = self._sample_block_size(
            generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
        )
        e_size = self._sample_block_size(
            generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
        )

        # Generate predictor masks
        masks_pred, masks_enc = [], []
        for _ in range(self.npred):
            mask_p, _ = self._sample_block_mask(p_size)
            # Expand mask to match batch size
            mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
            masks_pred.append(mask_p)

        # Generate encoder masks
        for _ in range(self.nenc):
            mask_e, _ = self._sample_block_mask(e_size)
            # Expand mask to match batch size
            mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
            masks_enc.append(mask_e)

        return {
            "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
            "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
        }
__post_init__
__post_init__()

Initialize the mask generator.

Source code in mmlearn/datasets/processors/masking.py
def __post_init__(self) -> None:
    """Initialize the mask generator."""
    self.height = self.input_size[0] // self.patch_size
    self.width = self.input_size[1] // self.patch_size
__call__
__call__(batch_size=1)

Generate encoder and predictor masks for a batch of examples.

Parameters:

Name Type Description Default
batch_size int

The batch size for which to generate masks.

1

Returns:

Type Description
dict[str, Any]

A dictionary of encoder masks and predictor masks.

Source code in mmlearn/datasets/processors/masking.py
def __call__(self, batch_size: int = 1) -> dict[str, Any]:
    """Generate encoder and predictor masks for a batch of examples.

    Parameters
    ----------
    batch_size : int, default=1
        The batch size for which to generate masks.

    Returns
    -------
    dict[str, Any]
        A dictionary of encoder masks and predictor masks.
    """
    seed = torch.randint(
        0, 2**32, (1,)
    ).item()  # Sample random seed for reproducibility
    g = torch.Generator().manual_seed(seed)

    # Sample block sizes
    p_size = self._sample_block_size(
        generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
    )
    e_size = self._sample_block_size(
        generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
    )

    # Generate predictor masks
    masks_pred, masks_enc = [], []
    for _ in range(self.npred):
        mask_p, _ = self._sample_block_mask(p_size)
        # Expand mask to match batch size
        mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
        masks_pred.append(mask_p)

    # Generate encoder masks
    for _ in range(self.nenc):
        mask_e, _ = self._sample_block_mask(e_size)
        # Expand mask to match batch size
        mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
        masks_enc.append(mask_e)

    return {
        "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
        "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
    }
apply_masks
apply_masks(x, masks)

Apply masks to the input tensor by selecting the patches to keep based on the masks.

This function is primarily intended to be used for the 🇵🇾class:i-JEPA <mmlearn.tasks.ijepa.IJEPA>.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, N, D).

required
masks Union[Tensor, list[Tensor]]

A list of mask tensors of shape (N,), (1, N), or (B, N).

required

Returns:

Type Description
Tensor

The masked tensor where only the patches indicated by the masks are kept. The output tensor has shape (B * num_masks, N', D), where N' is the number of patches kept.

Source code in mmlearn/datasets/processors/masking.py
def apply_masks(
    x: torch.Tensor, masks: Union[torch.Tensor, list[torch.Tensor]]
) -> torch.Tensor:
    """
    Apply masks to the input tensor by selecting the patches to keep based on the masks.

    This function is primarily intended to be used for the
    :py:class:`i-JEPA <mmlearn.tasks.ijepa.IJEPA>`.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(B, N, D)``.
    masks : Union[torch.Tensor, list[torch.Tensor]]
        A list of mask tensors of shape ``(N,)``, ``(1, N)``, or ``(B, N)``.

    Returns
    -------
    torch.Tensor
        The masked tensor where only the patches indicated by the masks are kept.
        The output tensor has shape ``(B * num_masks, N', D)``, where ``N'`` is
        the number of patches kept.
    """
    all_x = []
    batch_size = x.size(0)
    for m_ in masks:
        m = m_.to(x.device)

        # Ensure mask is at least 2D
        if m.dim() == 1:
            m = m.unsqueeze(0)  # Shape: (1, N)

        # Expand mask to match the batch size if needed
        if m.size(0) == 1 and batch_size > 1:
            m = m.expand(batch_size, -1)  # Shape: (B, N)

        # Expand mask to match x's dimensions
        m_expanded = (
            m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
        )  # Shape: (B, N, D)

        # Use boolean indexing
        selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
        all_x.append(selected_patches)

    # Concatenate along the batch dimension
    return torch.cat(all_x, dim=0)

tokenizers

Tokenizers - modules that convert raw input to sequences of tokens.

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)
Img2Seq

Bases: Module

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
img_size tuple of int

The size of the input image.

required
patch_size tuple of int

The size of the patch.

required
n_channels int

The number of channels in the input image.

required
d_model int

The dimension of the output sequence.

required
Source code in mmlearn/datasets/processors/tokenizers.py
class Img2Seq(nn.Module):
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    img_size : tuple of int
        The size of the input image.
    patch_size : tuple of int
        The size of the patch.
    n_channels : int
        The number of channels in the input image.
    d_model : int
        The dimension of the output sequence.

    """

    def __init__(
        self,
        img_size: tuple[int, int],
        patch_size: tuple[int, int],
        n_channels: int,
        d_model: int,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        n_tokens = nh * nw

        token_dim = patch_size[0] * patch_size[1] * n_channels
        self.linear = nn.Linear(token_dim, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_emb = nn.Parameter(torch.randn(n_tokens, d_model))

    def __call__(self, batch: torch.Tensor) -> torch.Tensor:
        """Convert a batch of images to a batch of sequences.

        Parameters
        ----------
        batch : torch.Tensor
            Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
            ``h`` is the height, ``w`` is the width, and ``c`` is the number of
            channels.

        Returns
        -------
        torch.Tensor
            Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
            ``s`` is the sequence length, and ``d`` is the dimension of the output
            sequence.
        """
        batch = _patchify(batch, self.patch_size)

        b, c, nh, nw, ph, pw = batch.shape

        # Flattening the patches
        batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
        batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

        batch = self.linear(batch)
        cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
        emb: torch.Tensor = batch + self.pos_emb

        return torch.cat([cls, emb], axis=1)
__call__
__call__(batch)

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
batch Tensor

Batch of images of shape (b, h, w, c) where b is the batch size, h is the height, w is the width, and c is the number of channels.

required

Returns:

Type Description
Tensor

Batch of sequences of shape (b, s, d) where b is the batch size, s is the sequence length, and d is the dimension of the output sequence.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(self, batch: torch.Tensor) -> torch.Tensor:
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    batch : torch.Tensor
        Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
        ``h`` is the height, ``w`` is the width, and ``c`` is the number of
        channels.

    Returns
    -------
    torch.Tensor
        Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
        ``s`` is the sequence length, and ``d`` is the dimension of the output
        sequence.
    """
    batch = _patchify(batch, self.patch_size)

    b, c, nh, nw, ph, pw = batch.shape

    # Flattening the patches
    batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
    batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

    batch = self.linear(batch)
    cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
    emb: torch.Tensor = batch + self.pos_emb

    return torch.cat([cls, emb], axis=1)

transforms

Custom transforms for datasets/inputs.

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence
repeat_interleave_batch
repeat_interleave_batch(x, b, repeat)

Repeat and interleave a tensor across the batch dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor to be repeated.

required
b int

Size of the batch to be repeated.

required
repeat int

Number of times to repeat each batch.

required

Returns:

Type Description
Tensor

The repeated tensor with shape adjusted for the batch.

Source code in mmlearn/datasets/processors/transforms.py
def repeat_interleave_batch(x: torch.Tensor, b: int, repeat: int) -> torch.Tensor:
    """Repeat and interleave a tensor across the batch dimension.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor to be repeated.
    b : int
        Size of the batch to be repeated.
    repeat : int
        Number of times to repeat each batch.

    Returns
    -------
    torch.Tensor
        The repeated tensor with shape adjusted for the batch.
    """
    n = len(x) // b
    return torch.cat(
        [
            torch.cat([x[i * b : (i + 1) * b] for _ in range(repeat)], dim=0)
            for i in range(n)
        ],
        dim=0,
    )

sunrgbd

SUN RGB-D dataset.

SUNRGBDDataset

Bases: Dataset[Example]

SUN RGB-D dataset.

Parameters:

Name Type Description Default
root_dir str

Path to the root directory of the dataset.

required
split (train, test)

Split of the dataset to use.

"train"
return_type (disparity, image)

Return type of the depth images. If "disparity", the depth images are converted to disparity similar to the ImageBind implementation. Otherwise, return the depth image as a 3-channel image.

"disparity"
rgb_transform Optional[Callable[[Image], Tensor]]

A callable that takes in an RGB PIL image and returns a transformed version of the image as a PyTorch tensor.

None
depth_transform Optional[Callable[[Image], Tensor]]

A callable that takes in a depth PIL image and returns a transformed version of the image as a PyTorch tensor.

None
References

.. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets

Source code in mmlearn/datasets/sunrgbd.py
@store(
    name="SUNRGBD",
    group="datasets",
    provider="mmlearn",
    root_dir=os.getenv("SUNRGBD_ROOT_DIR", MISSING),
)
class SUNRGBDDataset(Dataset[Example]):
    """SUN RGB-D dataset.

    Parameters
    ----------
    root_dir : str
        Path to the root directory of the dataset.
    split : {"train", "test"}, default="train"
        Split of the dataset to use.
    return_type : {"disparity", "image"}, default="disparity"
        Return type of the depth images. If "disparity", the depth images are
        converted to disparity similar to the ImageBind implementation.
        Otherwise, return the depth image as a 3-channel image.
    rgb_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in an RGB PIL image and returns a transformed version
        of the image as a PyTorch tensor.
    depth_transform: Callable[[PIL.Image], torch.Tensor], default=None
        A callable that takes in a depth PIL image and returns a transformed version
        of the image as a PyTorch tensor.

    References
    ----------
    .. [1] Repo followed to extract the dataset: https://github.com/TUI-NICR/nicr-scene-analysis-datasets
    """

    def __init__(
        self,
        root_dir: str,
        split: Literal["train", "test"] = "train",
        return_type: Literal["disparity", "image"] = "disparity",
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
    ) -> None:
        super().__init__()
        if not _OPENCV_AVAILABLE:
            raise ImportError(
                "SUN RGB-D dataset requires `opencv-python` which is not installed.",
            )

        self._validate_args(root_dir, split, rgb_transform, depth_transform)
        self.return_type = return_type

        self.root_dir = root_dir
        with open(os.path.join(root_dir, f"{split}.txt"), "r") as f:
            file_ids = f.readlines()
        file_ids = [f.strip() for f in file_ids]

        root_dir = os.path.join(root_dir, split)
        depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
        rgb_files = [os.path.join(root_dir, "rgb", f"{f}.jpg") for f in file_ids]
        intrinsic_files = [
            os.path.join(root_dir, "intrinsics", f"{f}.txt") for f in file_ids
        ]

        sensor_types = [
            file.removeprefix(os.path.join(root_dir, "depth")).split(os.sep)[1]
            for file in depth_files
        ]

        label_files = [
            os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids
        ]
        labels = []
        for label_file in label_files:
            with open(label_file, "r") as file:  # noqa: SIM115
                labels.append(file.read().strip())
        labels = [label.replace("_", " ") for label in labels]
        labels = [
            _LABELS.index(label) if label in _LABELS else len(_LABELS)  # type: ignore
            for label in labels
        ]

        # remove the samples with classes not in _LABELS
        # this is to follow the same classes used in ImageBind
        if split == "test":
            valid_indices = [
                i
                for i, label in enumerate(labels)
                if label < len(_LABELS)  # type: ignore
            ]
            rgb_files = [rgb_files[i] for i in valid_indices]
            depth_files = [depth_files[i] for i in valid_indices]
            labels = [labels[i] for i in valid_indices]
            intrinsic_files = [intrinsic_files[i] for i in valid_indices]
            sensor_types = [sensor_types[i] for i in valid_indices]

        self.samples = list(
            zip(
                rgb_files,
                depth_files,
                labels,
                intrinsic_files,
                sensor_types,
                strict=False,
            )
        )

        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.samples)

    def _validate_args(
        self,
        root_dir: str,
        split: str,
        rgb_transform: Optional[Callable[[PILImage], torch.Tensor]],
        depth_transform: Optional[Callable[[PILImage], torch.Tensor]],
    ) -> None:
        """Validate arguments."""
        if not os.path.isdir(root_dir):
            raise NotADirectoryError(
                f"The given `root_dir` {root_dir} is not a directory",
            )
        if split not in ["train", "test"]:
            raise ValueError(
                f"Expected `split` to be one of `'train'` or `'test'`, but got {split}",
            )
        if rgb_transform is not None and not callable(rgb_transform):
            raise TypeError(
                f"Expected argument `rgb_transform` to be callable, but got {type(rgb_transform)}",
            )
        if depth_transform is not None and not callable(depth_transform):
            raise TypeError(
                f"Expected `depth_transform` to be callable, but got {type(depth_transform)}",
            )

    def __getitem__(self, idx: int) -> Example:
        """Return RGB and depth images at index `idx`."""
        # Read images
        rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
        if self.rgb_transform is not None:
            rgb_image = self.rgb_transform(to_pil_image(rgb_image))

        if self.return_type == "disparity":
            depth_image = convert_depth_to_disparity(
                self.samples[idx][1],
                self.samples[idx][3],
                self.samples[idx][4],
            )
        else:
            # Using cv2 instead of PIL Image since we use PNG grayscale images.
            depth_image = cv2.imread(
                self.samples[idx][1],
                cv2.IMREAD_GRAYSCALE,
            )
            # Make a 3-channel depth image to enable passing to a pretrained ViT.
            depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

        if self.depth_transform is not None:
            depth_image = self.depth_transform(to_pil_image(depth_image))

        return Example(
            {
                Modalities.RGB.name: rgb_image,
                Modalities.DEPTH.name: depth_image,
                EXAMPLE_INDEX_KEY: idx,
                Modalities.DEPTH.target: self.samples[idx][2],
            }
        )
__len__
__len__()

Return the length of the dataset.

Source code in mmlearn/datasets/sunrgbd.py
def __len__(self) -> int:
    """Return the length of the dataset."""
    return len(self.samples)
__getitem__
__getitem__(idx)

Return RGB and depth images at index idx.

Source code in mmlearn/datasets/sunrgbd.py
def __getitem__(self, idx: int) -> Example:
    """Return RGB and depth images at index `idx`."""
    # Read images
    rgb_image = cv2.imread(self.samples[idx][0], cv2.IMREAD_UNCHANGED)
    if self.rgb_transform is not None:
        rgb_image = self.rgb_transform(to_pil_image(rgb_image))

    if self.return_type == "disparity":
        depth_image = convert_depth_to_disparity(
            self.samples[idx][1],
            self.samples[idx][3],
            self.samples[idx][4],
        )
    else:
        # Using cv2 instead of PIL Image since we use PNG grayscale images.
        depth_image = cv2.imread(
            self.samples[idx][1],
            cv2.IMREAD_GRAYSCALE,
        )
        # Make a 3-channel depth image to enable passing to a pretrained ViT.
        depth_image = np.repeat(depth_image[:, :, np.newaxis], 3, axis=-1)

    if self.depth_transform is not None:
        depth_image = self.depth_transform(to_pil_image(depth_image))

    return Example(
        {
            Modalities.RGB.name: rgb_image,
            Modalities.DEPTH.name: depth_image,
            EXAMPLE_INDEX_KEY: idx,
            Modalities.DEPTH.target: self.samples[idx][2],
        }
    )

convert_depth_to_disparity

convert_depth_to_disparity(
    depth_file,
    intrinsics_file,
    sensor_type,
    min_depth=0.01,
    max_depth=50,
)

Load depth file and convert to disparity.

Parameters:

Name Type Description Default
depth_file str

Path to the depth file.

required
intrinsics_file str

Intrinsics_file is a txt file supplied in SUNRGBD with sensor information Can be found at the path: os.path.join(root_dir, room_name, "intrinsics.txt")

required
sensor_type str

Sensor type of the depth file.

required
min_depth float

Minimum depth value to clip the depth image.

0.01
max_depth int

Maximum depth value to clip the depth image.

50

Returns:

Type Description
Tensor

Disparity image from the depth image following the ImageBind implementation.

Source code in mmlearn/datasets/sunrgbd.py
def convert_depth_to_disparity(
    depth_file: str,
    intrinsics_file: str,
    sensor_type: str,
    min_depth: float = 0.01,
    max_depth: int = 50,
) -> torch.Tensor:
    """Load depth file and convert to disparity.

    Parameters
    ----------
    depth_file : str
        Path to the depth file.
    intrinsics_file : str
        Intrinsics_file is a txt file supplied in SUNRGBD with sensor information
        Can be found at the path: os.path.join(root_dir, room_name, "intrinsics.txt")
    sensor_type : str
        Sensor type of the depth file.
    min_depth : float, default=0.01
        Minimum depth value to clip the depth image.
    max_depth : int, default=50
        Maximum depth value to clip the depth image.

    Returns
    -------
    torch.Tensor
        Disparity image from the depth image following the ImageBind implementation.
    """
    with open(intrinsics_file, "r") as fh:
        lines = fh.readlines()
        focal_length = float(lines[0].strip().split()[0])
    baseline = sensor_to_params[sensor_type]["baseline"]
    depth_image = np.array(PILImage.open(depth_file))
    depth = np.array(depth_image).astype(np.float32)
    depth_in_meters = depth / 1000.0
    if min_depth is not None:
        depth_in_meters = depth_in_meters.clip(min=min_depth, max=max_depth)
    disparity = baseline * focal_length / depth_in_meters
    return torch.from_numpy(disparity).float()

Core Datasets Components

mmlearn.datasets.core

Modules for core dataloading functionality.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]

__getitem__

__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example

__len__

__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch

__call__

__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)

create_ids

create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )

__getattr__

__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None

__setattr__

__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value

__setitem__

__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)

num_samples property

num_samples

Return the number of samples managed by the sampler.

total_size property

total_size

Return the total size of the dataset.

__iter__

__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)

__len__

__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples

set_epoch

set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)

DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

total_size property

total_size

Return the total size of the dataset.

num_samples property

num_samples

Return the number of samples managed by the sampler.

__iter__

__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)

__len__

__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples

set_epoch

set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

find_matching_indices

find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices

combined_dataset

Wrapper for combining multiple datasets into one.

CombinedDataset

Bases: Dataset[Example]

Combine multiple datasets into one.

This class is similar to 🇵🇾class:~torch.utils.data.ConcatDataset but allows for combining iterable-style datasets with map-style datasets. The iterable-style datasets must implement the :meth:__len__ method, which is used to determine the total length of the combined dataset. When an index is passed to the combined dataset, the dataset that contains the example at that index is determined and the example is retrieved from that dataset. Since iterable-style datasets do not support random access, the examples are retrieved sequentially from the iterable-style datasets. When the end of an iterable-style dataset is reached, the iterator is reset and the next example is retrieved from the beginning of the dataset.

Parameters:

Name Type Description Default
datasets Iterable[Union[Dataset, IterableDataset]]

Iterable of datasets to combine.

required

Raises:

Type Description
TypeError

If any of the datasets in the input iterable are not instances of 🇵🇾class:~torch.utils.data.Dataset or 🇵🇾class:~torch.utils.data.IterableDataset.

ValueError

If the input iterable of datasets is empty.

Source code in mmlearn/datasets/core/combined_dataset.py
class CombinedDataset(Dataset[Example]):
    """Combine multiple datasets into one.

    This class is similar to :py:class:`~torch.utils.data.ConcatDataset` but allows
    for combining iterable-style datasets with map-style datasets. The iterable-style
    datasets must implement the :meth:`__len__` method, which is used to determine the
    total length of the combined dataset. When an index is passed to the combined
    dataset, the dataset that contains the example at that index is determined and
    the example is retrieved from that dataset. Since iterable-style datasets do
    not support random access, the examples are retrieved sequentially from the
    iterable-style datasets. When the end of an iterable-style dataset is reached,
    the iterator is reset and the next example is retrieved from the beginning of
    the dataset.


    Parameters
    ----------
    datasets : Iterable[Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]]
        Iterable of datasets to combine.

    Raises
    ------
    TypeError
        If any of the datasets in the input iterable are not instances of
        :py:class:`~torch.utils.data.Dataset` or :py:class:`~torch.utils.data.IterableDataset`.
    ValueError
        If the input iterable of datasets is empty.

    """  # noqa: W505

    def __init__(
        self, datasets: Iterable[Union[Dataset[Example], IterableDataset[Example]]]
    ) -> None:
        self.datasets, _ = tree_flatten(datasets)
        if not all(
            isinstance(dataset, (Dataset, IterableDataset)) for dataset in self.datasets
        ):
            raise TypeError(
                "Expected argument `datasets` to be an iterable of `Dataset` or "
                f"`IterableDataset` instances, but found: {self.datasets}",
            )
        if len(self.datasets) == 0:
            raise ValueError(
                "Expected a non-empty iterable of datasets but found an empty iterable",
            )

        self._cumulative_sizes: list[int] = np.cumsum(
            [len(dataset) for dataset in self.datasets]
        ).tolist()
        self._iterators: list[Iterator[Example]] = []
        self._iter_dataset_mapping: dict[int, int] = {}

        # create iterators for iterable datasets and map dataset index to iterator index
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, IterableDataset):
                self._iterators.append(iter(dataset))
                self._iter_dataset_mapping[idx] = len(self._iterators) - 1

    def __getitem__(self, idx: int) -> Example:
        """Return an example from the combined dataset."""
        if idx < 0:  # handle negative indices
            if -idx > len(self):
                raise IndexError(
                    f"Index {idx} is out of bounds for the combined dataset with "
                    f"length {len(self)}",
                )
            idx = len(self) + idx

        dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

        curr_dataset = self.datasets[dataset_idx]
        if isinstance(curr_dataset, IterableDataset):
            iter_idx = self._iter_dataset_mapping[dataset_idx]
            try:
                example = next(self._iterators[iter_idx])
            except StopIteration:
                self._iterators[iter_idx] = iter(curr_dataset)
                example = next(self._iterators[iter_idx])
        else:
            if dataset_idx == 0:
                example_idx = idx
            else:
                example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
            example = curr_dataset[example_idx]

        if not isinstance(example, Example):
            raise TypeError(
                "Expected dataset examples to be instances of `Example` "
                f"but found {type(example)}",
            )

        if not hasattr(example, "dataset_index"):
            example.dataset_index = dataset_idx
        if not hasattr(example, "example_ids"):
            example.create_ids()

        return example

    def __len__(self) -> int:
        """Return the total number of examples in the combined dataset."""
        return self._cumulative_sizes[-1]
__getitem__
__getitem__(idx)

Return an example from the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __getitem__(self, idx: int) -> Example:
    """Return an example from the combined dataset."""
    if idx < 0:  # handle negative indices
        if -idx > len(self):
            raise IndexError(
                f"Index {idx} is out of bounds for the combined dataset with "
                f"length {len(self)}",
            )
        idx = len(self) + idx

    dataset_idx = bisect.bisect_right(self._cumulative_sizes, idx)

    curr_dataset = self.datasets[dataset_idx]
    if isinstance(curr_dataset, IterableDataset):
        iter_idx = self._iter_dataset_mapping[dataset_idx]
        try:
            example = next(self._iterators[iter_idx])
        except StopIteration:
            self._iterators[iter_idx] = iter(curr_dataset)
            example = next(self._iterators[iter_idx])
    else:
        if dataset_idx == 0:
            example_idx = idx
        else:
            example_idx = idx - self._cumulative_sizes[dataset_idx - 1]
        example = curr_dataset[example_idx]

    if not isinstance(example, Example):
        raise TypeError(
            "Expected dataset examples to be instances of `Example` "
            f"but found {type(example)}",
        )

    if not hasattr(example, "dataset_index"):
        example.dataset_index = dataset_idx
    if not hasattr(example, "example_ids"):
        example.create_ids()

    return example
__len__
__len__()

Return the total number of examples in the combined dataset.

Source code in mmlearn/datasets/core/combined_dataset.py
def __len__(self) -> int:
    """Return the total number of examples in the combined dataset."""
    return self._cumulative_sizes[-1]

data_collator

Data collators for batching examples.

DefaultDataCollator dataclass

Default data collator for batching examples.

This data collator will collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch. It can also apply processing functions to specified keys in the batch before returning it.

Parameters:

Name Type Description Default
batch_processors Optional[dict[str, Callable[[Any], Any]]]

Dictionary of callables to apply to the batch before returning it.

None

Raises:

Type Description
ValueError

If the batch processor for a key does not return a dictionary with the key in it.

Source code in mmlearn/datasets/core/data_collator.py
@dataclass
class DefaultDataCollator:
    """Default data collator for batching examples.

    This data collator will collate a list of :py:class:`~mmlearn.datasets.core.example.Example`
    objects into a batch. It can also apply processing functions to specified keys
    in the batch before returning it.

    Parameters
    ----------
    batch_processors : Optional[dict[str, Callable[[Any], Any]]], optional, default=None
        Dictionary of callables to apply to the batch before returning it.

    Raises
    ------
    ValueError
        If the batch processor for a key does not return a dictionary with the
        key in it.
    """  # noqa: W505

    #: Dictionary of callables to apply to the batch before returning it.
    #: The key is the name of the key in the batch, and the value is the processing
    #: function to apply to the key. The processing function must take a single
    #: argument and return a single value. If the processing function returns
    #: a dictionary, it must contain the key that was processed in it (all the
    #: other keys will also be included in the batch).
    batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

    def __call__(self, examples: list[Example]) -> dict[str, Any]:
        """Collate a list of `Example` objects and apply processing functions."""
        batch = collate_example_list(examples)

        if self.batch_processors is not None:
            for key, processor in self.batch_processors.items():
                batch_key: str = key
                if Modalities.has_modality(key):
                    batch_key = Modalities.get_modality(key).name

                if batch_key in batch:
                    batch_processed = processor(batch[batch_key])
                    if isinstance(batch_processed, Mapping):
                        if batch_key not in batch_processed:
                            raise ValueError(
                                f"Batch processor for '{key}' key must return a dictionary "
                                f"with '{batch_key}' in it."
                            )
                        batch.update(batch_processed)
                    else:
                        batch[batch_key] = batch_processed

        return batch
__call__
__call__(examples)

Collate a list of Example objects and apply processing functions.

Source code in mmlearn/datasets/core/data_collator.py
def __call__(self, examples: list[Example]) -> dict[str, Any]:
    """Collate a list of `Example` objects and apply processing functions."""
    batch = collate_example_list(examples)

    if self.batch_processors is not None:
        for key, processor in self.batch_processors.items():
            batch_key: str = key
            if Modalities.has_modality(key):
                batch_key = Modalities.get_modality(key).name

            if batch_key in batch:
                batch_processed = processor(batch[batch_key])
                if isinstance(batch_processed, Mapping):
                    if batch_key not in batch_processed:
                        raise ValueError(
                            f"Batch processor for '{key}' key must return a dictionary "
                            f"with '{batch_key}' in it."
                        )
                    batch.update(batch_processed)
                else:
                    batch[batch_key] = batch_processed

    return batch

collate_example_list

collate_example_list(examples)

Collate a list of 🇵🇾class:~mmlearn.datasets.core.example.Example objects into a batch.

Parameters:

Name Type Description Default
examples list[Example]

list of examples to collate.

required

Returns:

Type Description
dict[str, Any]

Dictionary of batched examples.

Source code in mmlearn/datasets/core/data_collator.py
def collate_example_list(examples: list[Example]) -> dict[str, Any]:
    """Collate a list of :py:class:`~mmlearn.datasets.core.example.Example` objects into a batch.

    Parameters
    ----------
    examples : list[Example]
        list of examples to collate.

    Returns
    -------
    dict[str, Any]
        Dictionary of batched examples.

    """  # noqa: W505
    return _collate_example_dict(_merge_examples(examples))

example

Module for example-related classes and functions.

Example

Bases: OrderedDict[Any, Any]

A representation of a single example from a dataset.

This class is a subclass of 🇵🇾class:~collections.OrderedDict and provides attribute-style access. This means that example["text"] and example.text are equivalent. All datasets in this library return examples as 🇵🇾class:~mmlearn.datasets.core.example.Example objects.

Parameters:

Name Type Description Default
init_dict Optional[MutableMapping[Hashable, Any]]

Dictionary to init Example class with.

None

Examples:

>>> example = Example({"text": torch.tensor(2)})
>>> example.text.zero_()
tensor(0)
>>> example.context = torch.tensor(4)  # set custom attributes after initialization
Source code in mmlearn/datasets/core/example.py
class Example(OrderedDict[Any, Any]):
    """A representation of a single example from a dataset.

    This class is a subclass of :py:class:`~collections.OrderedDict` and provides
    attribute-style access. This means that `example["text"]` and `example.text`
    are equivalent. All datasets in this library return examples as
    :py:class:`~mmlearn.datasets.core.example.Example` objects.


    Parameters
    ----------
    init_dict : Optional[MutableMapping[Hashable, Any]], optional, default=None
        Dictionary to init `Example` class with.

    Examples
    --------
    >>> example = Example({"text": torch.tensor(2)})
    >>> example.text.zero_()
    tensor(0)
    >>> example.context = torch.tensor(4)  # set custom attributes after initialization
    """

    def __init__(
        self,
        init_dict: Optional[MutableMapping[Hashable, Any]] = None,
    ) -> None:
        if init_dict is None:
            init_dict = {}
        super().__init__(init_dict)

    def create_ids(self) -> None:
        """Create a unique id for the example from the dataset and example index.

        This method combines the dataset index and example index to create an
        attribute called `example_ids`, which is a dictionary of tensors. The
        dictionary keys are all the keys in the example except for `example_ids`,
        `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
        containing the tuple `(dataset_index, example_index)` for each key.
        The `example_ids` is used to (re-)identify pairs of examples from different
        modalities after they have been combined into a batch.

        Warns
        -----
        UserWarning
            If the `example_index` and `dataset_index` attributes are not set.

        Notes
        -----
        - The Example must have the following attributes set before calling this
          this method: `example_index` (usually set/returned by the dataset) and
          `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
        - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
          function can be used to find matching examples given two tensors of example ids.

        """  # noqa: W505
        if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
            self.example_ids = {
                key: torch.tensor([self.dataset_index, self.example_index])
                for key in self.keys()
                if key not in ("example_ids", "example_index", "dataset_index")
            }
        else:
            rank_zero_warn(
                "Cannot create `example_ids` without `example_index` and `dataset_index` "
                "attributes. Set these attributes before calling `create_ids`. "
                "No `example_ids` was created.",
                stacklevel=2,
                category=UserWarning,
            )

    def __getattr__(self, key: str) -> Any:
        """Get attribute by key."""
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key) from None

    def __setattr__(self, key: str, value: Any) -> None:
        """Set attribute by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        self[key] = value

    def __setitem__(self, key: Hashable, value: Any) -> None:
        """Set item by key."""
        if isinstance(value, MutableMapping):
            value = Example(value)
        super().__setitem__(key, value)
create_ids
create_ids()

Create a unique id for the example from the dataset and example index.

This method combines the dataset index and example index to create an attribute called example_ids, which is a dictionary of tensors. The dictionary keys are all the keys in the example except for example_ids, example_index, and dataset_index. The values are tensors of shape (2,) containing the tuple (dataset_index, example_index) for each key. The example_ids is used to (re-)identify pairs of examples from different modalities after they have been combined into a batch.

Warns:

Type Description
UserWarning

If the example_index and dataset_index attributes are not set.

Notes
  • The Example must have the following attributes set before calling this this method: example_index (usually set/returned by the dataset) and dataset_index (usually set by the 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset object)
  • The 🇵🇾func:~mmlearn.datasets.core.example.find_matching_indices function can be used to find matching examples given two tensors of example ids.
Source code in mmlearn/datasets/core/example.py
def create_ids(self) -> None:
    """Create a unique id for the example from the dataset and example index.

    This method combines the dataset index and example index to create an
    attribute called `example_ids`, which is a dictionary of tensors. The
    dictionary keys are all the keys in the example except for `example_ids`,
    `example_index`, and `dataset_index`. The values are tensors of shape `(2,)`
    containing the tuple `(dataset_index, example_index)` for each key.
    The `example_ids` is used to (re-)identify pairs of examples from different
    modalities after they have been combined into a batch.

    Warns
    -----
    UserWarning
        If the `example_index` and `dataset_index` attributes are not set.

    Notes
    -----
    - The Example must have the following attributes set before calling this
      this method: `example_index` (usually set/returned by the dataset) and
      `dataset_index` (usually set by the :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset` object)
    - The :py:func:`~mmlearn.datasets.core.example.find_matching_indices`
      function can be used to find matching examples given two tensors of example ids.

    """  # noqa: W505
    if hasattr(self, "example_index") and hasattr(self, "dataset_index"):
        self.example_ids = {
            key: torch.tensor([self.dataset_index, self.example_index])
            for key in self.keys()
            if key not in ("example_ids", "example_index", "dataset_index")
        }
    else:
        rank_zero_warn(
            "Cannot create `example_ids` without `example_index` and `dataset_index` "
            "attributes. Set these attributes before calling `create_ids`. "
            "No `example_ids` was created.",
            stacklevel=2,
            category=UserWarning,
        )
__getattr__
__getattr__(key)

Get attribute by key.

Source code in mmlearn/datasets/core/example.py
def __getattr__(self, key: str) -> Any:
    """Get attribute by key."""
    try:
        return self[key]
    except KeyError:
        raise AttributeError(key) from None
__setattr__
__setattr__(key, value)

Set attribute by key.

Source code in mmlearn/datasets/core/example.py
def __setattr__(self, key: str, value: Any) -> None:
    """Set attribute by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    self[key] = value
__setitem__
__setitem__(key, value)

Set item by key.

Source code in mmlearn/datasets/core/example.py
def __setitem__(self, key: Hashable, value: Any) -> None:
    """Set item by key."""
    if isinstance(value, MutableMapping):
        value = Example(value)
    super().__setitem__(key, value)

find_matching_indices

find_matching_indices(
    first_example_ids, second_example_ids
)

Find the indices of matching examples given two tensors of example ids.

Matching examples are defined as examples with the same value in both tensors. This method is useful for finding pairs of examples from different modalities that are related to each other in a batch.

Parameters:

Name Type Description Default
first_example_ids Tensor

A tensor of example ids of shape (N, 2), where N is the number of examples.

required
second_example_ids Tensor

A tensor of example ids of shape (M, 2), where M is the number of examples.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of tensors containing the indices of matching examples in the first and second tensor, respectively.

Raises:

Type Description
TypeError

If either first_example_ids or second_example_ids is not a tensor.

ValueError

If either first_example_ids or second_example_ids is not a 2D tensor with the second dimension having a size of 2.

Examples:

>>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
>>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
>>> find_matching_indices(img_example_ids, text_example_ids)
(tensor([2, 3]), tensor([0, 1]))
Source code in mmlearn/datasets/core/example.py
def find_matching_indices(
    first_example_ids: torch.Tensor, second_example_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Find the indices of matching examples given two tensors of example ids.

    Matching examples are defined as examples with the same value in both tensors.
    This method is useful for finding pairs of examples from different modalities
    that are related to each other in a batch.

    Parameters
    ----------
    first_example_ids : torch.Tensor
        A tensor of example ids of shape `(N, 2)`, where `N` is the number of examples.
    second_example_ids : torch.Tensor
        A tensor of example ids of shape `(M, 2)`, where `M` is the number of examples.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        A tuple of tensors containing the indices of matching examples in the first and
        second tensor, respectively.

    Raises
    ------
    TypeError
        If either `first_example_ids` or `second_example_ids` is not a tensor.
    ValueError
        If either `first_example_ids` or `second_example_ids` is not a 2D tensor
        with the second dimension having a size of `2`.

    Examples
    --------
    >>> img_example_ids = torch.tensor([(0, 0), (0, 1), (1, 0), (1, 1)])
    >>> text_example_ids = torch.tensor([(1, 0), (1, 1), (2, 0), (2, 1), (2, 2)])
    >>> find_matching_indices(img_example_ids, text_example_ids)
    (tensor([2, 3]), tensor([0, 1]))


    """
    if not isinstance(first_example_ids, torch.Tensor) or not isinstance(
        second_example_ids,
        torch.Tensor,
    ):
        raise TypeError(
            f"Expected inputs to be tensors, but got {type(first_example_ids)} "
            f"and {type(second_example_ids)}.",
        )
    val = 2
    if not (first_example_ids.ndim == val and first_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `first_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {first_example_ids.shape}.",
        )
    if not (second_example_ids.ndim == val and second_example_ids.shape[1] == val):
        raise ValueError(
            "Expected argument `second_example_ids` to be a tensor of shape (N, 2), "
            f"but got shape {second_example_ids.shape}.",
        )

    first_example_ids = first_example_ids.unsqueeze(1)  # shape=(N, 1, 2)
    second_example_ids = second_example_ids.unsqueeze(0)  # shape=(1, M, 2)

    # compare all elements; results in a shape (N, M) tensor
    matches = torch.all(first_example_ids == second_example_ids, dim=-1)
    first_indices, second_indices = torch.where(matches)
    return first_indices, second_indices

modalities

Module for managing supported modalities in the library.

Modality dataclass

A representation of a modality in the library.

This class is used to represent a modality in the library. It contains the name of the modality and the properties that can be associated with it. The properties are dynamically generated based on the name of the modality and can be accessed as attributes of the class.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality, by default None

None

Raises:

Type Description
ValueError

If the property already exists for the modality or if the format string is invalid.

Source code in mmlearn/datasets/core/modalities.py
@dataclass
class Modality:
    """A representation of a modality in the library.

    This class is used to represent a modality in the library. It contains the name of
    the modality and the properties that can be associated with it. The properties are
    dynamically generated based on the name of the modality and can be accessed as
    attributes of the class.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality, by default None

    Raises
    ------
    ValueError
        If the property already exists for the modality or if the format string is
        invalid.
    """

    #: The name of the modality.
    name: str

    #: Target/label associated with the modality. This will return ``name_target``.
    target: str = field(init=False, repr=False)

    #: Attention mask associated with the modality. This will return
    # ``name_attention_mask``.
    attention_mask: str = field(init=False, repr=False)

    #: Input mask associated with the modality. This will return ``name_mask``.
    mask: str = field(init=False, repr=False)

    #: Embedding associated with the modality. This will return ``name_embedding``.
    embedding: str = field(init=False, repr=False)

    #: Masked embedding associated with the modality. This will return
    # ``name_masked_embedding``.
    masked_embedding: str = field(init=False, repr=False)

    #: Embedding from an Exponential Moving Average (EMA) encoder associated with
    #: the modality.
    ema_embedding: str = field(init=False, repr=False)

    #: Other properties specific to the modality.
    modality_specific_properties: Optional[dict[str, str]] = field(
        default=None, repr=False
    )

    def __post_init__(self) -> None:
        """Initialize the modality with the name and properties."""
        self.name = self.name.lower()
        self._properties = {}

        for field_name in self.__dataclass_fields__:
            if field_name not in ("name", "modality_specific_properties"):
                field_value = f"{self.name}_{field_name}"
                self._properties[field_name] = field_value
                setattr(self, field_name, field_value)

        if self.modality_specific_properties is not None:
            for (
                property_name,
                format_string,
            ) in self.modality_specific_properties.items():
                self.add_property(property_name, format_string)

    @property
    def properties(self) -> dict[str, str]:
        """Return the properties associated with the modality."""
        return self._properties

    def add_property(self, name: str, format_string: str) -> None:
        """Add a new property to the modality.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the modality. It will overwrite the
            existing property.

        Raises
        ------
        ValueError
            If `format_string` is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        if name in self._properties:
            warnings.warn(
                f"Property '{name}' already exists for modality '{super().__str__()}'."
                "Will overwrite the existing property.",
                category=UserWarning,
                stacklevel=2,
            )

        if not _is_format_string(format_string):
            raise ValueError(
                f"Invalid format string '{format_string}' for property "
                f"'{name}' of modality '{super().__str__()}'."
            )

        self._properties[name] = format_string.format(self.name)
        setattr(self, name, self._properties[name])

    def __str__(self) -> str:
        """Return the object as a string."""
        return self.name.lower()
properties property
properties

Return the properties associated with the modality.

__post_init__
__post_init__()

Initialize the modality with the name and properties.

Source code in mmlearn/datasets/core/modalities.py
def __post_init__(self) -> None:
    """Initialize the modality with the name and properties."""
    self.name = self.name.lower()
    self._properties = {}

    for field_name in self.__dataclass_fields__:
        if field_name not in ("name", "modality_specific_properties"):
            field_value = f"{self.name}_{field_name}"
            self._properties[field_name] = field_value
            setattr(self, field_name, field_value)

    if self.modality_specific_properties is not None:
        for (
            property_name,
            format_string,
        ) in self.modality_specific_properties.items():
            self.add_property(property_name, format_string)
add_property
add_property(name, format_string)

Add a new property to the modality.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the modality. It will overwrite the existing property.

Raises:

Type Description
ValueError

If format_string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_property(self, name: str, format_string: str) -> None:
    """Add a new property to the modality.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the modality. It will overwrite the
        existing property.

    Raises
    ------
    ValueError
        If `format_string` is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    if name in self._properties:
        warnings.warn(
            f"Property '{name}' already exists for modality '{super().__str__()}'."
            "Will overwrite the existing property.",
            category=UserWarning,
            stacklevel=2,
        )

    if not _is_format_string(format_string):
        raise ValueError(
            f"Invalid format string '{format_string}' for property "
            f"'{name}' of modality '{super().__str__()}'."
        )

    self._properties[name] = format_string.format(self.name)
    setattr(self, name, self._properties[name])
__str__
__str__()

Return the object as a string.

Source code in mmlearn/datasets/core/modalities.py
def __str__(self) -> str:
    """Return the object as a string."""
    return self.name.lower()

ModalityRegistry

Modality registry.

A singleton class that manages the supported modalities (and their properties) in the library. The class provides methods to add new modalities and properties, and to access the existing modalities. The class is implemented as a singleton to ensure that there is only one instance of the registry in the library.

Source code in mmlearn/datasets/core/modalities.py
class ModalityRegistry:
    """Modality registry.

    A singleton class that manages the supported modalities (and their properties) in
    the library. The class provides methods to add new modalities and properties, and
    to access the existing modalities. The class is implemented as a singleton to
    ensure that there is only one instance of the registry in the library.
    """

    _instance: ClassVar[Any] = None
    _modality_registry: dict[str, Modality] = {}

    def __new__(cls) -> Self:
        """Create a new instance of the class if it does not exist."""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._modality_registry = {}
        return cls._instance  # type: ignore[no-any-return]

    def register_modality(
        self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
    ) -> None:
        """Add a new modality to the registry.

        Parameters
        ----------
        name : str
            The name of the modality.
        modality_specific_properties : Optional[dict[str, str]], optional, default=None
            Additional properties specific to the modality.

        Warns
        -----
        UserWarning
            If the modality already exists in the registry. It will overwrite the
            existing modality.

        """
        if name.lower() in self._modality_registry:
            warnings.warn(
                f"Modality '{name}' already exists in the registry. Overwriting...",
                category=UserWarning,
                stacklevel=2,
            )

        name = name.lower()
        modality = Modality(name, modality_specific_properties)
        self._modality_registry[name] = modality
        setattr(self, name, modality)

    def add_default_property(self, name: str, format_string: str) -> None:
        """Add a new property that is applicable to all modalities.

        Parameters
        ----------
        name : str
            The name of the property.
        format_string : str
            The format string for the property. The format string should contain a
            placeholder that will be replaced with the name of the modality when the
            property is accessed.

        Warns
        -----
        UserWarning
            If the property already exists for the default properties. It will
            overwrite the existing property.

        Raises
        ------
        ValueError
            If the format string is invalid. A valid format string contains at least one
            placeholder enclosed in curly braces.
        """
        for modality in self._modality_registry.values():
            modality.add_property(name, format_string)

    def has_modality(self, name: str) -> bool:
        """Check if the modality exists in the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        bool
            True if the modality exists in the registry, False otherwise.
        """
        return name.lower() in self._modality_registry

    def get_modality(self, name: str) -> Modality:
        """Get the modality name from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        Modality
            The modality object from the registry.
        """
        return self._modality_registry[name.lower()]

    def get_modality_properties(self, name: str) -> dict[str, str]:
        """Get the properties of a modality from the registry.

        Parameters
        ----------
        name : str
            The name of the modality.

        Returns
        -------
        dict[str, str]
            The properties associated with the modality.
        """
        return self.get_modality(name).properties

    def list_modalities(self) -> list[Modality]:
        """Get the list of supported modalities in the registry.

        Returns
        -------
        list[Modality]
            The list of supported modalities in the registry.
        """
        return list(self._modality_registry.values())

    def __getattr__(self, name: str) -> Modality:
        """Access a modality as an attribute by its name."""
        if name.lower() in self._modality_registry:
            return self._modality_registry[name.lower()]
        raise AttributeError(
            f"'{self.__class__.__name__}' object has no attribute '{name}'"
        )
__new__
__new__()

Create a new instance of the class if it does not exist.

Source code in mmlearn/datasets/core/modalities.py
def __new__(cls) -> Self:
    """Create a new instance of the class if it does not exist."""
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._modality_registry = {}
    return cls._instance  # type: ignore[no-any-return]
register_modality
register_modality(name, modality_specific_properties=None)

Add a new modality to the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required
modality_specific_properties Optional[dict[str, str]]

Additional properties specific to the modality.

None

Warns:

Type Description
UserWarning

If the modality already exists in the registry. It will overwrite the existing modality.

Source code in mmlearn/datasets/core/modalities.py
def register_modality(
    self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
) -> None:
    """Add a new modality to the registry.

    Parameters
    ----------
    name : str
        The name of the modality.
    modality_specific_properties : Optional[dict[str, str]], optional, default=None
        Additional properties specific to the modality.

    Warns
    -----
    UserWarning
        If the modality already exists in the registry. It will overwrite the
        existing modality.

    """
    if name.lower() in self._modality_registry:
        warnings.warn(
            f"Modality '{name}' already exists in the registry. Overwriting...",
            category=UserWarning,
            stacklevel=2,
        )

    name = name.lower()
    modality = Modality(name, modality_specific_properties)
    self._modality_registry[name] = modality
    setattr(self, name, modality)
add_default_property
add_default_property(name, format_string)

Add a new property that is applicable to all modalities.

Parameters:

Name Type Description Default
name str

The name of the property.

required
format_string str

The format string for the property. The format string should contain a placeholder that will be replaced with the name of the modality when the property is accessed.

required

Warns:

Type Description
UserWarning

If the property already exists for the default properties. It will overwrite the existing property.

Raises:

Type Description
ValueError

If the format string is invalid. A valid format string contains at least one placeholder enclosed in curly braces.

Source code in mmlearn/datasets/core/modalities.py
def add_default_property(self, name: str, format_string: str) -> None:
    """Add a new property that is applicable to all modalities.

    Parameters
    ----------
    name : str
        The name of the property.
    format_string : str
        The format string for the property. The format string should contain a
        placeholder that will be replaced with the name of the modality when the
        property is accessed.

    Warns
    -----
    UserWarning
        If the property already exists for the default properties. It will
        overwrite the existing property.

    Raises
    ------
    ValueError
        If the format string is invalid. A valid format string contains at least one
        placeholder enclosed in curly braces.
    """
    for modality in self._modality_registry.values():
        modality.add_property(name, format_string)
has_modality
has_modality(name)

Check if the modality exists in the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
bool

True if the modality exists in the registry, False otherwise.

Source code in mmlearn/datasets/core/modalities.py
def has_modality(self, name: str) -> bool:
    """Check if the modality exists in the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    bool
        True if the modality exists in the registry, False otherwise.
    """
    return name.lower() in self._modality_registry
get_modality
get_modality(name)

Get the modality name from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
Modality

The modality object from the registry.

Source code in mmlearn/datasets/core/modalities.py
def get_modality(self, name: str) -> Modality:
    """Get the modality name from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    Modality
        The modality object from the registry.
    """
    return self._modality_registry[name.lower()]
get_modality_properties
get_modality_properties(name)

Get the properties of a modality from the registry.

Parameters:

Name Type Description Default
name str

The name of the modality.

required

Returns:

Type Description
dict[str, str]

The properties associated with the modality.

Source code in mmlearn/datasets/core/modalities.py
def get_modality_properties(self, name: str) -> dict[str, str]:
    """Get the properties of a modality from the registry.

    Parameters
    ----------
    name : str
        The name of the modality.

    Returns
    -------
    dict[str, str]
        The properties associated with the modality.
    """
    return self.get_modality(name).properties
list_modalities
list_modalities()

Get the list of supported modalities in the registry.

Returns:

Type Description
list[Modality]

The list of supported modalities in the registry.

Source code in mmlearn/datasets/core/modalities.py
def list_modalities(self) -> list[Modality]:
    """Get the list of supported modalities in the registry.

    Returns
    -------
    list[Modality]
        The list of supported modalities in the registry.
    """
    return list(self._modality_registry.values())
__getattr__
__getattr__(name)

Access a modality as an attribute by its name.

Source code in mmlearn/datasets/core/modalities.py
def __getattr__(self, name: str) -> Modality:
    """Access a modality as an attribute by its name."""
    if name.lower() in self._modality_registry:
        return self._modality_registry[name.lower()]
    raise AttributeError(
        f"'{self.__class__.__name__}' object has no attribute '{name}'"
    )

samplers

Samplers for data loading.

CombinedDatasetRatioSampler

Bases: Sampler[int]

Sampler for weighted sampling from a 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset.

Parameters:

Name Type Description Default
dataset CombinedDataset

An instance of 🇵🇾class:~mmlearn.datasets.core.combined_dataset.CombinedDataset to sample from.

required
ratios Optional[Sequence[float]]

A sequence of ratios for sampling from each dataset in the combined dataset. The length of the sequence must be equal to the number of datasets in the combined dataset (dataset). If None, the length of each dataset in the combined dataset is used as the ratio. The ratios are normalized to sum to 1.

None
num_samples Optional[int]

The number of samples to draw from the combined dataset. If None, the sampler will draw as many samples as there are in the combined dataset. This number must yield at least one sample per dataset in the combined dataset, when multiplied by the corresponding ratio.

None
replacement bool

Whether to sample with replacement or not.

False
shuffle bool

Whether to shuffle the sampled indices or not. If False, the indices of each dataset will appear in the order they are stored in the combined dataset. This is similar to sequential sampling from each dataset. The datasets that make up the combined dataset are still sampled randomly.

True
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:num_replicas is retrieved from the current distributed group.

None
drop_last bool

Whether to drop the last incomplete batch or not. If True, the sampler will drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

False
seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

0

Attributes:

Name Type Description
dataset CombinedDataset

The dataset to sample from.

num_samples int

The number of samples to draw from the combined dataset.

probs Tensor

The probabilities for sampling from each dataset in the combined dataset. This is computed from the ratios argument and is normalized to sum to 1.

replacement bool

Whether to sample with replacement or not.

shuffle bool

Whether to shuffle the sampled indices or not.

rank int

Rank of the current process within :attr:num_replicas.

num_replicas int

Number of processes participating in distributed training.

drop_last bool

Whether to drop samples to make the number of samples evenly divisible by the number of replicas in distributed mode.

seed int

Random seed used to when sampling from the combined dataset and shuffling the sampled indices.

epoch int

Current epoch number. This is used to set the random seed. This is useful in distributed mode to ensure that each process receives a different random ordering of the samples.

total_size int

The total number of samples across all processes.

Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class CombinedDatasetRatioSampler(Sampler[int]):
    """Sampler for weighted sampling from a :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`.

    Parameters
    ----------
    dataset : CombinedDataset
        An instance of :py:class:`~mmlearn.datasets.core.combined_dataset.CombinedDataset`
        to sample from.
    ratios : Optional[Sequence[float]], optional, default=None
        A sequence of ratios for sampling from each dataset in the combined dataset.
        The length of the sequence must be equal to the number of datasets in the
        combined dataset (`dataset`). If `None`, the length of each dataset in the
        combined dataset is used as the ratio. The ratios are normalized to sum to 1.
    num_samples : Optional[int], optional, default=None
        The number of samples to draw from the combined dataset. If `None`, the
        sampler will draw as many samples as there are in the combined dataset.
        This number must yield at least one sample per dataset in the combined
        dataset, when multiplied by the corresponding ratio.
    replacement : bool, default=False
        Whether to sample with replacement or not.
    shuffle : bool, default=True
        Whether to shuffle the sampled indices or not. If `False`, the indices of
        each dataset will appear in the order they are stored in the combined dataset.
        This is similar to sequential sampling from each dataset. The datasets
        that make up the combined dataset are still sampled randomly.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`num_replicas` is retrieved from the current distributed group.
    drop_last : bool, default=False
        Whether to drop the last incomplete batch or not. If `True`, the sampler will
        drop samples to make the number of samples evenly divisible by the number of
        replicas in distributed mode.
    seed : int, default=0
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.

    Attributes
    ----------
    dataset : CombinedDataset
        The dataset to sample from.
    num_samples : int
        The number of samples to draw from the combined dataset.
    probs : torch.Tensor
        The probabilities for sampling from each dataset in the combined dataset.
        This is computed from the `ratios` argument and is normalized to sum to 1.
    replacement : bool
        Whether to sample with replacement or not.
    shuffle : bool
        Whether to shuffle the sampled indices or not.
    rank : int
        Rank of the current process within :attr:`num_replicas`.
    num_replicas : int
        Number of processes participating in distributed training.
    drop_last : bool
        Whether to drop samples to make the number of samples evenly divisible by the
        number of replicas in distributed mode.
    seed : int
        Random seed used to when sampling from the combined dataset and shuffling
        the sampled indices.
    epoch : int
        Current epoch number. This is used to set the random seed. This is useful
        in distributed mode to ensure that each process receives a different random
        ordering of the samples.
    total_size : int
        The total number of samples across all processes.
    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        dataset: CombinedDataset,
        ratios: Optional[Sequence[float]] = None,
        num_samples: Optional[int] = None,
        replacement: bool = False,
        shuffle: bool = True,
        rank: Optional[int] = None,
        num_replicas: Optional[int] = None,
        drop_last: bool = False,
        seed: int = 0,
    ):
        if not isinstance(dataset, CombinedDataset):
            raise TypeError(
                "Expected argument `dataset` to be of type `CombinedDataset`, "
                f"but got {type(dataset)}.",
            )
        if not isinstance(seed, int):
            raise TypeError(
                f"Expected argument `seed` to be an integer, but got {type(seed)}.",
            )
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
            )

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.replacement = replacement
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0

        self._num_samples = num_samples
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "Expected argument `num_samples` to be a positive integer, but got "
                f"{self.num_samples}.",
            )

        if ratios is None:
            ratios = [len(subset) for subset in self.dataset.datasets]

        num_datasets = len(self.dataset.datasets)
        if len(ratios) != num_datasets:
            raise ValueError(
                f"Expected argument `ratios` to be of length {num_datasets}, "
                f"but got length {len(ratios)}.",
            )
        prob_sum = sum(ratios)
        if not all(ratio >= 0 for ratio in ratios) and prob_sum > 0:
            raise ValueError(
                "Expected argument `ratios` to be a sequence of non-negative numbers. "
                f"Got {ratios}.",
            )
        self.probs = torch.tensor(
            [ratio / prob_sum for ratio in ratios],
            dtype=torch.double,
        )
        if any((prob * self.num_samples) <= 0 for prob in self.probs):
            raise ValueError(
                "Expected dataset ratio to result in at least one sample per dataset. "
                f"Got dataset sizes {self.probs * self.num_samples}.",
            )

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        # dataset size might change at runtime
        if self._num_samples is None:
            num_samples = len(self.dataset)
        else:
            num_samples = self._num_samples

        if self.drop_last and num_samples % self.num_replicas != 0:
            # split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            num_samples = math.ceil(
                (num_samples - self.num_replicas) / self.num_replicas,
            )
        else:
            num_samples = math.ceil(num_samples / self.num_replicas)
        return num_samples

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return self.num_samples * self.num_replicas

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that yields sample indices for the combined dataset."""
        generator = torch.Generator()
        seed = self.seed + self.epoch
        generator.manual_seed(seed)

        cumulative_sizes = [0] + self.dataset._cumulative_sizes
        num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
        indices = []
        for i in range(len(self.dataset.datasets)):
            per_dataset_indices: torch.Tensor = torch.multinomial(
                torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
                num_samples_per_dataset[i],
                replacement=self.replacement,
                generator=generator,
            )
            # adjust indices to reflect position in cumulative dataset
            per_dataset_indices += cumulative_sizes[i]
            assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
                f"Indices from dataset {i} exceed dataset size. "
                f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
            )
            indices.append(per_dataset_indices)

        indices = torch.cat(indices)
        if self.shuffle:
            rand_indices = torch.randperm(len(indices), generator=generator)
            indices = indices[rand_indices]

        indices = indices.tolist()  # type: ignore[attr-defined]
        num_indices = len(indices)

        if num_indices < self.total_size:
            padding_size = self.total_size - num_indices
            if padding_size <= num_indices:
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / num_indices))[
                    :padding_size
                ]
        elif num_indices > self.total_size:
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples, (
            f"Expected {self.num_samples} samples, but got {len(indices)}.",
        )

        yield from iter(indices)

    def __len__(self) -> int:
        """Return the total number of samples in the sampler."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch

        # some iterable datasets (especially huggingface iterable datasets) might
        # require setting the epoch to ensure shuffling works properly
        for dataset in self.dataset.datasets:
            if hasattr(dataset, "set_epoch"):
                dataset.set_epoch(epoch)
num_samples property
num_samples

Return the number of samples managed by the sampler.

total_size property
total_size

Return the total size of the dataset.

__iter__
__iter__()

Return an iterator that yields sample indices for the combined dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that yields sample indices for the combined dataset."""
    generator = torch.Generator()
    seed = self.seed + self.epoch
    generator.manual_seed(seed)

    cumulative_sizes = [0] + self.dataset._cumulative_sizes
    num_samples_per_dataset = [int(prob * self.total_size) for prob in self.probs]
    indices = []
    for i in range(len(self.dataset.datasets)):
        per_dataset_indices: torch.Tensor = torch.multinomial(
            torch.ones(cumulative_sizes[i + 1] - cumulative_sizes[i]),
            num_samples_per_dataset[i],
            replacement=self.replacement,
            generator=generator,
        )
        # adjust indices to reflect position in cumulative dataset
        per_dataset_indices += cumulative_sizes[i]
        assert per_dataset_indices.max() < cumulative_sizes[i + 1], (
            f"Indices from dataset {i} exceed dataset size. "
            f"Got indices {per_dataset_indices} and dataset size {cumulative_sizes[i + 1]}.",
        )
        indices.append(per_dataset_indices)

    indices = torch.cat(indices)
    if self.shuffle:
        rand_indices = torch.randperm(len(indices), generator=generator)
        indices = indices[rand_indices]

    indices = indices.tolist()  # type: ignore[attr-defined]
    num_indices = len(indices)

    if num_indices < self.total_size:
        padding_size = self.total_size - num_indices
        if padding_size <= num_indices:
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / num_indices))[
                :padding_size
            ]
    elif num_indices > self.total_size:
        indices = indices[: self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples, (
        f"Expected {self.num_samples} samples, but got {len(indices)}.",
    )

    yield from iter(indices)
__len__
__len__()

Return the total number of samples in the sampler.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the total number of samples in the sampler."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

    # some iterable datasets (especially huggingface iterable datasets) might
    # require setting the epoch to ensure shuffling works properly
    for dataset in self.dataset.datasets:
        if hasattr(dataset, "set_epoch"):
            dataset.set_epoch(epoch)

DistributedEvalSampler

Bases: Sampler[int]

Sampler for distributed evaluation.

The main differences between this and 🇵🇾class:torch.utils.data.DistributedSampler are that this sampler does not add extra samples to make it evenly divisible and shuffling is disabled by default.

Parameters:

Name Type Description Default
dataset Dataset

Dataset used for sampling.

required
num_replicas Optional[int]

Number of processes participating in distributed training. By default, :attr:rank is retrieved from the current distributed group.

None
rank Optional[int]

Rank of the current process within :attr:num_replicas. By default, :attr:rank is retrieved from the current distributed group.

None
shuffle bool

If True (default), sampler will shuffle the indices.

False
seed int

Random seed used to shuffle the sampler if :attr:shuffle=True. This number should be identical across all processes in the distributed group.

0
Warnings

DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See [1]_ for details

Notes
  • This sampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. It is especially useful in conjunction with 🇵🇾class:torch.nn.parallel.DistributedDataParallel [2]_.
  • The input Dataset is assumed to be of constant size.
  • This implementation is adapted from [3]_.
References

.. [1] https://github.com/pytorch/pytorch/issues/22584 .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11 .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py

Examples:

>>> def example():
...     start_epoch, n_epochs = 0, 2
...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
...     for epoch in range(start_epoch, n_epochs):
...         if is_distributed:
...             sampler.set_epoch(epoch)
...         evaluate(loader)
Source code in mmlearn/datasets/core/samplers.py
@store(group="dataloader/sampler", provider="mmlearn")
class DistributedEvalSampler(Sampler[int]):
    """Sampler for distributed evaluation.

    The main differences between this and :py:class:`torch.utils.data.DistributedSampler`
    are that this sampler does not add extra samples to make it evenly divisible and
    shuffling is disabled by default.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        Dataset used for sampling.
    num_replicas : Optional[int], optional, default=None
        Number of processes participating in distributed training. By
        default, :attr:`rank` is retrieved from the current distributed group.
    rank : Optional[int], optional, default=None
        Rank of the current process within :attr:`num_replicas`. By default,
        :attr:`rank` is retrieved from the current distributed group.
    shuffle : bool, optional, default=False
        If `True` (default), sampler will shuffle the indices.
    seed : int, optional, default=0
        Random seed used to shuffle the sampler if :attr:`shuffle=True`.
        This number should be identical across all processes in the
        distributed group.

    Warnings
    --------
    DistributedEvalSampler should NOT be used for training. The distributed processes
    could hang forever. See [1]_ for details

    Notes
    -----
    - This sampler is for evaluation purpose where synchronization does not happen
      every epoch. Synchronization should be done outside the dataloader loop.
      It is especially useful in conjunction with
      :py:class:`torch.nn.parallel.DistributedDataParallel` [2]_.
    - The input Dataset is assumed to be of constant size.
    - This implementation is adapted from [3]_.

    References
    ----------
    .. [1] https://github.com/pytorch/pytorch/issues/22584
    .. [2] https://discuss.pytorch.org/t/how-to-validate-in-distributeddataparallel-correctly/94267/11
    .. [3] https://github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py


    Examples
    --------
    >>> def example():
    ...     start_epoch, n_epochs = 0, 2
    ...     sampler = DistributedEvalSampler(dataset) if is_distributed else None
    ...     loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
    ...     for epoch in range(start_epoch, n_epochs):
    ...         if is_distributed:
    ...             sampler.set_epoch(epoch)
    ...         evaluate(loader)

    """  # noqa: W505

    def __init__(
        self,
        dataset: Dataset[Sized],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = False,
        seed: int = 0,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.shuffle = shuffle
        self.seed = seed

    @property
    def total_size(self) -> int:
        """Return the total size of the dataset."""
        return len(self.dataset)

    @property
    def num_samples(self) -> int:
        """Return the number of samples managed by the sampler."""
        indices = list(range(self.total_size))[
            self.rank : self.total_size : self.num_replicas
        ]
        return len(indices)  # true value without extra samples

    def __iter__(self) -> Iterator[int]:
        """Return an iterator that iterates over the indices of the dataset."""
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(self.total_size, generator=g).tolist()
        else:
            indices = list(range(self.total_size))

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        """Return the number of samples."""
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        """Set the epoch for this sampler.

        When :attr:`shuffle=True`, this ensures all replicas use a different random
        ordering for each epoch. Otherwise, the next iteration of this sampler
        will yield the same ordering.

        Parameters
        ----------
        epoch : int
            Epoch number.

        """
        self.epoch = epoch
total_size property
total_size

Return the total size of the dataset.

num_samples property
num_samples

Return the number of samples managed by the sampler.

__iter__
__iter__()

Return an iterator that iterates over the indices of the dataset.

Source code in mmlearn/datasets/core/samplers.py
def __iter__(self) -> Iterator[int]:
    """Return an iterator that iterates over the indices of the dataset."""
    if self.shuffle:
        # deterministically shuffle based on epoch and seed
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()
    else:
        indices = list(range(self.total_size))

    # subsample
    indices = indices[self.rank : self.total_size : self.num_replicas]
    assert len(indices) == self.num_samples

    return iter(indices)
__len__
__len__()

Return the number of samples.

Source code in mmlearn/datasets/core/samplers.py
def __len__(self) -> int:
    """Return the number of samples."""
    return self.num_samples
set_epoch
set_epoch(epoch)

Set the epoch for this sampler.

When :attr:shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
Source code in mmlearn/datasets/core/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for this sampler.

    When :attr:`shuffle=True`, this ensures all replicas use a different random
    ordering for each epoch. Otherwise, the next iteration of this sampler
    will yield the same ordering.

    Parameters
    ----------
    epoch : int
        Epoch number.

    """
    self.epoch = epoch

Dataset Processors

mmlearn.datasets.processors

Data processors.

BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask

__repr__

__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )

get_shape

get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width

__call__

__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices

__call__

__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)

__call__

__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence

__call__

__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence

masking

Token mask generators.

RandomMaskGenerator

Random mask generator.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches. This is intended to be used for tasks like masked language modeling.

Parameters:

Name Type Description Default
probability float

Probability of masking a token.

required
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn", probability=0.15)
class RandomMaskGenerator:
    """Random mask generator.

    Returns a random mask of shape `(nb_patches, nb_patches)` based on the
    configuration where the number of patches to be masked is num_masking_patches.
    **This is intended to be used for tasks like masked language modeling.**

    Parameters
    ----------
    probability : float
        Probability of masking a token.
    """

    def __init__(self, probability: float):
        self.probability = probability

    def __call__(
        self,
        inputs: torch.Tensor,
        tokenizer: PreTrainedTokenizerBase,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        inputs : torch.Tensor
            The encoded inputs.
        tokenizer : PreTrainedTokenizer
            The tokenizer.
        special_tokens_mask : Optional[torch.Tensor], default=None
            Mask for special tokens.
        """
        inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training
        # (with probability `self.probability`)
        probability_matrix = torch.full(labels.shape, self.probability)
        if special_tokens_mask is None:
            special_tokens_mask = tokenizer.get_special_tokens_mask(
                labels, already_has_special_tokens=True
            )
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = tokenizer.pad_token_id
        # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # Rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices
__call__
__call__(inputs, tokenizer, special_tokens_mask=None)

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
inputs Tensor

The encoded inputs.

tokenizer PreTrainedTokenizer

The tokenizer.

special_tokens_mask Optional[torch.Tensor], default=None

Mask for special tokens.

Source code in mmlearn/datasets/processors/masking.py
def __call__(
    self,
    inputs: torch.Tensor,
    tokenizer: PreTrainedTokenizerBase,
    special_tokens_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    inputs : torch.Tensor
        The encoded inputs.
    tokenizer : PreTrainedTokenizer
        The tokenizer.
    special_tokens_mask : Optional[torch.Tensor], default=None
        Mask for special tokens.
    """
    inputs = tokenizer.pad(inputs, return_tensors="pt")["input_ids"]
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training
    # (with probability `self.probability`)
    probability_matrix = torch.full(labels.shape, self.probability)
    if special_tokens_mask is None:
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            labels, already_has_special_tokens=True
        )
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = tokenizer.pad_token_id
    # 80% of the time, replace masked input tokens with tokenizer.mask_token([MASK])
    indices_replaced = (
        torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    )
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = (
        torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    )
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # Rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels, masked_indices

BlockwiseImagePatchMaskGenerator

Blockwise image patch mask generator.

This is primarily intended for the data2vec method.

Parameters:

Name Type Description Default
input_size Union[int, tuple[int, int]]

The size of the input image. If an integer is provided, the image is assumed to be square.

required
num_masking_patches int

The number of patches to mask.

required
min_num_patches int

The minimum number of patches to mask.

4
max_num_patches int

The maximum number of patches to mask.

None
min_aspect_ratio float

The minimum aspect ratio of the patch.

0.3
max_aspect_ratio float

The maximum aspect ratio of the patch.

None
Source code in mmlearn/datasets/processors/masking.py
@store(group="datasets/masking", provider="mmlearn")
class BlockwiseImagePatchMaskGenerator:
    """Blockwise image patch mask generator.

    This is primarily intended for the data2vec method.

    Parameters
    ----------
    input_size : Union[int, tuple[int, int]]
        The size of the input image. If an integer is provided, the image is assumed
        to be square.
    num_masking_patches : int
        The number of patches to mask.
    min_num_patches : int, default=4
        The minimum number of patches to mask.
    max_num_patches : int, default=None
        The maximum number of patches to mask.
    min_aspect_ratio : float, default=0.3
        The minimum aspect ratio of the patch.
    max_aspect_ratio : float, default=None
        The maximum aspect ratio of the patch.
    """

    def __init__(
        self,
        input_size: Union[int, tuple[int, int]],
        num_masking_patches: int,
        min_num_patches: int = 4,
        max_num_patches: Any = None,
        min_aspect_ratio: float = 0.3,
        max_aspect_ratio: Any = None,
    ):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        self.height, self.width = input_size

        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = (
            num_masking_patches if max_num_patches is None else max_num_patches
        )

        max_aspect_ratio = max_aspect_ratio or 1 / min_aspect_ratio
        self.log_aspect_ratio = (math.log(min_aspect_ratio), math.log(max_aspect_ratio))

    def __repr__(self) -> str:
        """Generate a printable representation.

        Returns
        -------
        str
            A printable representation of the object.

        """
        return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.min_num_patches,
            self.max_num_patches,
            self.num_masking_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )

    def get_shape(self) -> tuple[int, int]:
        """Get the shape of the input.

        Returns
        -------
        tuple[int, int]
            The shape of the input as a tuple `(height, width)`.
        """
        return self.height, self.width

    def _mask(self, mask: torch.Tensor, max_mask_patches: int) -> int:
        """Masking function.

        This function mask adjacent patches by first selecting a target area and aspect
        ratio. Since, there might be overlap between selected areas  or the selected
        area might already be masked, it runs for a  maximum of 10 attempts or until the
        specified number of patches (max_mask_patches) is achieved.


        Parameters
        ----------
        mask: torch.Tensor
            Current mask. The mask to be updated.
        max_mask_patches: int
            The maximum number of patches to be masked.

        Returns
        -------
        delta: int
            The number of patches that were successfully masked.

        Notes
        -----
        - `target_area`: Randomly chosen target area for the patch.
        - `aspect_ratio`: Randomly chosen aspect ratio for the patch.
        - `h`: Height of the patch based on the target area and aspect ratio.
        - `w`: Width of the patch based on the target area and aspect ratio.
        - `top`: Randomly chosen top position for the patch.
        - `left`: Randomly chosen left position for the patch.
        - `num_masked`: Number of masked pixels within the proposed patch area.
        - `delta`: Accumulated count of modified pixels.
        """
        delta = 0
        for _ in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top : top + h, left : left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self) -> torch.Tensor:
        """Generate a random mask.

        Returns a random mask of shape (nb_patches, nb_patches) based on the
        configuration where the number of patches to be masked is num_masking_patches.

        Returns
        -------
        mask: torch.Tensor
            A mask of shape (nb_patches, nb_patches)

        """
        mask = torch.zeros(self.get_shape(), dtype=torch.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            mask_count += delta

        return mask
__repr__
__repr__()

Generate a printable representation.

Returns:

Type Description
str

A printable representation of the object.

Source code in mmlearn/datasets/processors/masking.py
def __repr__(self) -> str:
    """Generate a printable representation.

    Returns
    -------
    str
        A printable representation of the object.

    """
    return "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
        self.height,
        self.width,
        self.min_num_patches,
        self.max_num_patches,
        self.num_masking_patches,
        self.log_aspect_ratio[0],
        self.log_aspect_ratio[1],
    )
get_shape
get_shape()

Get the shape of the input.

Returns:

Type Description
tuple[int, int]

The shape of the input as a tuple (height, width).

Source code in mmlearn/datasets/processors/masking.py
def get_shape(self) -> tuple[int, int]:
    """Get the shape of the input.

    Returns
    -------
    tuple[int, int]
        The shape of the input as a tuple `(height, width)`.
    """
    return self.height, self.width
__call__
__call__()

Generate a random mask.

Returns a random mask of shape (nb_patches, nb_patches) based on the configuration where the number of patches to be masked is num_masking_patches.

Returns:

Name Type Description
mask Tensor

A mask of shape (nb_patches, nb_patches)

Source code in mmlearn/datasets/processors/masking.py
def __call__(self) -> torch.Tensor:
    """Generate a random mask.

    Returns a random mask of shape (nb_patches, nb_patches) based on the
    configuration where the number of patches to be masked is num_masking_patches.

    Returns
    -------
    mask: torch.Tensor
        A mask of shape (nb_patches, nb_patches)

    """
    mask = torch.zeros(self.get_shape(), dtype=torch.int)
    mask_count = 0
    while mask_count < self.num_masking_patches:
        max_mask_patches = self.num_masking_patches - mask_count
        max_mask_patches = min(max_mask_patches, self.max_num_patches)

        delta = self._mask(mask, max_mask_patches)
        if delta == 0:
            break
        mask_count += delta

    return mask

IJEPAMaskGenerator dataclass

Generates encoder and predictor masks for preprocessing.

This class generates masks dynamically for batches of examples.

Parameters:

Name Type Description Default
input_size tuple[int, int]

Input image size.

(224, 224)
patch_size int

Size of each patch.

16
min_keep int

Minimum number of patches to keep.

10
allow_overlap bool

Whether to allow overlap between encoder and predictor masks.

False
enc_mask_scale tuple[float, float]

Scale range for encoder mask.

(0.85, 1.0)
pred_mask_scale tuple[float, float]

Scale range for predictor mask.

(0.15, 0.2)
aspect_ratio tuple[float, float]

Aspect ratio range for mask blocks.

(0.75, 1.0)
nenc int

Number of encoder masks to generate.

1
npred int

Number of predictor masks to generate.

4
Source code in mmlearn/datasets/processors/masking.py
@dataclass
class IJEPAMaskGenerator:
    """Generates encoder and predictor masks for preprocessing.

    This class generates masks dynamically for batches of examples.

    Parameters
    ----------
    input_size : tuple[int, int], default=(224, 224)
        Input image size.
    patch_size : int, default=16
        Size of each patch.
    min_keep : int, default=10
        Minimum number of patches to keep.
    allow_overlap : bool, default=False
        Whether to allow overlap between encoder and predictor masks.
    enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
        Scale range for encoder mask.
    pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
        Scale range for predictor mask.
    aspect_ratio : tuple[float, float], default=(0.75, 1.0)
        Aspect ratio range for mask blocks.
    nenc : int, default=1
        Number of encoder masks to generate.
    npred : int, default=4
        Number of predictor masks to generate.
    """

    input_size: tuple[int, int] = (224, 224)
    patch_size: int = 16
    min_keep: int = 10
    allow_overlap: bool = False
    enc_mask_scale: tuple[float, float] = (0.85, 1.0)
    pred_mask_scale: tuple[float, float] = (0.15, 0.2)
    aspect_ratio: tuple[float, float] = (0.75, 1.5)
    nenc: int = 1
    npred: int = 4

    def __post_init__(self) -> None:
        """Initialize the mask generator."""
        self.height = self.input_size[0] // self.patch_size
        self.width = self.input_size[1] // self.patch_size

    def _sample_block_size(
        self,
        generator: torch.Generator,
        scale: tuple[float, float],
        aspect_ratio: tuple[float, float],
    ) -> tuple[int, int]:
        """Sample the size of the mask block based on scale and aspect ratio."""
        _rand = torch.rand(1, generator=generator).item()
        min_s, max_s = scale
        mask_scale = min_s + _rand * (max_s - min_s)
        max_keep = int(self.height * self.width * mask_scale)

        min_ar, max_ar = aspect_ratio
        aspect_ratio_val = min_ar + _rand * (max_ar - min_ar)

        h = int(round(math.sqrt(max_keep * aspect_ratio_val)))
        w = int(round(math.sqrt(max_keep / aspect_ratio_val)))

        h = min(h, self.height - 1)
        w = min(w, self.width - 1)

        return h, w

    def _sample_block_mask(
        self, b_size: tuple[int, int]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Sample a mask block."""
        h, w = b_size
        top = torch.randint(0, self.height - h, (1,)).item()
        left = torch.randint(0, self.width - w, (1,)).item()
        mask = torch.zeros((self.height, self.width), dtype=torch.int32)
        mask[top : top + h, left : left + w] = 1

        mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
        mask_complement[top : top + h, left : left + w] = 0

        return mask.flatten(), mask_complement.flatten()

    def __call__(self, batch_size: int = 1) -> dict[str, Any]:
        """Generate encoder and predictor masks for a batch of examples.

        Parameters
        ----------
        batch_size : int, default=1
            The batch size for which to generate masks.

        Returns
        -------
        dict[str, Any]
            A dictionary of encoder masks and predictor masks.
        """
        seed = torch.randint(
            0, 2**32, (1,)
        ).item()  # Sample random seed for reproducibility
        g = torch.Generator().manual_seed(seed)

        # Sample block sizes
        p_size = self._sample_block_size(
            generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
        )
        e_size = self._sample_block_size(
            generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
        )

        # Generate predictor masks
        masks_pred, masks_enc = [], []
        for _ in range(self.npred):
            mask_p, _ = self._sample_block_mask(p_size)
            # Expand mask to match batch size
            mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
            masks_pred.append(mask_p)

        # Generate encoder masks
        for _ in range(self.nenc):
            mask_e, _ = self._sample_block_mask(e_size)
            # Expand mask to match batch size
            mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
            masks_enc.append(mask_e)

        return {
            "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
            "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
        }
__post_init__
__post_init__()

Initialize the mask generator.

Source code in mmlearn/datasets/processors/masking.py
def __post_init__(self) -> None:
    """Initialize the mask generator."""
    self.height = self.input_size[0] // self.patch_size
    self.width = self.input_size[1] // self.patch_size
__call__
__call__(batch_size=1)

Generate encoder and predictor masks for a batch of examples.

Parameters:

Name Type Description Default
batch_size int

The batch size for which to generate masks.

1

Returns:

Type Description
dict[str, Any]

A dictionary of encoder masks and predictor masks.

Source code in mmlearn/datasets/processors/masking.py
def __call__(self, batch_size: int = 1) -> dict[str, Any]:
    """Generate encoder and predictor masks for a batch of examples.

    Parameters
    ----------
    batch_size : int, default=1
        The batch size for which to generate masks.

    Returns
    -------
    dict[str, Any]
        A dictionary of encoder masks and predictor masks.
    """
    seed = torch.randint(
        0, 2**32, (1,)
    ).item()  # Sample random seed for reproducibility
    g = torch.Generator().manual_seed(seed)

    # Sample block sizes
    p_size = self._sample_block_size(
        generator=g, scale=self.pred_mask_scale, aspect_ratio=self.aspect_ratio
    )
    e_size = self._sample_block_size(
        generator=g, scale=self.enc_mask_scale, aspect_ratio=(1.0, 1.0)
    )

    # Generate predictor masks
    masks_pred, masks_enc = [], []
    for _ in range(self.npred):
        mask_p, _ = self._sample_block_mask(p_size)
        # Expand mask to match batch size
        mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
        masks_pred.append(mask_p)

    # Generate encoder masks
    for _ in range(self.nenc):
        mask_e, _ = self._sample_block_mask(e_size)
        # Expand mask to match batch size
        mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
        masks_enc.append(mask_e)

    return {
        "encoder_masks": masks_enc,  # list of tensors of shape (batch_size, N)
        "predictor_masks": masks_pred,  # list of tensors of shape (batch_size, N)
    }

apply_masks

apply_masks(x, masks)

Apply masks to the input tensor by selecting the patches to keep based on the masks.

This function is primarily intended to be used for the 🇵🇾class:i-JEPA <mmlearn.tasks.ijepa.IJEPA>.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, N, D).

required
masks Union[Tensor, list[Tensor]]

A list of mask tensors of shape (N,), (1, N), or (B, N).

required

Returns:

Type Description
Tensor

The masked tensor where only the patches indicated by the masks are kept. The output tensor has shape (B * num_masks, N', D), where N' is the number of patches kept.

Source code in mmlearn/datasets/processors/masking.py
def apply_masks(
    x: torch.Tensor, masks: Union[torch.Tensor, list[torch.Tensor]]
) -> torch.Tensor:
    """
    Apply masks to the input tensor by selecting the patches to keep based on the masks.

    This function is primarily intended to be used for the
    :py:class:`i-JEPA <mmlearn.tasks.ijepa.IJEPA>`.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(B, N, D)``.
    masks : Union[torch.Tensor, list[torch.Tensor]]
        A list of mask tensors of shape ``(N,)``, ``(1, N)``, or ``(B, N)``.

    Returns
    -------
    torch.Tensor
        The masked tensor where only the patches indicated by the masks are kept.
        The output tensor has shape ``(B * num_masks, N', D)``, where ``N'`` is
        the number of patches kept.
    """
    all_x = []
    batch_size = x.size(0)
    for m_ in masks:
        m = m_.to(x.device)

        # Ensure mask is at least 2D
        if m.dim() == 1:
            m = m.unsqueeze(0)  # Shape: (1, N)

        # Expand mask to match the batch size if needed
        if m.size(0) == 1 and batch_size > 1:
            m = m.expand(batch_size, -1)  # Shape: (B, N)

        # Expand mask to match x's dimensions
        m_expanded = (
            m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
        )  # Shape: (B, N, D)

        # Use boolean indexing
        selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
        all_x.append(selected_patches)

    # Concatenate along the batch dimension
    return torch.cat(all_x, dim=0)

tokenizers

Tokenizers - modules that convert raw input to sequences of tokens.

HFTokenizer

A wrapper for loading HuggingFace tokenizers.

This class wraps any huggingface tokenizer that can be initialized with 🇵🇾meth:transformers.AutoTokenizer.from_pretrained. It preprocesses the input text and returns a dictionary with the tokenized text and other relevant information like attention masks.

Parameters:

Name Type Description Default
model_name_or_path str

Pretrained model name or path - same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

required
max_length Optional[int]

Maximum length of the tokenized sequence. This is passed to the tokenizer :meth:__call__ method.

None
padding bool or str

Padding strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

False
truncation Optional[Union[bool, str]]

Truncation strategy. Same as in 🇵🇾meth:transformers.AutoTokenizer.from_pretrained; passed to the tokenizer :meth:__call__ method.

None
**kwargs Any

Additional arguments passed to 🇵🇾meth:transformers.AutoTokenizer.from_pretrained.

{}
Source code in mmlearn/datasets/processors/tokenizers.py
@store(group="datasets/tokenizers", provider="mmlearn")
class HFTokenizer:
    """A wrapper for loading HuggingFace tokenizers.

    This class wraps any huggingface tokenizer that can be initialized with
    :py:meth:`transformers.AutoTokenizer.from_pretrained`. It preprocesses the
    input text and returns a dictionary with the tokenized text and other
    relevant information like attention masks.

    Parameters
    ----------
    model_name_or_path : str
        Pretrained model name or path - same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    max_length : Optional[int], optional, default=None
        Maximum length of the tokenized sequence. This is passed to the tokenizer
        :meth:`__call__` method.
    padding : bool or str, default=False
        Padding strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    truncation : Optional[Union[bool, str]], optional, default=None
        Truncation strategy. Same as in :py:meth:`transformers.AutoTokenizer.from_pretrained`;
        passed to the tokenizer :meth:`__call__` method.
    **kwargs : Any
        Additional arguments passed to :py:meth:`transformers.AutoTokenizer.from_pretrained`.
    """  # noqa: W505

    def __init__(
        self,
        model_name_or_path: str,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Optional[Union[bool, str]] = None,
        **kwargs: Any,
    ) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
        self.max_length = max_length
        self.padding = padding
        self.truncation = truncation

    def __call__(
        self, sentence: Union[str, list[str]], **kwargs: Any
    ) -> dict[str, torch.Tensor]:
        """Tokenize a text or a list of texts using the HuggingFace tokenizer.

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be tokenized.
        **kwargs : Any
            Additional arguments passed to the tokenizer :meth:`__call__` method.

        Returns
        -------
        dict[str, torch.Tensor]
            Tokenized sentence(s).

        Notes
        -----
        The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
        """
        batch_encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors="pt",
            **kwargs,
        )

        if isinstance(
            sentence, str
        ):  # remove batch dimension if input is a single sentence
            for key, value in batch_encoding.items():
                if isinstance(value, torch.Tensor):
                    batch_encoding[key] = torch.squeeze(value, 0)

        # use 'Modalities.TEXT' key for input_ids for consistency
        batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
        return dict(batch_encoding)
__call__
__call__(sentence, **kwargs)

Tokenize a text or a list of texts using the HuggingFace tokenizer.

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be tokenized.

required
**kwargs Any

Additional arguments passed to the tokenizer :meth:__call__ method.

{}

Returns:

Type Description
dict[str, Tensor]

Tokenized sentence(s).

Notes

The input_ids key is replaced with Modalities.TEXT for consistency.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(
    self, sentence: Union[str, list[str]], **kwargs: Any
) -> dict[str, torch.Tensor]:
    """Tokenize a text or a list of texts using the HuggingFace tokenizer.

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be tokenized.
    **kwargs : Any
        Additional arguments passed to the tokenizer :meth:`__call__` method.

    Returns
    -------
    dict[str, torch.Tensor]
        Tokenized sentence(s).

    Notes
    -----
    The ``input_ids`` key is replaced with ``Modalities.TEXT`` for consistency.
    """
    batch_encoding = self.tokenizer(
        sentence,
        max_length=self.max_length,
        padding=self.padding,
        truncation=self.truncation,
        return_tensors="pt",
        **kwargs,
    )

    if isinstance(
        sentence, str
    ):  # remove batch dimension if input is a single sentence
        for key, value in batch_encoding.items():
            if isinstance(value, torch.Tensor):
                batch_encoding[key] = torch.squeeze(value, 0)

    # use 'Modalities.TEXT' key for input_ids for consistency
    batch_encoding[Modalities.TEXT.name] = batch_encoding["input_ids"]
    return dict(batch_encoding)

Img2Seq

Bases: Module

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
img_size tuple of int

The size of the input image.

required
patch_size tuple of int

The size of the patch.

required
n_channels int

The number of channels in the input image.

required
d_model int

The dimension of the output sequence.

required
Source code in mmlearn/datasets/processors/tokenizers.py
class Img2Seq(nn.Module):
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    img_size : tuple of int
        The size of the input image.
    patch_size : tuple of int
        The size of the patch.
    n_channels : int
        The number of channels in the input image.
    d_model : int
        The dimension of the output sequence.

    """

    def __init__(
        self,
        img_size: tuple[int, int],
        patch_size: tuple[int, int],
        n_channels: int,
        d_model: int,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size

        nh, nw = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        n_tokens = nh * nw

        token_dim = patch_size[0] * patch_size[1] * n_channels
        self.linear = nn.Linear(token_dim, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_emb = nn.Parameter(torch.randn(n_tokens, d_model))

    def __call__(self, batch: torch.Tensor) -> torch.Tensor:
        """Convert a batch of images to a batch of sequences.

        Parameters
        ----------
        batch : torch.Tensor
            Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
            ``h`` is the height, ``w`` is the width, and ``c`` is the number of
            channels.

        Returns
        -------
        torch.Tensor
            Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
            ``s`` is the sequence length, and ``d`` is the dimension of the output
            sequence.
        """
        batch = _patchify(batch, self.patch_size)

        b, c, nh, nw, ph, pw = batch.shape

        # Flattening the patches
        batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
        batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

        batch = self.linear(batch)
        cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
        emb: torch.Tensor = batch + self.pos_emb

        return torch.cat([cls, emb], axis=1)
__call__
__call__(batch)

Convert a batch of images to a batch of sequences.

Parameters:

Name Type Description Default
batch Tensor

Batch of images of shape (b, h, w, c) where b is the batch size, h is the height, w is the width, and c is the number of channels.

required

Returns:

Type Description
Tensor

Batch of sequences of shape (b, s, d) where b is the batch size, s is the sequence length, and d is the dimension of the output sequence.

Source code in mmlearn/datasets/processors/tokenizers.py
def __call__(self, batch: torch.Tensor) -> torch.Tensor:
    """Convert a batch of images to a batch of sequences.

    Parameters
    ----------
    batch : torch.Tensor
        Batch of images of shape ``(b, h, w, c)`` where ``b`` is the batch size,
        ``h`` is the height, ``w`` is the width, and ``c`` is the number of
        channels.

    Returns
    -------
    torch.Tensor
        Batch of sequences of shape ``(b, s, d)`` where ``b`` is the batch size,
        ``s`` is the sequence length, and ``d`` is the dimension of the output
        sequence.
    """
    batch = _patchify(batch, self.patch_size)

    b, c, nh, nw, ph, pw = batch.shape

    # Flattening the patches
    batch = torch.permute(batch, [0, 2, 3, 4, 5, 1])
    batch = torch.reshape(batch, [b, nh * nw, ph * pw * c])

    batch = self.linear(batch)
    cls: torch.Tensor = self.cls_token.expand([b, -1, -1])
    emb: torch.Tensor = batch + self.pos_emb

    return torch.cat([cls, emb], axis=1)

transforms

Custom transforms for datasets/inputs.

TrimText

Trim text strings as a preprocessing step before tokenization.

Parameters:

Name Type Description Default
trim_size int

The maximum length of the trimmed text.

required
Source code in mmlearn/datasets/processors/transforms.py
@store(group="datasets/transforms", provider="mmlearn")
class TrimText:
    """Trim text strings as a preprocessing step before tokenization.

    Parameters
    ----------
    trim_size : int
        The maximum length of the trimmed text.
    """

    def __init__(self, trim_size: int) -> None:
        self.trim_size = trim_size

    def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
        """Trim the given sentence(s).

        Parameters
        ----------
        sentence : Union[str, list[str]]
            Sentence(s) to be trimmed.

        Returns
        -------
        Union[str, list[str]]
            Trimmed sentence(s).

        Raises
        ------
        TypeError
            If the input sentence is not a string or list of strings.
        """
        if not isinstance(sentence, (list, str)):
            raise TypeError(
                "Expected argument `sentence` to be a string or list of strings, "
                f"but got {type(sentence)}"
            )

        if isinstance(sentence, str):
            return sentence[: self.trim_size]

        for i, s in enumerate(sentence):
            sentence[i] = s[: self.trim_size]

        return sentence
__call__
__call__(sentence)

Trim the given sentence(s).

Parameters:

Name Type Description Default
sentence Union[str, list[str]]

Sentence(s) to be trimmed.

required

Returns:

Type Description
Union[str, list[str]]

Trimmed sentence(s).

Raises:

Type Description
TypeError

If the input sentence is not a string or list of strings.

Source code in mmlearn/datasets/processors/transforms.py
def __call__(self, sentence: Union[str, list[str]]) -> Union[str, list[str]]:
    """Trim the given sentence(s).

    Parameters
    ----------
    sentence : Union[str, list[str]]
        Sentence(s) to be trimmed.

    Returns
    -------
    Union[str, list[str]]
        Trimmed sentence(s).

    Raises
    ------
    TypeError
        If the input sentence is not a string or list of strings.
    """
    if not isinstance(sentence, (list, str)):
        raise TypeError(
            "Expected argument `sentence` to be a string or list of strings, "
            f"but got {type(sentence)}"
        )

    if isinstance(sentence, str):
        return sentence[: self.trim_size]

    for i, s in enumerate(sentence):
        sentence[i] = s[: self.trim_size]

    return sentence

repeat_interleave_batch

repeat_interleave_batch(x, b, repeat)

Repeat and interleave a tensor across the batch dimension.

Parameters:

Name Type Description Default
x Tensor

Input tensor to be repeated.

required
b int

Size of the batch to be repeated.

required
repeat int

Number of times to repeat each batch.

required

Returns:

Type Description
Tensor

The repeated tensor with shape adjusted for the batch.

Source code in mmlearn/datasets/processors/transforms.py
def repeat_interleave_batch(x: torch.Tensor, b: int, repeat: int) -> torch.Tensor:
    """Repeat and interleave a tensor across the batch dimension.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor to be repeated.
    b : int
        Size of the batch to be repeated.
    repeat : int
        Number of times to repeat each batch.

    Returns
    -------
    torch.Tensor
        The repeated tensor with shape adjusted for the batch.
    """
    n = len(x) // b
    return torch.cat(
        [
            torch.cat([x[i * b : (i + 1) * b] for _ in range(repeat)], dim=0)
            for i in range(n)
        ],
        dim=0,
    )

Modules

mmlearn.modules

Reusable components for building tasks.

ema

Exponential Moving Average (EMA) module.

ExponentialMovingAverage

Exponential Moving Average (EMA) for the input model.

At each step the parameter of the EMA model is updates as the weighted average of the model's parameters.

Parameters:

Name Type Description Default
model Module

The model to apply EMA to.

required
ema_decay float

The initial decay value for EMA.

required
ema_end_decay float

The final decay value for EMA.

required
ema_anneal_end_step int

The number of steps to anneal the decay from ema_decay to ema_end_decay.

required
skip_keys Optional[Union[list[str], Set[str]]]

The keys to skip in the EMA update. These parameters will be copied directly from the model to the EMA model.

None

Raises:

Type Description
RuntimeError

If a deep copy of the model cannot be created.

Source code in mmlearn/modules/ema.py
class ExponentialMovingAverage:
    """Exponential Moving Average (EMA) for the input model.

    At each step the parameter of the EMA model is updates as the weighted average
    of the model's parameters.

    Parameters
    ----------
    model : torch.nn.Module
        The model to apply EMA to.
    ema_decay : float
        The initial decay value for EMA.
    ema_end_decay : float
        The final decay value for EMA.
    ema_anneal_end_step : int
        The number of steps to anneal the decay from ``ema_decay`` to ``ema_end_decay``.
    skip_keys : Optional[Union[list[str], Set[str]]], optional, default=None
        The keys to skip in the EMA update. These parameters will be copied directly
        from the model to the EMA model.

    Raises
    ------
    RuntimeError
        If a deep copy of the model cannot be created.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        ema_decay: float,
        ema_end_decay: float,
        ema_anneal_end_step: int,
        skip_keys: Optional[Union[list[str], Set[str]]] = None,
    ) -> None:
        self.model = self.deepcopy_model(model)

        self.skip_keys: Union[list[str], set[str]] = skip_keys or set()
        self.num_updates = 0
        self.decay = ema_decay  # stores the current decay value
        self.ema_decay = ema_decay
        self.ema_end_decay = ema_end_decay
        self.ema_anneal_end_step = ema_anneal_end_step

        self._model_configured = False

    @staticmethod
    def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
        """Deep copy the model.

        Parameters
        ----------
        model : torch.nn.Module
            The model to copy.

        Returns
        -------
        torch.nn.Module
            The copied model.

        Raises
        ------
        RuntimeError
            If the model cannot be copied.
        """
        try:
            return copy.deepcopy(model)
        except RuntimeError as e:
            raise RuntimeError("Unable to copy the model ", e) from e

    @staticmethod
    def get_annealed_rate(
        start: float,
        end: float,
        curr_step: int,
        total_steps: int,
    ) -> float:
        """Calculate EMA annealing rate."""
        r = end - start
        pct_remaining = 1 - curr_step / total_steps
        return end - r * pct_remaining

    def configure_model(self, device_id: Union[int, torch.device]) -> None:
        """Configure the model for EMA."""
        if self._model_configured:
            return

        self.model.requires_grad_(False)
        self.model.to(device_id)

        self._model_configured = True

    def step(self, new_model: torch.nn.Module) -> None:
        """Perform single EMA update step."""
        if not self._model_configured:
            raise RuntimeError(
                "Model is not configured for EMA. Call `configure_model` first."
            )

        self._update_weights(new_model)
        self._update_ema_decay()

    def restore(self, model: torch.nn.Module) -> torch.nn.Module:
        """Reassign weights from another model.

        Parameters
        ----------
        model : torch.nn.Module
            Model to load weights from.

        Returns
        -------
        torch.nn.Module
            model with new weights
        """
        d = self.model.state_dict()
        model.load_state_dict(d, strict=False)
        return model

    def state_dict(self) -> dict[str, Any]:
        """Return the state dict of the model."""
        return self.model.state_dict()  # type: ignore[no-any-return]

    @torch.no_grad()  # type: ignore[misc]
    def _update_weights(self, new_model: torch.nn.Module) -> None:
        if self.decay < 1:
            ema_state_dict = {}
            ema_params = self.model.state_dict()

            for key, param in new_model.state_dict().items():
                ema_param = ema_params[key].float()

                if param.shape != ema_param.shape:
                    raise ValueError(
                        "Incompatible tensor shapes between student param and teacher param"
                        + "{} vs. {}".format(param.shape, ema_param.shape)
                    )

                if key in self.skip_keys or not param.requires_grad:
                    ema_param = param.to(dtype=ema_param.dtype).clone()
                else:
                    ema_param.mul_(self.decay)
                    ema_param.add_(
                        param.to(dtype=ema_param.dtype),
                        alpha=1 - self.decay,
                    )
                ema_state_dict[key] = ema_param

            self.model.load_state_dict(ema_state_dict, strict=False)
            self.num_updates += 1
        else:
            rank_zero_warn(
                "Exponential Moving Average decay is 1.0, no update is applied to the model.",
                stacklevel=1,
                category=UserWarning,
            )

    def _update_ema_decay(self) -> None:
        if self.ema_decay != self.ema_end_decay:
            if self.num_updates >= self.ema_anneal_end_step:
                decay = self.ema_end_decay
            else:
                decay = self.get_annealed_rate(
                    self.ema_decay,
                    self.ema_end_decay,
                    self.num_updates,
                    self.ema_anneal_end_step,
                )
            self.decay = decay
deepcopy_model staticmethod
deepcopy_model(model)

Deep copy the model.

Parameters:

Name Type Description Default
model Module

The model to copy.

required

Returns:

Type Description
Module

The copied model.

Raises:

Type Description
RuntimeError

If the model cannot be copied.

Source code in mmlearn/modules/ema.py
@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
    """Deep copy the model.

    Parameters
    ----------
    model : torch.nn.Module
        The model to copy.

    Returns
    -------
    torch.nn.Module
        The copied model.

    Raises
    ------
    RuntimeError
        If the model cannot be copied.
    """
    try:
        return copy.deepcopy(model)
    except RuntimeError as e:
        raise RuntimeError("Unable to copy the model ", e) from e
get_annealed_rate staticmethod
get_annealed_rate(start, end, curr_step, total_steps)

Calculate EMA annealing rate.

Source code in mmlearn/modules/ema.py
@staticmethod
def get_annealed_rate(
    start: float,
    end: float,
    curr_step: int,
    total_steps: int,
) -> float:
    """Calculate EMA annealing rate."""
    r = end - start
    pct_remaining = 1 - curr_step / total_steps
    return end - r * pct_remaining
configure_model
configure_model(device_id)

Configure the model for EMA.

Source code in mmlearn/modules/ema.py
def configure_model(self, device_id: Union[int, torch.device]) -> None:
    """Configure the model for EMA."""
    if self._model_configured:
        return

    self.model.requires_grad_(False)
    self.model.to(device_id)

    self._model_configured = True
step
step(new_model)

Perform single EMA update step.

Source code in mmlearn/modules/ema.py
def step(self, new_model: torch.nn.Module) -> None:
    """Perform single EMA update step."""
    if not self._model_configured:
        raise RuntimeError(
            "Model is not configured for EMA. Call `configure_model` first."
        )

    self._update_weights(new_model)
    self._update_ema_decay()
restore
restore(model)

Reassign weights from another model.

Parameters:

Name Type Description Default
model Module

Model to load weights from.

required

Returns:

Type Description
Module

model with new weights

Source code in mmlearn/modules/ema.py
def restore(self, model: torch.nn.Module) -> torch.nn.Module:
    """Reassign weights from another model.

    Parameters
    ----------
    model : torch.nn.Module
        Model to load weights from.

    Returns
    -------
    torch.nn.Module
        model with new weights
    """
    d = self.model.state_dict()
    model.load_state_dict(d, strict=False)
    return model
state_dict
state_dict()

Return the state dict of the model.

Source code in mmlearn/modules/ema.py
def state_dict(self) -> dict[str, Any]:
    """Return the state dict of the model."""
    return self.model.state_dict()  # type: ignore[no-any-return]

encoders

Encoders.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)

HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches

clip

Wrappers and interfaces for CLIP models.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )
HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)
HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)

text

Huggingface text encoder model.

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

vision

Vision encoder implementations.

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches
VisionTransformer

Bases: Module

Vision Transformer.

This module implements a Vision Transformer that processes images using a series of transformer blocks and patch embeddings.

Parameters:

Name Type Description Default
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
img_size List[int]

List of input image sizes.

None
patch_size int

Size of each patch.

16
in_chans int

Number of input channels.

3
embed_dim int

Embedding dimension.

768
depth int

Number of transformer blocks.

12
num_heads int

Number of attention heads.

12
mlp_ratio float

Ratio of hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformer(nn.Module):
    """Vision Transformer.

    This module implements a Vision Transformer that processes images using a
    series of transformer blocks and patch embeddings.

    Parameters
    ----------
    modality : str, optional, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    img_size : List[int], optional, default=None
        List of input image sizes.
    patch_size : int, optional, default=16
        Size of each patch.
    in_chans : int, optional, default=3
        Number of input channels.
    embed_dim : int, optional, default=768
        Embedding dimension.
    depth : int, optional, default=12
        Number of transformer blocks.
    num_heads : int, optional, default=12
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        modality: str = "RGB",
        img_size: Optional[list[int]] = None,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        global_pool: Literal["", "avg", "avgmax", "max", "token"] = "",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
    ) -> None:
        super().__init__()
        assert global_pool in ("", "avg", "avgmax", "max", "token")

        self.modality = Modalities.get_modality(modality)
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        img_size = [224, 224] if img_size is None else img_size

        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim), requires_grad=False
        )
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            int(self.patch_embed.num_patches**0.5),
            cls_token=False,
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Transformer Blocks
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        self.global_pool = global_pool

        # Weight Initialization
        self.init_std = init_std
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp[-1].weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self, inputs: dict[str, Any], return_hidden_states: bool = False
    ) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
        """Forward pass through the Vision Transformer."""
        masks = inputs.get(self.modality.mask)
        if masks is not None and not isinstance(masks, list):
            masks = [masks]

        x = inputs[self.modality.name]
        # -- Patchify x
        x = self.patch_embed(x)

        # -- Add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed

        # -- Mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- Initialize a list to store hidden states
        hidden_states: Optional[list[torch.Tensor]] = (
            [] if return_hidden_states else None
        )

        # -- Forward propagation through blocks
        for _i, blk in enumerate(self.blocks):
            x = blk(x)
            if return_hidden_states and hidden_states is not None:
                hidden_states.append(x)

        # -- Apply normalization if present
        if self.norm is not None:
            x = self.norm(x)

        # -- Apply global pooling
        x = global_pool_nlc(x, pool_type=self.global_pool)

        # -- Return both final output and hidden states if requested
        if return_hidden_states:
            return x, hidden_states
        return (x, None)

    def interpolate_pos_encoding(
        self, x: torch.Tensor, pos_embed: torch.Tensor
    ) -> torch.Tensor:
        """Interpolate positional encoding to match the size of the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
        pos_embed : torch.Tensor
            Positional embedding tensor.

        Returns
        -------
        torch.Tensor
            Interpolated positional encoding.
        """
        npatch = x.shape[1] - 1
        n = pos_embed.shape[1] - 1
        if npatch == n:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
                0, 3, 1, 2
            ),
            scale_factor=math.sqrt(npatch / n),
            mode="bicubic",
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp[-1].weight.data, layer_id + 1)
forward
forward(inputs, return_hidden_states=False)

Forward pass through the Vision Transformer.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self, inputs: dict[str, Any], return_hidden_states: bool = False
) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
    """Forward pass through the Vision Transformer."""
    masks = inputs.get(self.modality.mask)
    if masks is not None and not isinstance(masks, list):
        masks = [masks]

    x = inputs[self.modality.name]
    # -- Patchify x
    x = self.patch_embed(x)

    # -- Add positional embedding to x
    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
    x = x + pos_embed

    # -- Mask x
    if masks is not None:
        x = apply_masks(x, masks)

    # -- Initialize a list to store hidden states
    hidden_states: Optional[list[torch.Tensor]] = (
        [] if return_hidden_states else None
    )

    # -- Forward propagation through blocks
    for _i, blk in enumerate(self.blocks):
        x = blk(x)
        if return_hidden_states and hidden_states is not None:
            hidden_states.append(x)

    # -- Apply normalization if present
    if self.norm is not None:
        x = self.norm(x)

    # -- Apply global pooling
    x = global_pool_nlc(x, pool_type=self.global_pool)

    # -- Return both final output and hidden states if requested
    if return_hidden_states:
        return x, hidden_states
    return (x, None)
interpolate_pos_encoding
interpolate_pos_encoding(x, pos_embed)

Interpolate positional encoding to match the size of the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pos_embed Tensor

Positional embedding tensor.

required

Returns:

Type Description
Tensor

Interpolated positional encoding.

Source code in mmlearn/modules/encoders/vision.py
def interpolate_pos_encoding(
    self, x: torch.Tensor, pos_embed: torch.Tensor
) -> torch.Tensor:
    """Interpolate positional encoding to match the size of the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    pos_embed : torch.Tensor
        Positional embedding tensor.

    Returns
    -------
    torch.Tensor
        Interpolated positional encoding.
    """
    npatch = x.shape[1] - 1
    n = pos_embed.shape[1] - 1
    if npatch == n:
        return pos_embed
    class_emb = pos_embed[:, 0]
    pos_embed = pos_embed[:, 1:]
    dim = x.shape[-1]
    pos_embed = nn.functional.interpolate(
        pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
            0, 3, 1, 2
        ),
        scale_factor=math.sqrt(npatch / n),
        mode="bicubic",
    )
    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
VisionTransformerPredictor

Bases: Module

Vision Transformer Predictor.

This module implements a Vision Transformer that predicts masked tokens using a series of transformer blocks.

Parameters:

Name Type Description Default
num_patches int

The number of patches in the input image.

196
embed_dim int

The embedding dimension.

768
predictor_embed_dim int

The embedding dimension for the predictor.

384
depth int

The number of transformer blocks.

6
num_heads int

The number of attention heads.

12
mlp_ratio float

Ratio of the hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformerPredictor(nn.Module):
    """Vision Transformer Predictor.

    This module implements a Vision Transformer that predicts masked tokens
    using a series of transformer blocks.

    Parameters
    ----------
    num_patches : int
        The number of patches in the input image.
    embed_dim : int, optional, default=768
        The embedding dimension.
    predictor_embed_dim : int, optional, default=384
        The embedding dimension for the predictor.
    depth : int, optional, default=6
        The number of transformer blocks.
    num_heads : int, optional, default=12
        The number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of the hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        num_patches: int = 196,
        embed_dim: int = 768,
        predictor_embed_dim: int = 384,
        depth: int = 6,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.predictor_embed = nn.Linear(self.embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        # Positional Embedding
        self.predictor_pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, predictor_embed_dim), requires_grad=False
        )
        predictor_pos_embed = get_2d_sincos_pos_embed(
            self.predictor_pos_embed.shape[-1],
            int(self.num_patches**0.5),
            cls_token=False,
        )
        self.predictor_pos_embed.data.copy_(
            torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)
        )

        # Transformer Blocks
        self.predictor_blocks = nn.ModuleList(
            [
                Block(
                    dim=predictor_embed_dim,
                    num_heads=self.num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )

        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)

        # Weight Initialization
        self.init_std = init_std
        _trunc_normal(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self,
        x: torch.Tensor,
        masks_x: Union[torch.Tensor, list[torch.Tensor]],
        masks: Union[torch.Tensor, list[torch.Tensor]],
    ) -> torch.Tensor:
        """Forward pass through the Vision Transformer Predictor."""
        assert (masks is not None) and (masks_x is not None), (
            "Cannot run predictor without mask indices"
        )

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        b = len(x) // len(masks_x)

        # -- Map from encoder-dim to predictor-dim
        x = self.predictor_embed(x)

        # -- Add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, n_ctxt, d = x.shape

        # -- Concatenate mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- Forward propagation
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- Return predictions for mask tokens
        x = x[:, n_ctxt:]
        return self.predictor_proj(x)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.predictor_blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp.fc2.weight.data, layer_id + 1)
forward
forward(x, masks_x, masks)

Forward pass through the Vision Transformer Predictor.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self,
    x: torch.Tensor,
    masks_x: Union[torch.Tensor, list[torch.Tensor]],
    masks: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
    """Forward pass through the Vision Transformer Predictor."""
    assert (masks is not None) and (masks_x is not None), (
        "Cannot run predictor without mask indices"
    )

    if not isinstance(masks_x, list):
        masks_x = [masks_x]

    if not isinstance(masks, list):
        masks = [masks]

    # -- Batch Size
    b = len(x) // len(masks_x)

    # -- Map from encoder-dim to predictor-dim
    x = self.predictor_embed(x)

    # -- Add positional embedding to x tokens
    x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
    x += apply_masks(x_pos_embed, masks_x)

    _, n_ctxt, d = x.shape

    # -- Concatenate mask tokens to x
    pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
    pos_embs = apply_masks(pos_embs, masks)
    pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
    pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
    pred_tokens += pos_embs
    x = x.repeat(len(masks), 1, 1)
    x = torch.cat([x, pred_tokens], dim=1)

    # -- Forward propagation
    for blk in self.predictor_blocks:
        x = blk(x)
    x = self.predictor_norm(x)

    # -- Return predictions for mask tokens
    x = x[:, n_ctxt:]
    return self.predictor_proj(x)
vit_predictor
vit_predictor(kwargs=None)

Create a VisionTransformerPredictor model.

Parameters:

Name Type Description Default
kwargs dict[str, Any]

Keyword arguments for the predictor.

None

Returns:

Type Description
VisionTransformerPredictor

An instance of VisionTransformerPredictor.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformerPredictor,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_predictor(
    kwargs: Optional[dict[str, Any]] = None,
) -> VisionTransformerPredictor:
    """Create a VisionTransformerPredictor model.

    Parameters
    ----------
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the predictor.

    Returns
    -------
    VisionTransformerPredictor
        An instance of VisionTransformerPredictor.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
vit_tiny
vit_tiny(patch_size=16, kwargs=None)

Create a VisionTransformer model with tiny configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_tiny(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with tiny configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_small
vit_small(patch_size=16, kwargs=None)

Create a VisionTransformer model with small configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_small(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with small configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_base
vit_base(patch_size=16, kwargs=None)

Create a VisionTransformer model with base configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_base(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with base configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_large
vit_large(patch_size=16, kwargs=None)

Create a VisionTransformer model with large configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_large(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with large configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_huge
vit_huge(patch_size=16, kwargs=None)

Create a VisionTransformer model with huge configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_huge(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with huge configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
vit_giant
vit_giant(patch_size=16, kwargs=None)

Create a VisionTransformer model with giant configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_giant(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with giant configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1408,
        depth=40,
        num_heads=16,
        mlp_ratio=48 / 11,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

layers

Custom, reusable layers for models and tasks.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids

attention

Attention modules for Vision Transformer (ViT) and related models.

Attention

Bases: Module

Multi-head Self-Attention Mechanism.

Parameters:

Name Type Description Default
dim int

Number of input dimensions.

required
num_heads int

Number of attention heads.

8
qkv_bias bool

If True, adds a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override the default scale factor for the dot-product attention.

None
attn_drop float

Dropout probability for the attention weights.

0.0
proj_drop float

Dropout probability for the output of the attention layer.

0.0
Source code in mmlearn/modules/layers/attention.py
class Attention(nn.Module):
    """Multi-head Self-Attention Mechanism.

    Parameters
    ----------
    dim : int
        Number of input dimensions.
    num_heads : int, optional, default=8
        Number of attention heads.
    qkv_bias : bool, optional, default=False
        If True, adds a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default scale factor for the dot-product attention.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention weights.
    proj_drop : float, optional, default=0.0
        Dropout probability for the output of the attention layer.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through the multi-head self-attention module.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            The output tensor and the attention weights.
        """
        b, n, c = x.shape
        qkv = (
            self.qkv(x)
            .reshape(b, n, 3, self.num_heads, c // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn
forward
forward(x)

Forward pass through the multi-head self-attention module.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
tuple[Tensor, Tensor]

The output tensor and the attention weights.

Source code in mmlearn/modules/layers/attention.py
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass through the multi-head self-attention module.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        The output tensor and the attention weights.
    """
    b, n, c = x.shape
    qkv = (
        self.qkv(x)
        .reshape(b, n, 3, self.num_heads, c // self.num_heads)
        .permute(2, 0, 3, 1, 4)
    )
    q, k, v = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(b, n, c)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x, attn

embedding

Embedding layers.

PatchEmbed

Bases: Module

Image to Patch Embedding.

This module divides an image into patches and embeds them as a sequence of vectors.

Parameters:

Name Type Description Default
img_size int

Size of the input image (assumed to be square).

224
patch_size int

Size of each image patch (assumed to be square).

16
in_chans int

Number of input channels in the image.

3
embed_dim int

Dimension of the output embeddings.

768
Source code in mmlearn/modules/layers/embedding.py
class PatchEmbed(nn.Module):
    """Image to Patch Embedding.

    This module divides an image into patches and embeds them as a sequence of vectors.

    Parameters
    ----------
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    patch_size : int, optional, default=16
        Size of each image patch (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    embed_dim : int, optional, default=768
        Dimension of the output embeddings.

    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass to convert an image into patch embeddings."""
        return self.proj(x).flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass to convert an image into patch embeddings.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to convert an image into patch embeddings."""
    return self.proj(x).flatten(2).transpose(1, 2)
ConvEmbed

Bases: Module

3x3 Convolution stems for ViT following ViTC models.

This module builds convolutional stems for Vision Transformers (ViT) with intermediate batch normalization and ReLU activation.

Parameters:

Name Type Description Default
channels list[int]

list of channel sizes for each convolution layer.

required
strides list[int]

list of stride sizes for each convolution layer.

required
img_size int

Size of the input image (assumed to be square).

224
in_chans int

Number of input channels in the image.

3
batch_norm bool

Whether to include batch normalization after each convolution layer.

True
Source code in mmlearn/modules/layers/embedding.py
class ConvEmbed(nn.Module):
    """3x3 Convolution stems for ViT following ViTC models.

    This module builds convolutional stems for Vision Transformers (ViT)
    with intermediate batch normalization and ReLU activation.

    Parameters
    ----------
    channels : list[int]
        list of channel sizes for each convolution layer.
    strides : list[int]
        list of stride sizes for each convolution layer.
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    batch_norm : bool, optional, default=True
        Whether to include batch normalization after each convolution layer.

    """

    def __init__(
        self,
        channels: list[int],
        strides: list[int],
        img_size: int = 224,
        in_chans: int = 3,
        batch_norm: bool = True,
    ) -> None:
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [
                nn.Conv2d(
                    channels[i],
                    channels[i + 1],
                    kernel_size=3,
                    stride=strides[i],
                    padding=1,
                    bias=(not batch_norm),
                )
            ]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i + 1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [
            nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])
        ]
        self.stem = nn.Sequential(*stem)

        # Compute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size // stride_prod) ** 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the convolutional embedding layers."""
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass through the convolutional embedding layers.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the convolutional embedding layers."""
    p = self.stem(x)
    return p.flatten(2).transpose(1, 2)
get_2d_sincos_pos_embed
get_2d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 2D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid (both height and width).

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_sizegrid_size, embed_dim] or [1 + grid_sizegrid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid (both height and width).
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim] or
        [1 + grid_size*grid_size, embed_dim] if cls_token is True.
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
get_2d_sincos_pos_embed_from_grid
get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

Generate 2D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid ndarray

The grid of positions with shape [2, 1, grid_size, grid_size].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [grid_size*grid_size, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid : np.ndarray
        The grid of positions with shape [2, 1, grid_size, grid_size].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim].
    """
    assert embed_dim % 2 == 0

    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    return np.concatenate([emb_h, emb_w], axis=1)
get_1d_sincos_pos_embed
get_1d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 1D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid.

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_size, embed_dim] or [1 + grid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid.
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size, embed_dim] or
        [1 + grid_size, embed_dim] if cls_token is True.
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
get_1d_sincos_pos_embed_from_grid
get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

Generate 1D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
pos ndarray

A list of positions to be encoded, with shape [M,].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [M, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    pos : np.ndarray
        A list of positions to be encoded, with shape [M,].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [M, embed_dim].
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    return np.concatenate([emb_sin, emb_cos], axis=1)

logit_scaling

Learnable logit scaling layer.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )

mlp

Multi-layer perceptron (MLP).

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)

normalization

Normalization layers.

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)

patch_dropout

Patch dropout layer.

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids

transformer_block

Transformer Block and Embedding Modules for Vision Transformers (ViT).

DropPath

Bases: Module

Drop paths (Stochastic Depth) per sample.

Parameters:

Name Type Description Default
drop_prob float

Probability of dropping paths.

0.0
Source code in mmlearn/modules/layers/transformer_block.py
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample.

    Parameters
    ----------
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through DropPath module."""
        return drop_path(x, self.drop_prob, self.training)
forward
forward(x)

Forward pass through DropPath module.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through DropPath module."""
    return drop_path(x, self.drop_prob, self.training)
Block

Bases: Module

Transformer Block.

This module represents a Transformer block that includes self-attention, normalization layers, and a feedforward multi-layer perceptron (MLP) network.

Parameters:

Name Type Description Default
dim int

The input and output dimension of the block.

required
num_heads int

Number of attention heads.

required
mlp_ratio float

Ratio of hidden dimension to the input dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override default qk scale of head_dim ** -0.5 if set.

None
drop float

Dropout probability for the output of attention and MLP layers.

0.0
attn_drop float

Dropout probability for the attention scores.

0.0
drop_path float

Stochastic depth rate, a form of layer dropout.

0.0
act_layer Callable[..., Module]

Activation layer in the MLP.

nn.GELU
norm_layer Callable[..., Module]

Normalization layer.

torch.nn.LayerNorm
Source code in mmlearn/modules/layers/transformer_block.py
class Block(nn.Module):
    """Transformer Block.

    This module represents a Transformer block that includes self-attention,
    normalization layers, and a feedforward multi-layer perceptron (MLP) network.

    Parameters
    ----------
    dim : int
        The input and output dimension of the block.
    num_heads : int
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension to the input dimension in the MLP.
    qkv_bias : bool, optional, default=False
        If True, add a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override default qk scale of head_dim ** -0.5 if set.
    drop : float, optional, default=0.0
        Dropout probability for the output of attention and MLP layers.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention scores.
    drop_path : float, optional, default=0.0
        Stochastic depth rate, a form of layer dropout.
    act_layer : Callable[..., torch.nn.Module], optional, default=nn.GELU
        Activation layer in the MLP.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer.

    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)

        self.mlp = MLP(
            in_dim=dim,
            hidden_dims_multiplier=[mlp_ratio],
            activation_layer=act_layer,
            bias=True,
            dropout=drop,
        )

    def forward(
        self, x: torch.Tensor, return_attention: bool = False
    ) -> Union[torch.Tensor, torch.Tensor]:
        """Forward pass through the Transformer Block."""
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        return x + self.drop_path(self.mlp(self.norm2(x)))
forward
forward(x, return_attention=False)

Forward pass through the Transformer Block.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(
    self, x: torch.Tensor, return_attention: bool = False
) -> Union[torch.Tensor, torch.Tensor]:
    """Forward pass through the Transformer Block."""
    y, attn = self.attn(self.norm1(x))
    if return_attention:
        return attn
    x = x + self.drop_path(y)
    return x + self.drop_path(self.mlp(self.norm2(x)))
drop_path
drop_path(x, drop_prob=0.0, training=False)

Drop paths (Stochastic Depth) for regularization during training.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
drop_prob float

Probability of dropping paths.

0.0
training bool

Whether the model is in training mode.

False

Returns:

Name Type Description
output Tensor

Output tensor after applying drop path.

Source code in mmlearn/modules/layers/transformer_block.py
def drop_path(
    x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
    """Drop paths (Stochastic Depth) for regularization during training.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    training : bool, optional, default=False
        Whether the model is in training mode.

    Returns
    -------
    output : torch.Tensor
        Output tensor after applying drop path.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    return x.div(keep_prob) * random_tensor

losses

Loss functions.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

contrastive

Implementations of the contrastive loss and its variants.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()

data2vec

Implementation of Data2vec loss function.

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

lr_schedulers

Learning rate schedulers for training models.

linear_warmup_cosine_annealing_lr

linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

linear_warmup_cosine_lr

Linear warmup cosine annealing learning rate scheduler.

linear_warmup_cosine_annealing_lr
linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

metrics

Metrics for evaluating models.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

retrieval_recall

Retrieval Recall@K metric.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

Encoders

mmlearn.modules.encoders

Encoders.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)

HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )

get_intermediate_layers

get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

get_patch_info

get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches

clip

Wrappers and interfaces for CLIP models.

HFCLIPTextEncoder

Bases: Module

Wrapper around the CLIPTextModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",  # required for `peft_config` to be converted to a `PeftConfig` object
)
class HFCLIPTextEncoder(nn.Module):
    """Wrapper around the ``CLIPTextModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden
            states, and the attention weights, if ``output_attentions`` is set
            to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get("attention_mask")
            or inputs.get(Modalities.TEXT.attention_mask),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden
        states, and the attention weights, if ``output_attentions`` is set
        to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get("attention_mask")
        or inputs.get(Modalities.TEXT.attention_mask),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

HFCLIPVisionEncoder

Bases: Module

Wrapper around the CLIPVisionModel from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias Optional[float]

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoder(nn.Module):
    """Wrapper around the ``CLIPVisionModel`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : Optional[float], optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModel,
            model_name_or_path=model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model.vision_model
        self.pooling_layer = pooling_layer
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.

        """
        # FIXME: handle other vision modalities
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.encoder(
            inputs_embeds=hidden_states,
            output_attentions=inputs.get(
                "output_attentions", self.model.config.output_attentions
            ),
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden_state = encoder_outputs[0]
        if self.pooling_layer is not None:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.

    """
    # FIXME: handle other vision modalities
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.encoder(
        inputs_embeds=hidden_states,
        output_attentions=inputs.get(
            "output_attentions", self.model.config.output_attentions
        ),
        output_hidden_states=True,
        return_dict=True,
    )

    last_hidden_state = encoder_outputs[0]
    if self.pooling_layer is not None:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=encoder_outputs.hidden_states,
        attentions=encoder_outputs.attentions,
    )

HFCLIPTextEncoderWithProjection

Bases: Module

Wrapper around the CLIPTextModelWithProjection from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPTextEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPTextModelWithProjection`` from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPTextModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPTextConfig,
        )

        model = _freeze_text_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The text embeddings. Will be a tuple with a single element.
        """
        input_ids = inputs[Modalities.TEXT.name]
        attention_mask: Optional[torch.Tensor] = inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        )
        position_ids = inputs.get("position_ids")

        if self.use_all_token_embeddings:
            text_outputs = self.model.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            )
            # TODO: add more options for pooling before projection
            text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
        else:
            text_embeds = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                return_dict=True,
            ).text_embeds

        return (text_embeds,)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
tuple[Tensor]

The text embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The text embeddings. Will be a tuple with a single element.
    """
    input_ids = inputs[Modalities.TEXT.name]
    attention_mask: Optional[torch.Tensor] = inputs.get(
        "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
    )
    position_ids = inputs.get("position_ids")

    if self.use_all_token_embeddings:
        text_outputs = self.model.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        )
        # TODO: add more options for pooling before projection
        text_embeds = self.model.text_projection(text_outputs.last_hidden_state)
    else:
        text_embeds = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            return_dict=True,
        ).text_embeds

    return (text_embeds,)

HFCLIPVisionEncoderWithProjection

Bases: Module

Wrapper around the CLIPVisionModelWithProjection class from HuggingFace.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
use_all_token_embeddings bool

Whether to use all token embeddings for the text. If False the first token embedding will be used.

False
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N` layers of the model are frozen. If it is a float, the firstN`` percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
patch_dropout_rate float

The proportion of patch embeddings to drop out.

0.0
patch_dropout_shuffle bool

Whether to shuffle the patches while applying patch dropout.

False
patch_dropout_bias float

The bias to apply to the patch dropout mask.

None
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs dict[str, Any]

Additional keyword arguments to pass to the model configuration.

None

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/clip.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name_or_path="openai/clip-vit-base-patch16",
    hydra_convert="object",
)
class HFCLIPVisionEncoderWithProjection(nn.Module):
    """Wrapper around the ``CLIPVisionModelWithProjection`` class from HuggingFace.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    use_all_token_embeddings : bool, default=False
        Whether to use all token embeddings for the text. If ``False`` the first token
        embedding will be used.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    patch_dropout_rate : float, default=0.0
        The proportion of patch embeddings to drop out.
    patch_dropout_shuffle : bool, default=False
        Whether to shuffle the patches while applying patch dropout.
    patch_dropout_bias : float, optional, default=None
        The bias to apply to the patch dropout mask.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : dict[str, Any], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.

    """

    def __init__(
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        use_all_token_embeddings: bool = False,
        patch_dropout_rate: float = 0.0,
        patch_dropout_shuffle: bool = False,
        patch_dropout_bias: Optional[float] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        _warn_freeze_with_peft(peft_config, freeze_layers)

        self.use_all_token_embeddings = use_all_token_embeddings

        model = hf_utils.load_huggingface_model(
            transformers.CLIPVisionModelWithProjection,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
            config_type=CLIPVisionConfig,
        )

        model = _freeze_vision_model(model, freeze_layers, freeze_layer_norm)
        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.patch_dropout = None
        if patch_dropout_rate > 0:
            self.patch_dropout = PatchDropout(
                keep_rate=1 - patch_dropout_rate,
                token_shuffling=patch_dropout_shuffle,
                bias=patch_dropout_bias,
            )

    def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The image tensor will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        tuple[torch.Tensor]
            The image embeddings. Will be a tuple with a single element.
        """
        pixel_values = inputs[Modalities.RGB.name]
        hidden_states = self.model.vision_model.embeddings(pixel_values)
        if self.patch_dropout is not None:
            hidden_states = self.patch_dropout(hidden_states)
        hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

        encoder_outputs = self.model.vision_model.encoder(
            inputs_embeds=hidden_states, return_dict=True
        )

        last_hidden_state = encoder_outputs.last_hidden_state
        if self.use_all_token_embeddings:
            pooled_output = last_hidden_state
        else:
            pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.model.vision_model.post_layernorm(pooled_output)

        return (self.model.visual_projection(pooled_output),)
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image tensor will be expected under the Modalities.RGB key.

required

Returns:

Type Description
tuple[Tensor]

The image embeddings. Will be a tuple with a single element.

Source code in mmlearn/modules/encoders/clip.py
def forward(self, inputs: dict[str, Any]) -> tuple[torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The image tensor will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    tuple[torch.Tensor]
        The image embeddings. Will be a tuple with a single element.
    """
    pixel_values = inputs[Modalities.RGB.name]
    hidden_states = self.model.vision_model.embeddings(pixel_values)
    if self.patch_dropout is not None:
        hidden_states = self.patch_dropout(hidden_states)
    hidden_states = self.model.vision_model.pre_layrnorm(hidden_states)

    encoder_outputs = self.model.vision_model.encoder(
        inputs_embeds=hidden_states, return_dict=True
    )

    last_hidden_state = encoder_outputs.last_hidden_state
    if self.use_all_token_embeddings:
        pooled_output = last_hidden_state
    else:
        pooled_output = last_hidden_state[:, 0, :]
    pooled_output = self.model.vision_model.post_layernorm(pooled_output)

    return (self.model.visual_projection(pooled_output),)

text

Huggingface text encoder model.

HFTextEncoder

Bases: Module

Wrapper around huggingface models in the AutoModelForTextEncoding class.

Parameters:

Name Type Description Default
model_name_or_path str

The huggingface model name or a local path from which to load the model.

required
pretrained bool

Whether to load the pretrained weights or not.

True
pooling_layer Optional[Module]

Pooling layer to apply to the last hidden state of the model.

None
freeze_layers Union[int, float, list[int], bool]

Whether to freeze layers of the model and which layers to freeze. If True, all model layers are frozen. If it is an integer, the first N layers of the model are frozen. If it is a float, the first N percent of the layers are frozen. If it is a list of integers, the layers at the indices in the list are frozen.

False
freeze_layer_norm bool

Whether to freeze the layer normalization layers of the model.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration.

None

Raises:

Type Description
ValueError

If the model is a decoder model or if freezing individual layers is not supported for the model type.

Warns:

Type Description
UserWarning

If both peft_config and freeze_layers are set. The peft_config will override the freeze_layers setting.

Source code in mmlearn/modules/encoders/text.py
@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class HFTextEncoder(nn.Module):
    """Wrapper around huggingface models in the ``AutoModelForTextEncoding`` class.

    Parameters
    ----------
    model_name_or_path : str
        The huggingface model name or a local path from which to load the model.
    pretrained : bool, default=True
        Whether to load the pretrained weights or not.
    pooling_layer : Optional[torch.nn.Module], optional, default=None
        Pooling layer to apply to the last hidden state of the model.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze layers of the model and which layers to freeze. If ``True``,
        all model layers are frozen. If it is an integer, the first ``N`` layers of
        the model are frozen. If it is a float, the first ``N`` percent of the layers
        are frozen. If it is a list of integers, the layers at the indices in the
        list are frozen.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer normalization layers of the model.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.

    Raises
    ------
    ValueError
        If the model is a decoder model or if freezing individual layers is not
        supported for the model type.

    Warns
    -----
    UserWarning
        If both ``peft_config`` and ``freeze_layers`` are set. The ``peft_config``
        will override the ``freeze_layers`` setting.


    """

    def __init__(  # noqa: PLR0912
        self,
        model_name_or_path: str,
        pretrained: bool = True,
        pooling_layer: Optional[nn.Module] = None,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_config_kwargs: Optional[dict[str, Any]] = None,
    ):
        super().__init__()
        if model_config_kwargs is None:
            model_config_kwargs = {}
        model_config_kwargs["output_hidden_states"] = True
        model_config_kwargs["add_pooling_layer"] = False
        model = hf_utils.load_huggingface_model(
            AutoModelForTextEncoding,
            model_name_or_path,
            load_pretrained_weights=pretrained,
            model_config_kwargs=model_config_kwargs,
        )
        if hasattr(model.config, "is_decoder") and model.config.is_decoder:
            raise ValueError("Model is a decoder. Only encoder models are supported.")

        if not pretrained and freeze_layers:
            rank_zero_warn(
                "Freezing layers when loading a model with random weights may lead to "
                "unexpected behavior. Consider setting `freeze_layers=False` if "
                "`pretrained=False`.",
            )

        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "LayerNorm" in name else False
                )

        if isinstance(
            freeze_layers, (float, int, list)
        ) and model.config.model_type in ["flaubert", "xlm"]:
            # flaubert and xlm models have a different architecture that does not
            # support freezing individual layers in the same way as other models
            raise ValueError(
                f"Freezing individual layers is not supported for {model.config.model_type} "
                "models. Please use `freeze_layers=False` or `freeze_layers=True`."
            )

        # get list of layers
        embeddings = model.embeddings
        encoder = getattr(model, "encoder", None) or getattr(
            model, "transformer", model
        )
        encoder_layers = (
            getattr(encoder, "layer", None)
            or getattr(encoder, "layers", None)
            or getattr(encoder, "block", None)
        )
        if encoder_layers is None and hasattr(encoder, "albert_layer_groups"):
            encoder_layers = [
                layer
                for group in encoder.albert_layer_groups
                for layer in group.albert_layers
            ]
        modules = [embeddings]
        if encoder_layers is not None and isinstance(encoder_layers, list):
            modules.extend(encoder_layers)

        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "LayerNorm" in name else False
                        )

        if peft_config is not None:
            model = hf_utils._wrap_peft_model(model, peft_config)

        self.model = model
        self.pooling_layer = pooling_layer

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``input_ids`` will be expected under the
            ``Modalities.TEXT`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model, including the last hidden state, all hidden states,
            and the attention weights, if ``output_attentions`` is set to ``True``.
        """
        outputs = self.model(
            input_ids=inputs[Modalities.TEXT.name],
            attention_mask=inputs.get(
                "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
            ),
            position_ids=inputs.get("position_ids"),
            output_attentions=inputs.get("output_attentions"),
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
        if self.pooling_layer:
            last_hidden_state = self.pooling_layer(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The input_ids will be expected under the Modalities.TEXT key.

required

Returns:

Type Description
BaseModelOutput

The output of the model, including the last hidden state, all hidden states, and the attention weights, if output_attentions is set to True.

Source code in mmlearn/modules/encoders/text.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``input_ids`` will be expected under the
        ``Modalities.TEXT`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model, including the last hidden state, all hidden states,
        and the attention weights, if ``output_attentions`` is set to ``True``.
    """
    outputs = self.model(
        input_ids=inputs[Modalities.TEXT.name],
        attention_mask=inputs.get(
            "attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
        ),
        position_ids=inputs.get("position_ids"),
        output_attentions=inputs.get("output_attentions"),
        return_dict=True,
    )
    last_hidden_state = outputs.hidden_states[-1]  # NOTE: no layer norm applied
    if self.pooling_layer:
        last_hidden_state = self.pooling_layer(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

vision

Vision encoder implementations.

TimmViT

Bases: Module

Vision Transformer model from timm.

Parameters:

Name Type Description Default
model_name str

The name of the model to use.

required
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
projection_dim int

The dimension of the projection head.

768
pretrained bool

Whether to use the pretrained weights.

True
freeze_layers Union[int, float, list[int], bool]

Whether to freeze the layers.

False
freeze_layer_norm bool

Whether to freeze the layer norm.

True
peft_config Optional[PeftConfig]

The configuration from the peft <https://huggingface.co/docs/peft/index>_ library to use to wrap the model for parameter-efficient finetuning.

None
model_kwargs Optional[dict[str, Any]]

Additional keyword arguments for the model.

None
Source code in mmlearn/modules/encoders/vision.py
@store(
    group="modules/encoders",
    provider="mmlearn",
    model_name="vit_base_patch16_224",
    hydra_convert="object",
)
class TimmViT(nn.Module):
    """Vision Transformer model from timm.

    Parameters
    ----------
    model_name : str
        The name of the model to use.
    modality : str, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    projection_dim : int, default=768
        The dimension of the projection head.
    pretrained : bool, default=True
        Whether to use the pretrained weights.
    freeze_layers : Union[int, float, list[int], bool], default=False
        Whether to freeze the layers.
    freeze_layer_norm : bool, default=True
        Whether to freeze the layer norm.
    peft_config : Optional[PeftConfig], optional, default=None
        The configuration from the `peft <https://huggingface.co/docs/peft/index>`_
        library to use to wrap the model for parameter-efficient finetuning.
    model_kwargs : Optional[dict[str, Any]], default=None
        Additional keyword arguments for the model.
    """

    def __init__(
        self,
        model_name: str,
        modality: str = "RGB",
        projection_dim: int = 768,
        pretrained: bool = True,
        freeze_layers: Union[int, float, list[int], bool] = False,
        freeze_layer_norm: bool = True,
        peft_config: Optional["PeftConfig"] = None,
        model_kwargs: Optional[dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        self.modality = Modalities.get_modality(modality)
        if model_kwargs is None:
            model_kwargs = {}

        self.model: TimmVisionTransformer = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=projection_dim,
            **model_kwargs,
        )
        assert isinstance(self.model, TimmVisionTransformer), (
            f"Model {model_name} is not a Vision Transformer. "
            "Please provide a model name that corresponds to a Vision Transformer."
        )

        self._freeze_layers(freeze_layers, freeze_layer_norm)

        if peft_config is not None:
            self.model = hf_utils._wrap_peft_model(self.model, peft_config)

    def _freeze_layers(
        self, freeze_layers: Union[int, float, list[int], bool], freeze_layer_norm: bool
    ) -> None:
        """Freeze the layers of the model.

        Parameters
        ----------
        freeze_layers : Union[int, float, list[int], bool]
            Whether to freeze the layers.
        freeze_layer_norm : bool
            Whether to freeze the layer norm.
        """
        if isinstance(freeze_layers, bool) and freeze_layers:
            for name, param in self.model.named_parameters():
                param.requires_grad = (
                    (not freeze_layer_norm) if "norm" in name else False
                )

        modules = [self.model.patch_embed, *self.model.blocks, self.model.norm]
        if isinstance(freeze_layers, float):
            freeze_layers = int(freeze_layers * len(modules))
        if isinstance(freeze_layers, int):
            freeze_layers = list(range(freeze_layers))

        if isinstance(freeze_layers, list):
            for idx, module in enumerate(modules):
                if idx in freeze_layers:
                    for name, param in module.named_parameters():
                        param.requires_grad = (
                            (not freeze_layer_norm) if "norm" in name else False
                        )

    def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the
            ``Modalities.RGB`` key.

        Returns
        -------
        BaseModelOutput
            The output of the model.
        """
        x = inputs[self.modality.name]
        last_hidden_state, hidden_states = self.model.forward_intermediates(
            x, output_fmt="NLC"
        )
        last_hidden_state = self.model.forward_head(last_hidden_state)

        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=hidden_states
        )

    def get_intermediate_layers(
        self, inputs: dict[str, Any], n: int = 1
    ) -> list[torch.Tensor]:
        """Get the output of the intermediate layers.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input data. The ``image`` will be expected under the ``Modalities.RGB``
            key.
        n : int, default=1
            The number of intermediate layers to return.

        Returns
        -------
        list[torch.Tensor]
            The outputs of the last n intermediate layers.
        """
        return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore

    def get_patch_info(self) -> tuple[int, int]:
        """Get patch size and number of patches.

        Returns
        -------
        tuple[int, int]
            Patch size and number of patches.
        """
        patch_size = self.model.patch_embed.patch_size[0]
        num_patches = self.model.patch_embed.num_patches
        return patch_size, num_patches
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required

Returns:

Type Description
BaseModelOutput

The output of the model.

Source code in mmlearn/modules/encoders/vision.py
def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the
        ``Modalities.RGB`` key.

    Returns
    -------
    BaseModelOutput
        The output of the model.
    """
    x = inputs[self.modality.name]
    last_hidden_state, hidden_states = self.model.forward_intermediates(
        x, output_fmt="NLC"
    )
    last_hidden_state = self.model.forward_head(last_hidden_state)

    return BaseModelOutput(
        last_hidden_state=last_hidden_state, hidden_states=hidden_states
    )
get_intermediate_layers
get_intermediate_layers(inputs, n=1)

Get the output of the intermediate layers.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input data. The image will be expected under the Modalities.RGB key.

required
n int

The number of intermediate layers to return.

1

Returns:

Type Description
list[Tensor]

The outputs of the last n intermediate layers.

Source code in mmlearn/modules/encoders/vision.py
def get_intermediate_layers(
    self, inputs: dict[str, Any], n: int = 1
) -> list[torch.Tensor]:
    """Get the output of the intermediate layers.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input data. The ``image`` will be expected under the ``Modalities.RGB``
        key.
    n : int, default=1
        The number of intermediate layers to return.

    Returns
    -------
    list[torch.Tensor]
        The outputs of the last n intermediate layers.
    """
    return self.model.get_intermediate_layers(inputs[Modalities.RGB], n)  # type: ignore
get_patch_info
get_patch_info()

Get patch size and number of patches.

Returns:

Type Description
tuple[int, int]

Patch size and number of patches.

Source code in mmlearn/modules/encoders/vision.py
def get_patch_info(self) -> tuple[int, int]:
    """Get patch size and number of patches.

    Returns
    -------
    tuple[int, int]
        Patch size and number of patches.
    """
    patch_size = self.model.patch_embed.patch_size[0]
    num_patches = self.model.patch_embed.num_patches
    return patch_size, num_patches

VisionTransformer

Bases: Module

Vision Transformer.

This module implements a Vision Transformer that processes images using a series of transformer blocks and patch embeddings.

Parameters:

Name Type Description Default
modality str

The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

"RGB"
img_size List[int]

List of input image sizes.

None
patch_size int

Size of each patch.

16
in_chans int

Number of input channels.

3
embed_dim int

Embedding dimension.

768
depth int

Number of transformer blocks.

12
num_heads int

Number of attention heads.

12
mlp_ratio float

Ratio of hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformer(nn.Module):
    """Vision Transformer.

    This module implements a Vision Transformer that processes images using a
    series of transformer blocks and patch embeddings.

    Parameters
    ----------
    modality : str, optional, default="RGB"
        The modality of the input data. This allows this model to be used with different
        image modalities e.g. RGB, Depth, etc.
    img_size : List[int], optional, default=None
        List of input image sizes.
    patch_size : int, optional, default=16
        Size of each patch.
    in_chans : int, optional, default=3
        Number of input channels.
    embed_dim : int, optional, default=768
        Embedding dimension.
    depth : int, optional, default=12
        Number of transformer blocks.
    num_heads : int, optional, default=12
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        modality: str = "RGB",
        img_size: Optional[list[int]] = None,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        global_pool: Literal["", "avg", "avgmax", "max", "token"] = "",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
    ) -> None:
        super().__init__()
        assert global_pool in ("", "avg", "avgmax", "max", "token")

        self.modality = Modalities.get_modality(modality)
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        img_size = [224, 224] if img_size is None else img_size

        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size[0],
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        # Positional Embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim), requires_grad=False
        )
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            int(self.patch_embed.num_patches**0.5),
            cls_token=False,
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Transformer Blocks
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)

        self.global_pool = global_pool

        # Weight Initialization
        self.init_std = init_std
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp[-1].weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self, inputs: dict[str, Any], return_hidden_states: bool = False
    ) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
        """Forward pass through the Vision Transformer."""
        masks = inputs.get(self.modality.mask)
        if masks is not None and not isinstance(masks, list):
            masks = [masks]

        x = inputs[self.modality.name]
        # -- Patchify x
        x = self.patch_embed(x)

        # -- Add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed

        # -- Mask x
        if masks is not None:
            x = apply_masks(x, masks)

        # -- Initialize a list to store hidden states
        hidden_states: Optional[list[torch.Tensor]] = (
            [] if return_hidden_states else None
        )

        # -- Forward propagation through blocks
        for _i, blk in enumerate(self.blocks):
            x = blk(x)
            if return_hidden_states and hidden_states is not None:
                hidden_states.append(x)

        # -- Apply normalization if present
        if self.norm is not None:
            x = self.norm(x)

        # -- Apply global pooling
        x = global_pool_nlc(x, pool_type=self.global_pool)

        # -- Return both final output and hidden states if requested
        if return_hidden_states:
            return x, hidden_states
        return (x, None)

    def interpolate_pos_encoding(
        self, x: torch.Tensor, pos_embed: torch.Tensor
    ) -> torch.Tensor:
        """Interpolate positional encoding to match the size of the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
        pos_embed : torch.Tensor
            Positional embedding tensor.

        Returns
        -------
        torch.Tensor
            Interpolated positional encoding.
        """
        npatch = x.shape[1] - 1
        n = pos_embed.shape[1] - 1
        if npatch == n:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
                0, 3, 1, 2
            ),
            scale_factor=math.sqrt(npatch / n),
            mode="bicubic",
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp[-1].weight.data, layer_id + 1)
forward
forward(inputs, return_hidden_states=False)

Forward pass through the Vision Transformer.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self, inputs: dict[str, Any], return_hidden_states: bool = False
) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
    """Forward pass through the Vision Transformer."""
    masks = inputs.get(self.modality.mask)
    if masks is not None and not isinstance(masks, list):
        masks = [masks]

    x = inputs[self.modality.name]
    # -- Patchify x
    x = self.patch_embed(x)

    # -- Add positional embedding to x
    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
    x = x + pos_embed

    # -- Mask x
    if masks is not None:
        x = apply_masks(x, masks)

    # -- Initialize a list to store hidden states
    hidden_states: Optional[list[torch.Tensor]] = (
        [] if return_hidden_states else None
    )

    # -- Forward propagation through blocks
    for _i, blk in enumerate(self.blocks):
        x = blk(x)
        if return_hidden_states and hidden_states is not None:
            hidden_states.append(x)

    # -- Apply normalization if present
    if self.norm is not None:
        x = self.norm(x)

    # -- Apply global pooling
    x = global_pool_nlc(x, pool_type=self.global_pool)

    # -- Return both final output and hidden states if requested
    if return_hidden_states:
        return x, hidden_states
    return (x, None)
interpolate_pos_encoding
interpolate_pos_encoding(x, pos_embed)

Interpolate positional encoding to match the size of the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
pos_embed Tensor

Positional embedding tensor.

required

Returns:

Type Description
Tensor

Interpolated positional encoding.

Source code in mmlearn/modules/encoders/vision.py
def interpolate_pos_encoding(
    self, x: torch.Tensor, pos_embed: torch.Tensor
) -> torch.Tensor:
    """Interpolate positional encoding to match the size of the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    pos_embed : torch.Tensor
        Positional embedding tensor.

    Returns
    -------
    torch.Tensor
        Interpolated positional encoding.
    """
    npatch = x.shape[1] - 1
    n = pos_embed.shape[1] - 1
    if npatch == n:
        return pos_embed
    class_emb = pos_embed[:, 0]
    pos_embed = pos_embed[:, 1:]
    dim = x.shape[-1]
    pos_embed = nn.functional.interpolate(
        pos_embed.reshape(1, int(math.sqrt(n)), int(math.sqrt(n)), dim).permute(
            0, 3, 1, 2
        ),
        scale_factor=math.sqrt(npatch / n),
        mode="bicubic",
    )
    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
    return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)

VisionTransformerPredictor

Bases: Module

Vision Transformer Predictor.

This module implements a Vision Transformer that predicts masked tokens using a series of transformer blocks.

Parameters:

Name Type Description Default
num_patches int

The number of patches in the input image.

196
embed_dim int

The embedding dimension.

768
predictor_embed_dim int

The embedding dimension for the predictor.

384
depth int

The number of transformer blocks.

6
num_heads int

The number of attention heads.

12
mlp_ratio float

Ratio of the hidden dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, and value projections.

True
qk_scale Optional[float]

Override the default qk scale factor.

None
drop_rate float

Dropout rate for the transformer blocks.

0.0
attn_drop_rate float

Dropout rate for the attention mechanism.

0.0
drop_path_rate float

Dropout rate for stochastic depth.

0.0
norm_layer Callable[..., Module]

Normalization layer to use.

torch.nn.LayerNorm
init_std float

Standard deviation for weight initialization.

0.02
Source code in mmlearn/modules/encoders/vision.py
class VisionTransformerPredictor(nn.Module):
    """Vision Transformer Predictor.

    This module implements a Vision Transformer that predicts masked tokens
    using a series of transformer blocks.

    Parameters
    ----------
    num_patches : int
        The number of patches in the input image.
    embed_dim : int, optional, default=768
        The embedding dimension.
    predictor_embed_dim : int, optional, default=384
        The embedding dimension for the predictor.
    depth : int, optional, default=6
        The number of transformer blocks.
    num_heads : int, optional, default=12
        The number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of the hidden dimension in the MLP.
    qkv_bias : bool, optional, default=True
        If True, add a learnable bias to the query, key, and value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default qk scale factor.
    drop_rate : float, optional, default=0.0
        Dropout rate for the transformer blocks.
    attn_drop_rate : float, optional, default=0.0
        Dropout rate for the attention mechanism.
    drop_path_rate : float, optional, default=0.0
        Dropout rate for stochastic depth.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer to use.
    init_std : float, optional, default=0.02
        Standard deviation for weight initialization.
    """

    def __init__(
        self,
        num_patches: int = 196,
        embed_dim: int = 768,
        predictor_embed_dim: int = 384,
        depth: int = 6,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        init_std: float = 0.02,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.predictor_embed = nn.Linear(self.embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        # Positional Embedding
        self.predictor_pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, predictor_embed_dim), requires_grad=False
        )
        predictor_pos_embed = get_2d_sincos_pos_embed(
            self.predictor_pos_embed.shape[-1],
            int(self.num_patches**0.5),
            cls_token=False,
        )
        self.predictor_pos_embed.data.copy_(
            torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)
        )

        # Transformer Blocks
        self.predictor_blocks = nn.ModuleList(
            [
                Block(
                    dim=predictor_embed_dim,
                    num_heads=self.num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )

        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)

        # Weight Initialization
        self.init_std = init_std
        _trunc_normal(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)

    def fix_init_weight(self) -> None:
        """Fix initialization of weights by rescaling them according to layer depth."""

        def rescale(param: torch.Tensor, layer_id: int) -> None:
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m: nn.Module) -> None:
        """Initialize weights for the layers."""
        if isinstance(m, nn.Linear):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            _trunc_normal(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(
        self,
        x: torch.Tensor,
        masks_x: Union[torch.Tensor, list[torch.Tensor]],
        masks: Union[torch.Tensor, list[torch.Tensor]],
    ) -> torch.Tensor:
        """Forward pass through the Vision Transformer Predictor."""
        assert (masks is not None) and (masks_x is not None), (
            "Cannot run predictor without mask indices"
        )

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        b = len(x) // len(masks_x)

        # -- Map from encoder-dim to predictor-dim
        x = self.predictor_embed(x)

        # -- Add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, n_ctxt, d = x.shape

        # -- Concatenate mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- Forward propagation
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- Return predictions for mask tokens
        x = x[:, n_ctxt:]
        return self.predictor_proj(x)
fix_init_weight
fix_init_weight()

Fix initialization of weights by rescaling them according to layer depth.

Source code in mmlearn/modules/encoders/vision.py
def fix_init_weight(self) -> None:
    """Fix initialization of weights by rescaling them according to layer depth."""

    def rescale(param: torch.Tensor, layer_id: int) -> None:
        param.div_(math.sqrt(2.0 * layer_id))

    for layer_id, layer in enumerate(self.predictor_blocks):
        rescale(layer.attn.proj.weight.data, layer_id + 1)
        rescale(layer.mlp.fc2.weight.data, layer_id + 1)
forward
forward(x, masks_x, masks)

Forward pass through the Vision Transformer Predictor.

Source code in mmlearn/modules/encoders/vision.py
def forward(
    self,
    x: torch.Tensor,
    masks_x: Union[torch.Tensor, list[torch.Tensor]],
    masks: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
    """Forward pass through the Vision Transformer Predictor."""
    assert (masks is not None) and (masks_x is not None), (
        "Cannot run predictor without mask indices"
    )

    if not isinstance(masks_x, list):
        masks_x = [masks_x]

    if not isinstance(masks, list):
        masks = [masks]

    # -- Batch Size
    b = len(x) // len(masks_x)

    # -- Map from encoder-dim to predictor-dim
    x = self.predictor_embed(x)

    # -- Add positional embedding to x tokens
    x_pos_embed = self.predictor_pos_embed.repeat(b, 1, 1)
    x += apply_masks(x_pos_embed, masks_x)

    _, n_ctxt, d = x.shape

    # -- Concatenate mask tokens to x
    pos_embs = self.predictor_pos_embed.repeat(b, 1, 1)
    pos_embs = apply_masks(pos_embs, masks)
    pos_embs = repeat_interleave_batch(pos_embs, b, repeat=len(masks_x))
    pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
    pred_tokens += pos_embs
    x = x.repeat(len(masks), 1, 1)
    x = torch.cat([x, pred_tokens], dim=1)

    # -- Forward propagation
    for blk in self.predictor_blocks:
        x = blk(x)
    x = self.predictor_norm(x)

    # -- Return predictions for mask tokens
    x = x[:, n_ctxt:]
    return self.predictor_proj(x)

vit_predictor

vit_predictor(kwargs=None)

Create a VisionTransformerPredictor model.

Parameters:

Name Type Description Default
kwargs dict[str, Any]

Keyword arguments for the predictor.

None

Returns:

Type Description
VisionTransformerPredictor

An instance of VisionTransformerPredictor.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformerPredictor,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_predictor(
    kwargs: Optional[dict[str, Any]] = None,
) -> VisionTransformerPredictor:
    """Create a VisionTransformerPredictor model.

    Parameters
    ----------
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the predictor.

    Returns
    -------
    VisionTransformerPredictor
        An instance of VisionTransformerPredictor.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )

vit_tiny

vit_tiny(patch_size=16, kwargs=None)

Create a VisionTransformer model with tiny configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_tiny(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with tiny configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=192,
        depth=12,
        num_heads=3,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

vit_small

vit_small(patch_size=16, kwargs=None)

Create a VisionTransformer model with small configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_small(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with small configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

vit_base

vit_base(patch_size=16, kwargs=None)

Create a VisionTransformer model with base configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_base(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with base configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

vit_large

vit_large(patch_size=16, kwargs=None)

Create a VisionTransformer model with large configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_large(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with large configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

vit_huge

vit_huge(patch_size=16, kwargs=None)

Create a VisionTransformer model with huge configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_huge(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with huge configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

vit_giant

vit_giant(patch_size=16, kwargs=None)

Create a VisionTransformer model with giant configuration.

Parameters:

Name Type Description Default
patch_size int

Size of each patch.

16
kwargs dict[str, Any]

Keyword arguments for the model variant.

None

Returns:

Type Description
VisionTransformer

An instance of VisionTransformer.

Source code in mmlearn/modules/encoders/vision.py
@cast(
    VisionTransformer,
    store(
        group="modules/encoders",
        provider="mmlearn",
    ),
)
def vit_giant(
    patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
) -> VisionTransformer:
    """Create a VisionTransformer model with giant configuration.

    Parameters
    ----------
    patch_size : int, default=16
        Size of each patch.
    kwargs : dict[str, Any], optional, default=None
        Keyword arguments for the model variant.

    Returns
    -------
    VisionTransformer
        An instance of VisionTransformer.
    """
    if kwargs is None:
        kwargs = {}
    return VisionTransformer(
        patch_size=patch_size,
        embed_dim=1408,
        depth=40,
        num_heads=16,
        mlp_ratio=48 / 11,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )

Layers

mmlearn.modules.layers

Custom, reusable layers for models and tasks.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )

forward

forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

extra_repr

extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)

forward

forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids

forward

forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

uniform_mask

uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids

attention

Attention modules for Vision Transformer (ViT) and related models.

Attention

Bases: Module

Multi-head Self-Attention Mechanism.

Parameters:

Name Type Description Default
dim int

Number of input dimensions.

required
num_heads int

Number of attention heads.

8
qkv_bias bool

If True, adds a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override the default scale factor for the dot-product attention.

None
attn_drop float

Dropout probability for the attention weights.

0.0
proj_drop float

Dropout probability for the output of the attention layer.

0.0
Source code in mmlearn/modules/layers/attention.py
class Attention(nn.Module):
    """Multi-head Self-Attention Mechanism.

    Parameters
    ----------
    dim : int
        Number of input dimensions.
    num_heads : int, optional, default=8
        Number of attention heads.
    qkv_bias : bool, optional, default=False
        If True, adds a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override the default scale factor for the dot-product attention.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention weights.
    proj_drop : float, optional, default=0.0
        Dropout probability for the output of the attention layer.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through the multi-head self-attention module.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            The output tensor and the attention weights.
        """
        b, n, c = x.shape
        qkv = (
            self.qkv(x)
            .reshape(b, n, 3, self.num_heads, c // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn
forward
forward(x)

Forward pass through the multi-head self-attention module.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
tuple[Tensor, Tensor]

The output tensor and the attention weights.

Source code in mmlearn/modules/layers/attention.py
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass through the multi-head self-attention module.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        The output tensor and the attention weights.
    """
    b, n, c = x.shape
    qkv = (
        self.qkv(x)
        .reshape(b, n, 3, self.num_heads, c // self.num_heads)
        .permute(2, 0, 3, 1, 4)
    )
    q, k, v = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(b, n, c)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x, attn

embedding

Embedding layers.

PatchEmbed

Bases: Module

Image to Patch Embedding.

This module divides an image into patches and embeds them as a sequence of vectors.

Parameters:

Name Type Description Default
img_size int

Size of the input image (assumed to be square).

224
patch_size int

Size of each image patch (assumed to be square).

16
in_chans int

Number of input channels in the image.

3
embed_dim int

Dimension of the output embeddings.

768
Source code in mmlearn/modules/layers/embedding.py
class PatchEmbed(nn.Module):
    """Image to Patch Embedding.

    This module divides an image into patches and embeds them as a sequence of vectors.

    Parameters
    ----------
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    patch_size : int, optional, default=16
        Size of each image patch (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    embed_dim : int, optional, default=768
        Dimension of the output embeddings.

    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
    ) -> None:
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass to convert an image into patch embeddings."""
        return self.proj(x).flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass to convert an image into patch embeddings.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to convert an image into patch embeddings."""
    return self.proj(x).flatten(2).transpose(1, 2)

ConvEmbed

Bases: Module

3x3 Convolution stems for ViT following ViTC models.

This module builds convolutional stems for Vision Transformers (ViT) with intermediate batch normalization and ReLU activation.

Parameters:

Name Type Description Default
channels list[int]

list of channel sizes for each convolution layer.

required
strides list[int]

list of stride sizes for each convolution layer.

required
img_size int

Size of the input image (assumed to be square).

224
in_chans int

Number of input channels in the image.

3
batch_norm bool

Whether to include batch normalization after each convolution layer.

True
Source code in mmlearn/modules/layers/embedding.py
class ConvEmbed(nn.Module):
    """3x3 Convolution stems for ViT following ViTC models.

    This module builds convolutional stems for Vision Transformers (ViT)
    with intermediate batch normalization and ReLU activation.

    Parameters
    ----------
    channels : list[int]
        list of channel sizes for each convolution layer.
    strides : list[int]
        list of stride sizes for each convolution layer.
    img_size : int, optional, default=224
        Size of the input image (assumed to be square).
    in_chans : int, optional, default=3
        Number of input channels in the image.
    batch_norm : bool, optional, default=True
        Whether to include batch normalization after each convolution layer.

    """

    def __init__(
        self,
        channels: list[int],
        strides: list[int],
        img_size: int = 224,
        in_chans: int = 3,
        batch_norm: bool = True,
    ) -> None:
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [
                nn.Conv2d(
                    channels[i],
                    channels[i + 1],
                    kernel_size=3,
                    stride=strides[i],
                    padding=1,
                    bias=(not batch_norm),
                )
            ]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i + 1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [
            nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])
        ]
        self.stem = nn.Sequential(*stem)

        # Compute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size // stride_prod) ** 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the convolutional embedding layers."""
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)
forward
forward(x)

Forward pass through the convolutional embedding layers.

Source code in mmlearn/modules/layers/embedding.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the convolutional embedding layers."""
    p = self.stem(x)
    return p.flatten(2).transpose(1, 2)

get_2d_sincos_pos_embed

get_2d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 2D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid (both height and width).

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_sizegrid_size, embed_dim] or [1 + grid_sizegrid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid (both height and width).
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim] or
        [1 + grid_size*grid_size, embed_dim] if cls_token is True.
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

get_2d_sincos_pos_embed_from_grid

get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

Generate 2D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid ndarray

The grid of positions with shape [2, 1, grid_size, grid_size].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [grid_size*grid_size, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray:
    """
    Generate 2D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid : np.ndarray
        The grid of positions with shape [2, 1, grid_size, grid_size].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [grid_size*grid_size, embed_dim].
    """
    assert embed_dim % 2 == 0

    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    return np.concatenate([emb_h, emb_w], axis=1)

get_1d_sincos_pos_embed

get_1d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False
)

Generate 1D sine-cosine positional embeddings.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
grid_size int

The size of the grid.

required
cls_token bool

Whether to include a class token in the embeddings.

False

Returns:

Name Type Description
pos_embed ndarray

Positional embeddings with shape [grid_size, embed_dim] or [1 + grid_size, embed_dim] if cls_token is True.

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed(
    embed_dim: int, grid_size: int, cls_token: bool = False
) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    grid_size : int
        The size of the grid.
    cls_token : bool, optional, default=False
        Whether to include a class token in the embeddings.

    Returns
    -------
    pos_embed : np.ndarray
        Positional embeddings with shape [grid_size, embed_dim] or
        [1 + grid_size, embed_dim] if cls_token is True.
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

get_1d_sincos_pos_embed_from_grid

get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

Generate 1D sine-cosine positional embeddings from a grid.

Parameters:

Name Type Description Default
embed_dim int

The dimension of the embeddings.

required
pos ndarray

A list of positions to be encoded, with shape [M,].

required

Returns:

Name Type Description
emb ndarray

Positional embeddings with shape [M, embed_dim].

Source code in mmlearn/modules/layers/embedding.py
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray:
    """
    Generate 1D sine-cosine positional embeddings from a grid.

    Parameters
    ----------
    embed_dim : int
        The dimension of the embeddings.
    pos : np.ndarray
        A list of positions to be encoded, with shape [M,].

    Returns
    -------
    emb : np.ndarray
        Positional embeddings with shape [M, embed_dim].
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    return np.concatenate([emb_sin, emb_cos], axis=1)

logit_scaling

Learnable logit scaling layer.

LearnableLogitScaling

Bases: Module

Logit scaling layer.

Parameters:

Name Type Description Default
init_logit_scale float

Initial value of the logit scale.

1/0.07
learnable bool

If True, the logit scale is learnable. Otherwise, it is fixed.

True
max_logit_scale float

Maximum value of the logit scale.

100
Source code in mmlearn/modules/layers/logit_scaling.py
@store(group="modules/layers", provider="mmlearn")
class LearnableLogitScaling(torch.nn.Module):
    """Logit scaling layer.

    Parameters
    ----------
    init_logit_scale : float, optional, default=1/0.07
        Initial value of the logit scale.
    learnable : bool, optional, default=True
        If True, the logit scale is learnable. Otherwise, it is fixed.
    max_logit_scale : float, optional, default=100
        Maximum value of the logit scale.
    """

    def __init__(
        self,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable: bool = True,
    ) -> None:
        super().__init__()
        self.max_logit_scale = max_logit_scale
        self.init_logit_scale = init_logit_scale
        self.learnable = learnable
        log_logit_scale = torch.ones([]) * np.log(self.init_logit_scale)
        if learnable:
            self.log_logit_scale = torch.nn.Parameter(log_logit_scale)
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the logit scaling to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x

    def extra_repr(self) -> str:
        """Return the string representation of the layer."""
        return (
            f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
            f" max_logit_scale={self.max_logit_scale}"
        )
forward
forward(x)

Apply the logit scaling to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
Source code in mmlearn/modules/layers/logit_scaling.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply the logit scaling to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
extra_repr
extra_repr()

Return the string representation of the layer.

Source code in mmlearn/modules/layers/logit_scaling.py
def extra_repr(self) -> str:
    """Return the string representation of the layer."""
    return (
        f"logit_scale_init={self.init_logit_scale},learnable={self.learnable},"
        f" max_logit_scale={self.max_logit_scale}"
    )

mlp

Multi-layer perceptron (MLP).

MLP

Bases: Sequential

Multi-layer perceptron (MLP).

This module will create a block of Linear -> Normalization -> Activation -> Dropout layers.

Parameters:

Name Type Description Default
in_dim int

The input dimension.

required
out_dim Optional[int]

The output dimension. If not specified, it is set to :attr:in_dim.

None
hidden_dims Optional[list]

The dimensions of the hidden layers. The length of the list determines the number of hidden layers. This parameter is mutually exclusive with :attr:hidden_dims_multiplier.

None
hidden_dims_multiplier Optional[list]

The multipliers to apply to the input dimension to get the dimensions of the hidden layers. The length of the list determines the number of hidden layers. The multipliers will be used to get the dimensions of the hidden layers. This parameter is mutually exclusive with hidden_dims.

None
apply_multiplier_to_in_dim bool

Whether to apply the :attr:hidden_dims_multiplier to :attr:in_dim to get the dimensions of the hidden layers. If False, the multipliers will be applied to the dimensions of the previous hidden layer, starting from :attr:in_dim. This parameter is only relevant when :attr:hidden_dims_multiplier is specified.

False
norm_layer Optional[Callable[..., Module]]

The normalization layer to use. If not specified, no normalization is used. Partial functions can be used to specify the normalization layer with specific parameters.

None
activation_layer Optional[Callable[..., Module]]

The activation layer to use. If not specified, ReLU is used. Partial functions can be used to specify the activation layer with specific parameters.

torch.nn.ReLU
bias Union[bool, list[bool]]

Whether to use bias in the linear layers.

True
dropout Union[float, list[float]]

The dropout probability to use.

0.0

Raises:

Type Description
ValueError

If both :attr:hidden_dims and :attr:hidden_dims_multiplier are specified or if the lengths of :attr:bias and :attr:hidden_dims do not match or if the lengths of :attr:dropout and :attr:hidden_dims do not match.

Source code in mmlearn/modules/layers/mlp.py
@store(group="modules/layers", provider="mmlearn")
class MLP(torch.nn.Sequential):
    """Multi-layer perceptron (MLP).

    This module will create a block of ``Linear -> Normalization -> Activation -> Dropout``
    layers.

    Parameters
    ----------
    in_dim : int
        The input dimension.
    out_dim : Optional[int], optional, default=None
        The output dimension. If not specified, it is set to :attr:`in_dim`.
    hidden_dims : Optional[list], optional, default=None
        The dimensions of the hidden layers. The length of the list determines the
        number of hidden layers. This parameter is mutually exclusive with
        :attr:`hidden_dims_multiplier`.
    hidden_dims_multiplier : Optional[list], optional, default=None
        The multipliers to apply to the input dimension to get the dimensions of
        the hidden layers. The length of the list determines the number of hidden
        layers. The multipliers will be used to get the dimensions of the hidden
        layers. This parameter is mutually exclusive with `hidden_dims`.
    apply_multiplier_to_in_dim : bool, optional, default=False
        Whether to apply the :attr:`hidden_dims_multiplier` to :attr:`in_dim` to get the
        dimensions of the hidden layers. If ``False``, the multipliers will be applied
        to the dimensions of the previous hidden layer, starting from :attr:`in_dim`.
        This parameter is only relevant when :attr:`hidden_dims_multiplier` is
        specified.
    norm_layer : Optional[Callable[..., torch.nn.Module]], optional, default=None
        The normalization layer to use. If not specified, no normalization is used.
        Partial functions can be used to specify the normalization layer with specific
        parameters.
    activation_layer : Optional[Callable[..., torch.nn.Module]], optional, default=torch.nn.ReLU
        The activation layer to use. If not specified, ReLU is used. Partial functions
        can be used to specify the activation layer with specific parameters.
    bias : Union[bool, list[bool]], optional, default=True
        Whether to use bias in the linear layers.
    dropout : Union[float, list[float]], optional, default=0.0
        The dropout probability to use.

    Raises
    ------
    ValueError
        If both :attr:`hidden_dims` and :attr:`hidden_dims_multiplier` are specified
        or if the lengths of :attr:`bias` and :attr:`hidden_dims` do not match or if
        the lengths of :attr:`dropout` and :attr:`hidden_dims` do not match.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912
        self,
        in_dim: int,
        out_dim: Optional[int] = None,
        hidden_dims: Optional[list[int]] = None,
        hidden_dims_multiplier: Optional[list[float]] = None,
        apply_multiplier_to_in_dim: bool = False,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        bias: Union[bool, list[bool]] = True,
        dropout: Union[float, list[float]] = 0.0,
    ) -> None:
        if hidden_dims is None and hidden_dims_multiplier is None:
            hidden_dims = []
        if hidden_dims is not None and hidden_dims_multiplier is not None:
            raise ValueError(
                "Only one of `hidden_dims` or `hidden_dims_multiplier` must be specified."
            )

        if hidden_dims is None and hidden_dims_multiplier is not None:
            if apply_multiplier_to_in_dim:
                hidden_dims = [
                    int(in_dim * multiplier) for multiplier in hidden_dims_multiplier
                ]
            else:
                hidden_dims = [int(in_dim * hidden_dims_multiplier[0])]
                for multiplier in hidden_dims_multiplier[1:]:
                    hidden_dims.append(int(hidden_dims[-1] * multiplier))

        if isinstance(bias, bool):
            bias_list: list[bool] = [bias] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            bias_list = bias
        if len(bias_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `bias` to be a boolean or a list of booleans with length "
                "equal to the number of linear layers in the MLP."
            )

        if isinstance(dropout, float):
            dropout_list: list[float] = [dropout] * (len(hidden_dims) + 1)  # type: ignore[arg-type]
        else:
            dropout_list = dropout
        if len(dropout_list) != len(hidden_dims) + 1:  # type: ignore[arg-type]
            raise ValueError(
                "Expected `dropout` to be a float or a list of floats with length "
                "equal to the number of linear layers in the MLP."
            )

        # construct list of dimensions for the layers
        dims = [in_dim] + hidden_dims  # type: ignore[operator]
        layers = []
        for layer_idx, (in_features, hidden_features) in enumerate(
            zip(dims[:-1], dims[1:], strict=False)
        ):
            layers.append(
                torch.nn.Linear(in_features, hidden_features, bias=bias_list[layer_idx])
            )
            if norm_layer is not None:
                layers.append(norm_layer(hidden_features))
            if activation_layer is not None:
                layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout_list[layer_idx]))

        out_dim = out_dim or in_dim

        layers.append(torch.nn.Linear(dims[-1], out_dim, bias=bias_list[-1]))
        layers.append(torch.nn.Dropout(dropout_list[-1]))

        super().__init__(*layers)

normalization

Normalization layers.

L2Norm

Bases: Module

L2 normalization.

Parameters:

Name Type Description Default
dim int

The dimension along which to normalize.

required
Source code in mmlearn/modules/layers/normalization.py
@store(group="modules/layers", provider="mmlearn")
class L2Norm(torch.nn.Module):
    """L2 normalization.

    Parameters
    ----------
    dim : int
        The dimension along which to normalize.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply L2 normalization to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
        """
        return torch.nn.functional.normalize(x, dim=self.dim, p=2)
forward
forward(x)

Apply L2 normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Normalized tensor of shape (batch_sz, seq_len, dim).

Source code in mmlearn/modules/layers/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Apply L2 normalization to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Normalized tensor of shape ``(batch_sz, seq_len, dim)``.
    """
    return torch.nn.functional.normalize(x, dim=self.dim, p=2)

patch_dropout

Patch dropout layer.

PatchDropout

Bases: Module

Patch dropout layer.

Drops patch tokens (after embedding and adding CLS token) from the input tensor. Usually used in vision transformers to reduce the number of tokens. [1]_

Parameters:

Name Type Description Default
keep_rate float

The proportion of tokens to keep.

0.5
bias Optional[float]

The bias to add to the random noise before sorting.

None
token_shuffling bool

If True, the tokens are shuffled.

False
References

.. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023). Patchdropout: Economizing vision transformers using patch dropout. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 3953-3962).

Source code in mmlearn/modules/layers/patch_dropout.py
class PatchDropout(torch.nn.Module):
    """Patch dropout layer.

    Drops patch tokens (after embedding and adding CLS token) from the input tensor.
    Usually used in vision transformers to reduce the number of tokens. [1]_

    Parameters
    ----------
    keep_rate : float, optional, default=0.5
        The proportion of tokens to keep.
    bias : Optional[float], optional, default=None
        The bias to add to the random noise before sorting.
    token_shuffling : bool, optional, default=False
        If True, the tokens are shuffled.

    References
    ----------
    .. [1] Liu, Y., Matsoukas, C., Strand, F., Azizpour, H., & Smith, K. (2023).
       Patchdropout: Economizing vision transformers using patch dropout. In Proceedings
       of the IEEE/CVF Winter Conference on Applications of Computer Vision
       (pp. 3953-3962).
    """

    def __init__(
        self,
        keep_rate: float = 0.5,
        bias: Optional[float] = None,
        token_shuffling: bool = False,
    ):
        super().__init__()
        assert 0 < keep_rate <= 1, "The keep_rate must be in (0,1]"

        self.bias = bias
        self.keep_rate = keep_rate
        self.token_shuffling = token_shuffling

    def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
        """Drop tokens from the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.
        force_drop : bool, optional, default=False
            If True, the tokens are always dropped, even when the model is in
            evaluation mode.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
        """
        if (not self.training and not force_drop) or self.keep_rate == 1:
            return x

        batch_sz, _, dim = x.shape

        cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
            batch_sz, 1, dtype=torch.int64, device=x.device
        )
        patch_mask = self.uniform_mask(x)
        patch_mask = torch.hstack([cls_mask, patch_mask])

        return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))

    def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
        """Generate token ids to keep from uniform random distribution.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape ``(batch_sz, seq_len, dim)``.

        Returns
        -------
        torch.Tensor
            Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

        """
        batch_sz, seq_len, _ = x.shape
        seq_len = seq_len - 1  # patch length (without CLS)

        keep_len = int(seq_len * self.keep_rate)
        noise = torch.rand(batch_sz, seq_len, device=x.device)
        if self.bias is not None:
            noise += self.bias
        ids = torch.argsort(noise, dim=1)
        keep_ids = ids[:, :keep_len]
        if not self.token_shuffling:
            keep_ids = keep_ids.sort(1)[0]
        return keep_ids
forward
forward(x, force_drop=False)

Drop tokens from the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required
force_drop bool

If True, the tokens are always dropped, even when the model is in evaluation mode.

False

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len, dim) containing the kept tokens.

Source code in mmlearn/modules/layers/patch_dropout.py
def forward(self, x: torch.Tensor, force_drop: bool = False) -> torch.Tensor:
    """Drop tokens from the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.
    force_drop : bool, optional, default=False
        If True, the tokens are always dropped, even when the model is in
        evaluation mode.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len, dim)`` containing the kept tokens.
    """
    if (not self.training and not force_drop) or self.keep_rate == 1:
        return x

    batch_sz, _, dim = x.shape

    cls_mask = torch.zeros(  # assumes that CLS is always the 1st element
        batch_sz, 1, dtype=torch.int64, device=x.device
    )
    patch_mask = self.uniform_mask(x)
    patch_mask = torch.hstack([cls_mask, patch_mask])

    return torch.gather(x, dim=1, index=patch_mask.unsqueeze(-1).repeat(1, 1, dim))
uniform_mask
uniform_mask(x)

Generate token ids to keep from uniform random distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_sz, seq_len, dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch_sz, keep_len) containing the token ids to keep.

Source code in mmlearn/modules/layers/patch_dropout.py
def uniform_mask(self, x: torch.Tensor) -> torch.Tensor:
    """Generate token ids to keep from uniform random distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape ``(batch_sz, seq_len, dim)``.

    Returns
    -------
    torch.Tensor
        Tensor of shape ``(batch_sz, keep_len)`` containing the token ids to keep.

    """
    batch_sz, seq_len, _ = x.shape
    seq_len = seq_len - 1  # patch length (without CLS)

    keep_len = int(seq_len * self.keep_rate)
    noise = torch.rand(batch_sz, seq_len, device=x.device)
    if self.bias is not None:
        noise += self.bias
    ids = torch.argsort(noise, dim=1)
    keep_ids = ids[:, :keep_len]
    if not self.token_shuffling:
        keep_ids = keep_ids.sort(1)[0]
    return keep_ids

transformer_block

Transformer Block and Embedding Modules for Vision Transformers (ViT).

DropPath

Bases: Module

Drop paths (Stochastic Depth) per sample.

Parameters:

Name Type Description Default
drop_prob float

Probability of dropping paths.

0.0
Source code in mmlearn/modules/layers/transformer_block.py
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample.

    Parameters
    ----------
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through DropPath module."""
        return drop_path(x, self.drop_prob, self.training)
forward
forward(x)

Forward pass through DropPath module.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through DropPath module."""
    return drop_path(x, self.drop_prob, self.training)

Block

Bases: Module

Transformer Block.

This module represents a Transformer block that includes self-attention, normalization layers, and a feedforward multi-layer perceptron (MLP) network.

Parameters:

Name Type Description Default
dim int

The input and output dimension of the block.

required
num_heads int

Number of attention heads.

required
mlp_ratio float

Ratio of hidden dimension to the input dimension in the MLP.

4.0
qkv_bias bool

If True, add a learnable bias to the query, key, value projections.

False
qk_scale Optional[float]

Override default qk scale of head_dim ** -0.5 if set.

None
drop float

Dropout probability for the output of attention and MLP layers.

0.0
attn_drop float

Dropout probability for the attention scores.

0.0
drop_path float

Stochastic depth rate, a form of layer dropout.

0.0
act_layer Callable[..., Module]

Activation layer in the MLP.

nn.GELU
norm_layer Callable[..., Module]

Normalization layer.

torch.nn.LayerNorm
Source code in mmlearn/modules/layers/transformer_block.py
class Block(nn.Module):
    """Transformer Block.

    This module represents a Transformer block that includes self-attention,
    normalization layers, and a feedforward multi-layer perceptron (MLP) network.

    Parameters
    ----------
    dim : int
        The input and output dimension of the block.
    num_heads : int
        Number of attention heads.
    mlp_ratio : float, optional, default=4.0
        Ratio of hidden dimension to the input dimension in the MLP.
    qkv_bias : bool, optional, default=False
        If True, add a learnable bias to the query, key, value projections.
    qk_scale : Optional[float], optional, default=None
        Override default qk scale of head_dim ** -0.5 if set.
    drop : float, optional, default=0.0
        Dropout probability for the output of attention and MLP layers.
    attn_drop : float, optional, default=0.0
        Dropout probability for the attention scores.
    drop_path : float, optional, default=0.0
        Stochastic depth rate, a form of layer dropout.
    act_layer : Callable[..., torch.nn.Module], optional, default=nn.GELU
        Activation layer in the MLP.
    norm_layer : Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm
        Normalization layer.

    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)

        self.mlp = MLP(
            in_dim=dim,
            hidden_dims_multiplier=[mlp_ratio],
            activation_layer=act_layer,
            bias=True,
            dropout=drop,
        )

    def forward(
        self, x: torch.Tensor, return_attention: bool = False
    ) -> Union[torch.Tensor, torch.Tensor]:
        """Forward pass through the Transformer Block."""
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        return x + self.drop_path(self.mlp(self.norm2(x)))
forward
forward(x, return_attention=False)

Forward pass through the Transformer Block.

Source code in mmlearn/modules/layers/transformer_block.py
def forward(
    self, x: torch.Tensor, return_attention: bool = False
) -> Union[torch.Tensor, torch.Tensor]:
    """Forward pass through the Transformer Block."""
    y, attn = self.attn(self.norm1(x))
    if return_attention:
        return attn
    x = x + self.drop_path(y)
    return x + self.drop_path(self.mlp(self.norm2(x)))

drop_path

drop_path(x, drop_prob=0.0, training=False)

Drop paths (Stochastic Depth) for regularization during training.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
drop_prob float

Probability of dropping paths.

0.0
training bool

Whether the model is in training mode.

False

Returns:

Name Type Description
output Tensor

Output tensor after applying drop path.

Source code in mmlearn/modules/layers/transformer_block.py
def drop_path(
    x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
    """Drop paths (Stochastic Depth) for regularization during training.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    drop_prob : float, optional, default=0.0
        Probability of dropping paths.
    training : bool, optional, default=False
        Whether the model is in training mode.

    Returns
    -------
    output : torch.Tensor
        Output tensor after applying drop path.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    return x.div(keep_prob) * random_tensor

Losses

mmlearn.modules.losses

Loss functions.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()

forward

forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)

forward

forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

contrastive

Implementations of the contrastive loss and its variants.

ContrastiveLoss

Bases: Module

Contrastive Loss.

Parameters:

Name Type Description Default
l2_normalize bool

Whether to L2 normalize the features.

False
local_loss bool

Whether to calculate the loss locally i.e. local_features@global_features.

False
gather_with_grad bool

Whether to gather tensors with gradients.

False
modality_alignment bool

Whether to include modality alignment loss. This loss considers all features from the same modality as positive pairs and all features from different modalities as negative pairs.

False
cache_labels bool

Whether to cache the labels.

False
Source code in mmlearn/modules/losses/contrastive.py
@store(group="modules/losses", provider="mmlearn")
class ContrastiveLoss(nn.Module):
    """Contrastive Loss.

    Parameters
    ----------
    l2_normalize : bool, optional, default=False
        Whether to L2 normalize the features.
    local_loss : bool, optional, default=False
        Whether to calculate the loss locally i.e. ``local_features@global_features``.
    gather_with_grad : bool, optional, default=False
        Whether to gather tensors with gradients.
    modality_alignment : bool, optional, default=False
        Whether to include modality alignment loss. This loss considers all features
        from the same modality as positive pairs and all features from different
        modalities as negative pairs.
    cache_labels : bool, optional, default=False
        Whether to cache the labels.

    """

    def __init__(
        self,
        l2_normalize: bool = False,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        modality_alignment: bool = False,
        cache_labels: bool = False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.l2_normalize = l2_normalize
        self.modality_alignment = modality_alignment

        # cache state
        self._prev_num_logits = 0
        self._labels: dict[torch.device, torch.Tensor] = {}

    def forward(
        self,
        embeddings: dict[str, torch.Tensor],
        example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        modality_loss_pairs: list[LossPairSpec],
    ) -> torch.Tensor:
        """Calculate the contrastive loss.

        Parameters
        ----------
        embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        modality_loss_pairs : list[LossPairSpec]
            Specification of the modality pairs for which the loss should be calculated.

        Returns
        -------
        torch.Tensor
            The contrastive loss.
        """
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if world_size > 1 else 0

        if self.l2_normalize:
            embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

        if world_size > 1:  # gather embeddings and example_ids across all processes
            # NOTE: gathering dictionaries of tensors across all processes
            # (keys + values, as opposed to just values) is especially important
            # for the modality_alignment loss, which requires all embeddings
            all_embeddings = _gather_dicts(
                embeddings,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
            all_example_ids = _gather_dicts(
                example_ids,
                local_loss=self.local_loss,
                gather_with_grad=self.gather_with_grad,
                rank=rank,
            )
        else:
            all_embeddings = embeddings
            all_example_ids = example_ids

        losses = []
        for loss_pairs in modality_loss_pairs:
            logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
                loss_pairs.modalities,
                per_device_embeddings=embeddings,
                all_embeddings=all_embeddings,
                per_device_example_ids=example_ids,
                all_example_ids=all_example_ids,
                logit_scale=logit_scale,
                world_size=world_size,
            )
            if logits_per_feature_a is None or logits_per_feature_b is None:
                continue

            labels = self._get_ground_truth(
                logits_per_feature_a.shape,
                device=logits_per_feature_a.device,
                rank=rank,
                world_size=world_size,
                skipped_process=skip_flag,
            )

            if labels.numel() != 0:
                losses.append(
                    (
                        (
                            F.cross_entropy(logits_per_feature_a, labels)
                            + F.cross_entropy(logits_per_feature_b, labels)
                        )
                        / 2
                    )
                    * loss_pairs.weight
                )

        if self.modality_alignment:
            losses.append(
                self._compute_modality_alignment_loss(all_embeddings, logit_scale)
            )

        if not losses:  # no loss to compute (e.g. no paired data in batch)
            losses.append(
                torch.tensor(
                    0.0,
                    device=logit_scale.device,
                    dtype=next(iter(embeddings.values())).dtype,
                )
            )

        return torch.stack(losses).sum()

    def _get_ground_truth(
        self,
        logits_shape: tuple[int, int],
        device: torch.device,
        rank: int,
        world_size: int,
        skipped_process: bool,
    ) -> torch.Tensor:
        """Return the ground-truth labels.

        Parameters
        ----------
        logits_shape : tuple[int, int]
            Shape of the logits tensor.
        device : torch.device
            Device on which the labels should be created.
        rank : int
            Rank of the current process.
        world_size : int
            Number of processes.
        skipped_process : bool
            Whether the current process skipped the computation due to lack of data.

        Returns
        -------
        torch.Tensor
            Ground-truth labels.
        """
        num_logits = logits_shape[-1]

        # calculate ground-truth and cache if enabled
        if self._prev_num_logits != num_logits or device not in self._labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)

            if world_size > 1 and self.local_loss:
                local_size = torch.tensor(
                    0 if skipped_process else logits_shape[0], device=device
                )
                # NOTE: all processes must participate in the all_gather operation
                # even if they have no data to contribute.
                sizes = torch.stack(
                    _simple_gather_all_tensors(
                        local_size, group=dist.group.WORLD, world_size=world_size
                    )
                )
                sizes = torch.cat(
                    [torch.tensor([0], device=sizes.device), torch.cumsum(sizes, dim=0)]
                )
                labels = labels[
                    sizes[rank] : sizes[rank + 1] if rank + 1 < world_size else None
                ]

            if self.cache_labels:
                self._labels[device] = labels
                self._prev_num_logits = num_logits
        else:
            labels = self._labels[device]
        return labels

    def _get_logits(  # noqa: PLR0912
        self,
        modalities: tuple[str, str],
        per_device_embeddings: dict[str, torch.Tensor],
        all_embeddings: dict[str, torch.Tensor],
        per_device_example_ids: dict[str, torch.Tensor],
        all_example_ids: dict[str, torch.Tensor],
        logit_scale: torch.Tensor,
        world_size: int,
    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
        """Calculate the logits for the given modalities.

        Parameters
        ----------
        modalities : tuple[str, str]
            Tuple of modality names.
        per_device_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor. In distributed mode, this contains
            embeddings from all processes.
        per_device_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index.
        all_example_ids : dict[str, torch.Tensor]
            Dictionary of example IDs, where the key is the modality name and the value
            is a tensor tuple of the dataset index and the example index. In distributed
            mode, this contains example IDs from all processes.
        logit_scale : torch.Tensor
            Scale factor for the logits.
        world_size : int
            Number of processes.

        Returns
        -------
        tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]
            Tuple of logits for the given modalities. If embeddings for the given
            modalities are not available, returns `None` for the logits. The last
            element is a flag indicating whether the process skipped the computation
            due to lack of data.
        """
        modality_a = Modalities.get_modality(modalities[0])
        modality_b = Modalities.get_modality(modalities[1])
        skip_flag = False

        if self.local_loss or world_size == 1:
            if not (
                modality_a.embedding in per_device_embeddings
                and modality_b.embedding in per_device_embeddings
            ):
                if world_size > 1:  # NOTE: not all processes exit here, hence skip_flag
                    skip_flag = True
                else:
                    return None, None, skip_flag

            if not skip_flag:
                indices_a, indices_b = find_matching_indices(
                    per_device_example_ids[modality_a.name],
                    per_device_example_ids[modality_b.name],
                )
                if indices_a.numel() == 0 or indices_b.numel() == 0:
                    if world_size > 1:  # not all processes exit here
                        skip_flag = True
                    else:
                        return None, None, skip_flag

            if not skip_flag:
                features_a = per_device_embeddings[modality_a.embedding][indices_a]
                features_b = per_device_embeddings[modality_b.embedding][indices_b]
            else:
                # all processes must participate in the all_gather operation
                # that follows, even if they have no data to contribute. So,
                # we create empty tensors here.
                features_a = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )
                features_b = torch.empty(
                    0, device=list(per_device_embeddings.values())[0].device
                )

        if world_size > 1:
            if not (
                modality_a.embedding in all_embeddings
                and modality_b.embedding in all_embeddings
            ):  # all processes exit here
                return None, None, skip_flag

            indices_a, indices_b = find_matching_indices(
                all_example_ids[modality_a.name],
                all_example_ids[modality_b.name],
            )
            if indices_a.numel() == 0 or indices_b.numel() == 0:
                # all processes exit here
                return None, None, skip_flag

            all_features_a = all_embeddings[modality_a.embedding][indices_a]
            all_features_b = all_embeddings[modality_b.embedding][indices_b]

            if self.local_loss:
                if features_a.numel() == 0:
                    features_a = all_features_a
                if features_b.numel() == 0:
                    features_b = all_features_b

                logits_per_feature_a = logit_scale * _safe_matmul(
                    features_a, all_features_b
                )
                logits_per_feature_b = logit_scale * _safe_matmul(
                    features_b, all_features_a
                )
            else:
                logits_per_feature_a = logit_scale * _safe_matmul(
                    all_features_a, all_features_b
                )
                logits_per_feature_b = logits_per_feature_a.T
        else:
            logits_per_feature_a = logit_scale * _safe_matmul(features_a, features_b)
            logits_per_feature_b = logit_scale * _safe_matmul(features_b, features_a)

        return logits_per_feature_a, logits_per_feature_b, skip_flag

    def _compute_modality_alignment_loss(
        self, all_embeddings: dict[str, torch.Tensor], logit_scale: torch.Tensor
    ) -> torch.Tensor:
        """Compute the modality alignment loss.

        This loss considers all features from the same modality as positive pairs
        and all features from different modalities as negative pairs.

        Parameters
        ----------
        all_embeddings : dict[str, torch.Tensor]
            Dictionary of embeddings, where the key is the modality name and the value
            is the corresponding embedding tensor.
        logit_scale : torch.Tensor
            Scale factor for the logits.

        Returns
        -------
        torch.Tensor
            Modality alignment loss.

        Notes
        -----
        This loss does not support `local_loss=True`.
        """
        available_modalities = list(all_embeddings.keys())
        # TODO: support local_loss for modality_alignment?
        # if world_size == 1, all_embeddings == embeddings
        all_features = torch.cat(list(all_embeddings.values()), dim=0)

        positive_indices = torch.tensor(
            [
                (i, j)
                if idx == 0
                else (
                    i + all_embeddings[available_modalities[idx - 1]].size(0),
                    j + all_embeddings[available_modalities[idx - 1]].size(0),
                )
                for idx, k in enumerate(all_embeddings)
                for i, j in itertools.combinations(range(all_embeddings[k].size(0)), 2)
            ],
            device=all_features.device,
        )
        logits = logit_scale * _safe_matmul(all_features, all_features)

        target = torch.eye(all_features.size(0), device=all_features.device)
        target[positive_indices[:, 0], positive_indices[:, 1]] = 1

        modality_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, target, reduction="none"
        )

        target_pos = target.bool()
        target_neg = ~target_pos

        # loss_pos and loss_neg below contain non-zero values only for those
        # elements that are positive pairs and negative pairs respectively.
        loss_pos = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_pos, modality_loss[target_pos])
        loss_neg = torch.zeros(
            logits.size(0), logits.size(0), device=target.device
        ).masked_scatter(target_neg, modality_loss[target_neg])

        loss_pos = loss_pos.sum(dim=1)
        loss_neg = loss_neg.sum(dim=1)
        num_pos = target.sum(dim=1)
        num_neg = logits.size(0) - num_pos

        return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()
forward
forward(
    embeddings,
    example_ids,
    logit_scale,
    modality_loss_pairs,
)

Calculate the contrastive loss.

Parameters:

Name Type Description Default
embeddings dict[str, Tensor]

Dictionary of embeddings, where the key is the modality name and the value is the corresponding embedding tensor.

required
example_ids dict[str, Tensor]

Dictionary of example IDs, where the key is the modality name and the value is a tensor tuple of the dataset index and the example index.

required
logit_scale Tensor

Scale factor for the logits.

required
modality_loss_pairs list[LossPairSpec]

Specification of the modality pairs for which the loss should be calculated.

required

Returns:

Type Description
Tensor

The contrastive loss.

Source code in mmlearn/modules/losses/contrastive.py
def forward(
    self,
    embeddings: dict[str, torch.Tensor],
    example_ids: dict[str, torch.Tensor],
    logit_scale: torch.Tensor,
    modality_loss_pairs: list[LossPairSpec],
) -> torch.Tensor:
    """Calculate the contrastive loss.

    Parameters
    ----------
    embeddings : dict[str, torch.Tensor]
        Dictionary of embeddings, where the key is the modality name and the value
        is the corresponding embedding tensor.
    example_ids : dict[str, torch.Tensor]
        Dictionary of example IDs, where the key is the modality name and the value
        is a tensor tuple of the dataset index and the example index.
    logit_scale : torch.Tensor
        Scale factor for the logits.
    modality_loss_pairs : list[LossPairSpec]
        Specification of the modality pairs for which the loss should be calculated.

    Returns
    -------
    torch.Tensor
        The contrastive loss.
    """
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if world_size > 1 else 0

    if self.l2_normalize:
        embeddings = {k: F.normalize(v, p=2, dim=-1) for k, v in embeddings.items()}

    if world_size > 1:  # gather embeddings and example_ids across all processes
        # NOTE: gathering dictionaries of tensors across all processes
        # (keys + values, as opposed to just values) is especially important
        # for the modality_alignment loss, which requires all embeddings
        all_embeddings = _gather_dicts(
            embeddings,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
        all_example_ids = _gather_dicts(
            example_ids,
            local_loss=self.local_loss,
            gather_with_grad=self.gather_with_grad,
            rank=rank,
        )
    else:
        all_embeddings = embeddings
        all_example_ids = example_ids

    losses = []
    for loss_pairs in modality_loss_pairs:
        logits_per_feature_a, logits_per_feature_b, skip_flag = self._get_logits(
            loss_pairs.modalities,
            per_device_embeddings=embeddings,
            all_embeddings=all_embeddings,
            per_device_example_ids=example_ids,
            all_example_ids=all_example_ids,
            logit_scale=logit_scale,
            world_size=world_size,
        )
        if logits_per_feature_a is None or logits_per_feature_b is None:
            continue

        labels = self._get_ground_truth(
            logits_per_feature_a.shape,
            device=logits_per_feature_a.device,
            rank=rank,
            world_size=world_size,
            skipped_process=skip_flag,
        )

        if labels.numel() != 0:
            losses.append(
                (
                    (
                        F.cross_entropy(logits_per_feature_a, labels)
                        + F.cross_entropy(logits_per_feature_b, labels)
                    )
                    / 2
                )
                * loss_pairs.weight
            )

    if self.modality_alignment:
        losses.append(
            self._compute_modality_alignment_loss(all_embeddings, logit_scale)
        )

    if not losses:  # no loss to compute (e.g. no paired data in batch)
        losses.append(
            torch.tensor(
                0.0,
                device=logit_scale.device,
                dtype=next(iter(embeddings.values())).dtype,
            )
        )

    return torch.stack(losses).sum()

data2vec

Implementation of Data2vec loss function.

Data2VecLoss

Bases: Module

Data2Vec loss function.

Parameters:

Name Type Description Default
beta float

Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.

0
loss_scale Optional[float]

Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).

None
reduction str

Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.

'none'

Raises:

Type Description
ValueError

If the reduction mode is not supported.

Source code in mmlearn/modules/losses/data2vec.py
@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
    """Data2Vec loss function.

    Parameters
    ----------
    beta : float, optional, default=0
        Specifies the beta parameter for smooth L1 loss. If ``0``, MSE loss is used.
    loss_scale : Optional[float], optional, default=None
        Scaling factor for the loss. If None, uses ``1 / sqrt(embedding_dim)``.
    reduction : str, optional, default='none'
        Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``.

    Raises
    ------
    ValueError
        If the reduction mode is not supported.
    """

    def __init__(
        self,
        beta: float = 0,
        loss_scale: Optional[float] = None,
        reduction: str = "none",
    ) -> None:
        super().__init__()
        self.beta = beta
        self.loss_scale = loss_scale
        if reduction not in ["none", "mean", "sum"]:
            raise ValueError(f"Unsupported reduction mode: {reduction}")
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the Data2Vec loss.

        Parameters
        ----------
        x : torch.Tensor
            Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
        y : torch.Tensor
            Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

        Returns
        -------
        torch.Tensor
            Data2Vec loss value.

        Raises
        ------
        ValueError
            If the shapes of x and y do not match.
        """
        if x.shape != y.shape:
            raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

        x = x.view(-1, x.size(-1)).float()
        y = y.view(-1, y.size(-1))

        if self.beta == 0:
            loss = mse_loss(x, y, reduction="none")
        else:
            loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

        if self.loss_scale is not None:
            scale = self.loss_scale
        else:
            scale = 1 / math.sqrt(x.size(-1))

        loss = loss * scale

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        # 'none'
        return loss.view(x.size(0), -1).sum(1)
forward
forward(x, y)

Compute the Data2Vec loss.

Parameters:

Name Type Description Default
x Tensor

Predicted embeddings of shape (batch_size, num_patches, embedding_dim).

required
y Tensor

Target embeddings of shape (batch_size, num_patches, embedding_dim).

required

Returns:

Type Description
Tensor

Data2Vec loss value.

Raises:

Type Description
ValueError

If the shapes of x and y do not match.

Source code in mmlearn/modules/losses/data2vec.py
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the Data2Vec loss.

    Parameters
    ----------
    x : torch.Tensor
        Predicted embeddings of shape ``(batch_size, num_patches, embedding_dim)``.
    y : torch.Tensor
        Target embeddings of shape ``(batch_size, num_patches, embedding_dim)``.

    Returns
    -------
    torch.Tensor
        Data2Vec loss value.

    Raises
    ------
    ValueError
        If the shapes of x and y do not match.
    """
    if x.shape != y.shape:
        raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

    x = x.view(-1, x.size(-1)).float()
    y = y.view(-1, y.size(-1))

    if self.beta == 0:
        loss = mse_loss(x, y, reduction="none")
    else:
        loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

    if self.loss_scale is not None:
        scale = self.loss_scale
    else:
        scale = 1 / math.sqrt(x.size(-1))

    loss = loss * scale

    if self.reduction == "mean":
        return loss.mean()
    if self.reduction == "sum":
        return loss.sum()
    # 'none'
    return loss.view(x.size(0), -1).sum(1)

Learning Rate Schedulers

mmlearn.modules.lr_schedulers

Learning rate schedulers for training models.

linear_warmup_cosine_annealing_lr

linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

linear_warmup_cosine_lr

Linear warmup cosine annealing learning rate scheduler.

linear_warmup_cosine_annealing_lr

linear_warmup_cosine_annealing_lr(
    optimizer,
    warmup_steps,
    max_steps,
    start_factor=1 / 3,
    eta_min=0.0,
    last_epoch=-1,
)

Create a linear warmup cosine annealing learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
warmup_steps int

Maximum number of iterations for linear warmup.

required
max_steps int

Maximum number of iterations.

required
start_factor float

Multiplicative factor for the learning rate at the start of the warmup phase.

1/3
eta_min float

Minimum learning rate.

0
last_epoch int

The index of last epoch. If set to -1, it initializes the learning rate as the base learning rate

-1

Returns:

Type Description
LRScheduler

The learning rate scheduler.

Raises:

Type Description
ValueError

If warmup_steps is greater than or equal to max_steps or if warmup_steps is less than or equal to 0.

Source code in mmlearn/modules/lr_schedulers/linear_warmup_cosine_lr.py
@store(  # type: ignore[misc]
    group="modules/lr_schedulers",
    provider="mmlearn",
    zen_partial=True,
    warmup_steps=MISSING,
    max_steps=MISSING,
)
def linear_warmup_cosine_annealing_lr(
    optimizer: Optimizer,
    warmup_steps: int,
    max_steps: int,
    start_factor: float = 1 / 3,
    eta_min: float = 0.0,
    last_epoch: int = -1,
) -> LRScheduler:
    """Create a linear warmup cosine annealing learning rate scheduler.

    Parameters
    ----------
    optimizer : Optimizer
        The optimizer for which to schedule the learning rate.
    warmup_steps : int
        Maximum number of iterations for linear warmup.
    max_steps : int
        Maximum number of iterations.
    start_factor : float, optional, default=1/3
        Multiplicative factor for the learning rate at the start of the warmup phase.
    eta_min : float, optional, default=0
        Minimum learning rate.
    last_epoch : int, optional, default=-1
        The index of last epoch. If set to ``-1``, it initializes the learning rate
        as the base learning rate

    Returns
    -------
    LRScheduler
        The learning rate scheduler.

    Raises
    ------
    ValueError
        If `warmup_steps` is greater than or equal to `max_steps` or if `warmup_steps`
        is less than or equal to 0.
    """
    if warmup_steps >= max_steps:
        raise ValueError(
            "Expected `warmup_steps` to be less than `max_steps` but got "
            f"`warmup_steps={warmup_steps}` and `max_steps={max_steps}`."
        )
    if warmup_steps <= 0:
        raise ValueError(
            "Expected `warmup_steps` to be positive but got "
            f"`warmup_steps={warmup_steps}`."
        )

    linear_lr = LinearLR(
        optimizer,
        start_factor=start_factor,
        total_iters=warmup_steps,
        last_epoch=last_epoch,
    )
    cosine_lr = CosineAnnealingLR(
        optimizer,
        T_max=max_steps - warmup_steps,
        eta_min=eta_min,
        last_epoch=last_epoch,
    )
    return SequentialLR(
        optimizer,
        schedulers=[linear_lr, cosine_lr],
        milestones=[warmup_steps],
        last_epoch=last_epoch,
    )

Metrics

mmlearn.modules.metrics

Metrics for evaluating models.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)

__init__

__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False

update

update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size

compute

compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )

forward

forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

retrieval_recall

Retrieval Recall@K metric.

RetrievalRecallAtK

Bases: Metric

Retrieval Recall@K metric.

Computes the Recall@K for retrieval tasks. The metric is computed as follows:

  1. Compute the cosine similarity between the query and the database.
  2. For each query, sort the database in decreasing order of similarity.
  3. Compute the Recall@K as the number of true positives among the top K elements.

Parameters:

Name Type Description Default
top_k int

The number of top elements to consider for computing the Recall@K.

required
reduction (mean, sum, none, None)

Specifies the reduction to apply after computing the pairwise cosine similarity scores.

"mean"
aggregation (mean, median, min, max)

Specifies the aggregation function to apply to the Recall@K values computed in batches. If a callable is provided, it should accept a tensor of values and a keyword argument 'dim' and return a single scalar value.

"mean"
kwargs Any

Additional arguments to be passed to the 🇵🇾class:torchmetrics.Metric class.

{}

Raises:

Type Description
ValueError
  • If the top_k is not a positive integer or None.
  • If the reduction is not one of {"mean", "sum", "none", None}.
  • If the aggregation is not one of {"mean", "median", "min", "max"} or a custom callable function.
Source code in mmlearn/modules/metrics/retrieval_recall.py
@store(group="modules/metrics", provider="mmlearn")
class RetrievalRecallAtK(Metric):
    """Retrieval Recall@K metric.

    Computes the Recall@K for retrieval tasks. The metric is computed as follows:

    1. Compute the cosine similarity between the query and the database.
    2. For each query, sort the database in decreasing order of similarity.
    3. Compute the Recall@K as the number of true positives among the top K elements.

    Parameters
    ----------
    top_k : int
        The number of top elements to consider for computing the Recall@K.
    reduction : {"mean", "sum", "none", None}, optional, default="sum"
        Specifies the reduction to apply after computing the pairwise cosine similarity
        scores.
    aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
        Specifies the aggregation function to apply to the Recall@K values computed
        in batches. If a callable is provided, it should accept a tensor of values
        and a keyword argument ``'dim'`` and return a single scalar value.
    kwargs : Any
        Additional arguments to be passed to the :py:class:`torchmetrics.Metric` class.

    Raises
    ------
    ValueError

        - If the `top_k` is not a positive integer or None.
        - If the `reduction` is not one of {"mean", "sum", "none", None}.
        - If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
          custom callable function.

    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    indexes: list[torch.Tensor]
    x: list[torch.Tensor]
    y: list[torch.Tensor]
    num_samples: torch.Tensor

    def __init__(
        self,
        top_k: int,
        reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
        aggregation: Union[
            Literal["mean", "median", "min", "max"],
            Callable[[torch.Tensor, int], torch.Tensor],
        ] = "mean",
        **kwargs: Any,
    ) -> None:
        """Initialize the metric."""
        super().__init__(**kwargs)

        if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
            raise ValueError("`top_k` has to be a positive integer or None")
        self.top_k = top_k

        allowed_reduction = ("sum", "mean", "none", None)
        if reduction not in allowed_reduction:
            raise ValueError(
                f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
            )
        self.reduction = reduction

        if not (
            aggregation in ("mean", "median", "min", "max") or callable(aggregation)
        ):
            raise ValueError(
                "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
                f"which takes tensor of values, but got {aggregation}."
            )
        self.aggregation = aggregation

        self.add_state("x", default=[], dist_reduce_fx="cat")
        self.add_state("y", default=[], dist_reduce_fx="cat")
        self.add_state("indexes", default=[], dist_reduce_fx="cat")
        self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

        self._batch_size: int = -1

        self.compute_on_cpu = True
        self.sync_on_compute = False
        self.dist_sync_on_step = False
        self._to_sync = self.sync_on_compute
        self._should_unsync = False

    def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
        """Check shape, convert dtypes and add to accumulators.

        Parameters
        ----------
        x : torch.Tensor
            Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
            of samples and `D` is the number of dimensions.
        y : torch.Tensor
            Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
            of samples and ``D`` is the number of dimensions.
        indexes : torch.Tensor
            Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
            This specifies which sample in ``y`` is the positive match for each
            sample in ``x``.

        Raises
        ------
        ValueError
            If `indexes` is None.

        """
        if indexes is None:
            raise ValueError("Argument `indexes` cannot be None")

        x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

        # offset batch indexes by the number of samples seen so far
        indexes += self.num_samples

        local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
        if self._is_distributed():
            x = dim_zero_cat(gather_all_tensors(x, self.process_group))
            y = dim_zero_cat(gather_all_tensors(y, self.process_group))
            indexes = dim_zero_cat(
                gather_all_tensors(indexes.clone(), self.process_group)
            )

            # offset indexes for each device
            bsz_per_device = dim_zero_cat(
                gather_all_tensors(local_batch_size, self.process_group)
            )
            cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
            for device_idx in range(1, bsz_per_device.numel()):
                indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                    cum_local_bsz[device_idx - 1]
                )

            # update the global sample count
            self.num_samples += cum_local_bsz[-1]

            self._is_synced = True
        else:
            self.num_samples += x.size(0)

        self.x.append(x)
        self.y.append(y)
        self.indexes.append(indexes)

        if self._batch_size == -1:
            self._batch_size = x.size(0)  # global batch size

    def compute(self) -> torch.Tensor:
        """Compute the metric.

        Returns
        -------
        torch.Tensor
            The computed metric.
        """
        x = dim_zero_cat(self.x)
        y = dim_zero_cat(self.y)

        # normalize embeddings
        x /= x.norm(dim=-1, p=2, keepdim=True)
        y /= y.norm(dim=-1, p=2, keepdim=True)

        # instantiate reduction function
        reduction_mapping: Dict[
            Optional[str], Callable[[torch.Tensor], torch.Tensor]
        ] = {
            "sum": partial(torch.sum, dim=-1),
            "mean": partial(torch.mean, dim=-1),
            "none": lambda x: x,
            None: lambda x: x,
        }

        # concatenate indexes of true pairs
        indexes = dim_zero_cat(self.indexes)

        results: list[torch.Tensor] = []
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=os.cpu_count() or 1  # use all available CPUs
        ) as executor:
            futures = [
                executor.submit(
                    self._process_batch,
                    start,
                    x,
                    y,
                    indexes,
                    reduction_mapping,
                    self.top_k,
                )
                for start in tqdm(
                    range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
                )
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        return _retrieval_aggregate(
            (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
        )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward method is not supported.

        Raises
        ------
        NotImplementedError
            The forward method is not supported for this metric.
        """
        raise NotImplementedError(
            "RetrievalRecallAtK metric does not support forward method"
        )

    def _is_distributed(self) -> bool:
        if self.distributed_available_fn is not None:
            distributed_available = self.distributed_available_fn

        return distributed_available() if callable(distributed_available) else False

    def _process_batch(
        self,
        start: int,
        x_norm: torch.Tensor,
        y_norm: torch.Tensor,
        indexes: torch.Tensor,
        reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
        top_k: int,
    ) -> torch.Tensor:
        """Compute the Recall@K for a batch of samples."""
        end = start + self._batch_size
        x_norm_batch = x_norm[start:end]
        indexes_batch = indexes[start:end]

        similarity = _safe_matmul(x_norm_batch, y_norm)
        scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

        with torch.inference_mode():
            positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
            positive_pairs[torch.arange(len(scores)), indexes_batch] = True

        return _recall_at_k(scores, positive_pairs, top_k)
__init__
__init__(
    top_k, reduction="sum", aggregation="mean", **kwargs
)

Initialize the metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def __init__(
    self,
    top_k: int,
    reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
    aggregation: Union[
        Literal["mean", "median", "min", "max"],
        Callable[[torch.Tensor, int], torch.Tensor],
    ] = "mean",
    **kwargs: Any,
) -> None:
    """Initialize the metric."""
    super().__init__(**kwargs)

    if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
        raise ValueError("`top_k` has to be a positive integer or None")
    self.top_k = top_k

    allowed_reduction = ("sum", "mean", "none", None)
    if reduction not in allowed_reduction:
        raise ValueError(
            f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}"
        )
    self.reduction = reduction

    if not (
        aggregation in ("mean", "median", "min", "max") or callable(aggregation)
    ):
        raise ValueError(
            "Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
            f"which takes tensor of values, but got {aggregation}."
        )
    self.aggregation = aggregation

    self.add_state("x", default=[], dist_reduce_fx="cat")
    self.add_state("y", default=[], dist_reduce_fx="cat")
    self.add_state("indexes", default=[], dist_reduce_fx="cat")
    self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="cat")

    self._batch_size: int = -1

    self.compute_on_cpu = True
    self.sync_on_compute = False
    self.dist_sync_on_step = False
    self._to_sync = self.sync_on_compute
    self._should_unsync = False
update
update(x, y, indexes)

Check shape, convert dtypes and add to accumulators.

Parameters:

Name Type Description Default
x Tensor

Embeddings (unnormalized) of shape (N, D) where N is the number of samples and D is the number of dimensions.

required
y Tensor

Embeddings (unnormalized) of shape (M, D) where M is the number of samples and D is the number of dimensions.

required
indexes Tensor

Index tensor of shape (N,) where N is the number of samples. This specifies which sample in y is the positive match for each sample in x.

required

Raises:

Type Description
ValueError

If indexes is None.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
    """Check shape, convert dtypes and add to accumulators.

    Parameters
    ----------
    x : torch.Tensor
        Embeddings (unnormalized) of shape ``(N, D)`` where ``N`` is the number
        of samples and `D` is the number of dimensions.
    y : torch.Tensor
        Embeddings (unnormalized) of shape ``(M, D)`` where ``M`` is the number
        of samples and ``D`` is the number of dimensions.
    indexes : torch.Tensor
        Index tensor of shape ``(N,)`` where ``N`` is the number of samples.
        This specifies which sample in ``y`` is the positive match for each
        sample in ``x``.

    Raises
    ------
    ValueError
        If `indexes` is None.

    """
    if indexes is None:
        raise ValueError("Argument `indexes` cannot be None")

    x, y, indexes = _update_batch_inputs(x.clone(), y.clone(), indexes.clone())

    # offset batch indexes by the number of samples seen so far
    indexes += self.num_samples

    local_batch_size = torch.zeros_like(self.num_samples) + x.size(0)
    if self._is_distributed():
        x = dim_zero_cat(gather_all_tensors(x, self.process_group))
        y = dim_zero_cat(gather_all_tensors(y, self.process_group))
        indexes = dim_zero_cat(
            gather_all_tensors(indexes.clone(), self.process_group)
        )

        # offset indexes for each device
        bsz_per_device = dim_zero_cat(
            gather_all_tensors(local_batch_size, self.process_group)
        )
        cum_local_bsz = torch.cumsum(bsz_per_device, dim=0)
        for device_idx in range(1, bsz_per_device.numel()):
            indexes[cum_local_bsz[device_idx - 1] : cum_local_bsz[device_idx]] += (
                cum_local_bsz[device_idx - 1]
            )

        # update the global sample count
        self.num_samples += cum_local_bsz[-1]

        self._is_synced = True
    else:
        self.num_samples += x.size(0)

    self.x.append(x)
    self.y.append(y)
    self.indexes.append(indexes)

    if self._batch_size == -1:
        self._batch_size = x.size(0)  # global batch size
compute
compute()

Compute the metric.

Returns:

Type Description
Tensor

The computed metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def compute(self) -> torch.Tensor:
    """Compute the metric.

    Returns
    -------
    torch.Tensor
        The computed metric.
    """
    x = dim_zero_cat(self.x)
    y = dim_zero_cat(self.y)

    # normalize embeddings
    x /= x.norm(dim=-1, p=2, keepdim=True)
    y /= y.norm(dim=-1, p=2, keepdim=True)

    # instantiate reduction function
    reduction_mapping: Dict[
        Optional[str], Callable[[torch.Tensor], torch.Tensor]
    ] = {
        "sum": partial(torch.sum, dim=-1),
        "mean": partial(torch.mean, dim=-1),
        "none": lambda x: x,
        None: lambda x: x,
    }

    # concatenate indexes of true pairs
    indexes = dim_zero_cat(self.indexes)

    results: list[torch.Tensor] = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=os.cpu_count() or 1  # use all available CPUs
    ) as executor:
        futures = [
            executor.submit(
                self._process_batch,
                start,
                x,
                y,
                indexes,
                reduction_mapping,
                self.top_k,
            )
            for start in tqdm(
                range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
            )
        ]
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())

    return _retrieval_aggregate(
        (torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
    )
forward
forward(*args, **kwargs)

Forward method is not supported.

Raises:

Type Description
NotImplementedError

The forward method is not supported for this metric.

Source code in mmlearn/modules/metrics/retrieval_recall.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Forward method is not supported.

    Raises
    ------
    NotImplementedError
        The forward method is not supported for this metric.
    """
    raise NotImplementedError(
        "RetrievalRecallAtK metric does not support forward method"
    )

Tasks

mmlearn.tasks

Modules for pretraining, downstream and evaluation tasks.

ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)

configure_model

configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()

encode

encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output

forward

forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs

on_train_epoch_start

on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()

training_step

training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss

on_before_zero_grad

on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)

on_validation_epoch_start

on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")

validation_step

validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")

on_validation_epoch_end

on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")

on_test_epoch_start

on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")

test_step

test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")

on_test_epoch_end

on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")

on_load_checkpoint

on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)

on_save_checkpoint

on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)

IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)

configure_model

configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)

on_before_zero_grad

on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)

training_step

training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")

validation_step

validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")

test_step

test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")

on_validation_epoch_start

on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")

on_validation_epoch_end

on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")

on_test_epoch_start

on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")

on_test_epoch_end

on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")

on_save_checkpoint

on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }

on_load_checkpoint

on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)

ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )

on_evaluation_epoch_start

on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings

evaluation_step

evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)

on_evaluation_epoch_end

on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

on_evaluation_epoch_start

on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)

evaluation_step

evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)

on_evaluation_epoch_end

on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

base

Base class for all tasks in mmlearn that require training.

TrainingTask

Bases: LightningModule

Base class for all tasks in mmlearn that require training.

Parameters:

Name Type Description Default
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a partial function, created using functools.partial, that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a partial function that takes the optimizer as the only required argument or a dictionary with a scheduler key that specifies the scheduler and an optional extras key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
loss_fn Optional[Module]

Loss function to use for training.

None
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True

Raises:

Type Description
ValueError

If the loss function is not provided and either the validation or test loss needs to be computed.

Source code in mmlearn/tasks/base.py
class TrainingTask(L.LightningModule):
    """Base class for all tasks in mmlearn that require training.

    Parameters
    ----------
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a partial function,
        created using `functools.partial`, that takes the model parameters as the
        only required argument. If not provided, training will continue without an
        optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a partial function
        that takes the optimizer as the only required argument or a dictionary with
        a `scheduler` key that specifies the scheduler and an optional `extras` key
        that specifies additional arguments to pass to the scheduler. If not provided,
        the learning rate will not be adjusted during training.
    loss_fn : Optional[torch.nn.Module], optional, default=None
        Loss function to use for training.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.

    Raises
    ------
    ValueError
        If the loss function is not provided and either the validation or test loss
        needs to be computed.
    """  # noqa: W505

    def __init__(
        self,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        loss_fn: Optional[torch.nn.Module] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__()
        if loss_fn is None and (compute_validation_loss or compute_test_loss):
            raise ValueError(
                "Loss function must be provided to compute validation or test loss."
            )

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loss_fn = loss_fn
        self.compute_validation_loss = compute_validation_loss
        self.compute_test_loss = compute_test_loss

    def configure_optimizers(self) -> OptimizerLRScheduler:  # noqa: PLR0912
        """Configure the optimizer and learning rate scheduler."""
        if self.optimizer is None:
            rank_zero_warn(
                "Optimizer not provided. Training will continue without an optimizer. "
                "LR scheduler will not be used.",
            )
            return None

        weight_decay: Optional[float] = self.optimizer.keywords.get(
            "weight_decay", None
        )
        if weight_decay is None:  # try getting default value
            kw_param = inspect.signature(self.optimizer.func).parameters.get(
                "weight_decay"
            )
            if kw_param is not None and kw_param.default != inspect.Parameter.empty:
                weight_decay = kw_param.default

        parameters = [param for param in self.parameters() if param.requires_grad]

        if weight_decay is not None:
            decay_params = []
            no_decay_params = []

            for param in self.parameters():
                if not param.requires_grad:
                    continue

                if param.ndim < 2:  # includes all bias and normalization parameters
                    no_decay_params.append(param)
                else:
                    decay_params.append(param)

            parameters = [
                {
                    "params": decay_params,
                    "weight_decay": weight_decay,
                    "name": "weight_decay_params",
                },
                {
                    "params": no_decay_params,
                    "weight_decay": 0.0,
                    "name": "no_weight_decay_params",
                },
            ]

        optimizer = self.optimizer(parameters)
        if not isinstance(optimizer, torch.optim.Optimizer):
            raise TypeError(
                "Expected optimizer to be an instance of `torch.optim.Optimizer`, "
                f"but got {type(optimizer)}.",
            )

        if self.lr_scheduler is not None:
            if isinstance(self.lr_scheduler, dict):
                if "scheduler" not in self.lr_scheduler:
                    raise ValueError(
                        "Expected 'scheduler' key in the learning rate scheduler dictionary.",
                    )

                lr_scheduler = self.lr_scheduler["scheduler"](optimizer)
                if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                    raise TypeError(
                        "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                        f"but got {type(lr_scheduler)}.",
                    )
                lr_scheduler_dict: dict[
                    str, Union[torch.optim.lr_scheduler.LRScheduler, Any]
                ] = {"scheduler": lr_scheduler}

                if self.lr_scheduler.get("extras"):
                    lr_scheduler_dict.update(self.lr_scheduler["extras"])
                return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

            lr_scheduler = self.lr_scheduler(optimizer)
            if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                raise TypeError(
                    "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                    f"but got {type(lr_scheduler)}.",
                )
            return [optimizer], [lr_scheduler]

        return optimizer
configure_optimizers
configure_optimizers()

Configure the optimizer and learning rate scheduler.

Source code in mmlearn/tasks/base.py
def configure_optimizers(self) -> OptimizerLRScheduler:  # noqa: PLR0912
    """Configure the optimizer and learning rate scheduler."""
    if self.optimizer is None:
        rank_zero_warn(
            "Optimizer not provided. Training will continue without an optimizer. "
            "LR scheduler will not be used.",
        )
        return None

    weight_decay: Optional[float] = self.optimizer.keywords.get(
        "weight_decay", None
    )
    if weight_decay is None:  # try getting default value
        kw_param = inspect.signature(self.optimizer.func).parameters.get(
            "weight_decay"
        )
        if kw_param is not None and kw_param.default != inspect.Parameter.empty:
            weight_decay = kw_param.default

    parameters = [param for param in self.parameters() if param.requires_grad]

    if weight_decay is not None:
        decay_params = []
        no_decay_params = []

        for param in self.parameters():
            if not param.requires_grad:
                continue

            if param.ndim < 2:  # includes all bias and normalization parameters
                no_decay_params.append(param)
            else:
                decay_params.append(param)

        parameters = [
            {
                "params": decay_params,
                "weight_decay": weight_decay,
                "name": "weight_decay_params",
            },
            {
                "params": no_decay_params,
                "weight_decay": 0.0,
                "name": "no_weight_decay_params",
            },
        ]

    optimizer = self.optimizer(parameters)
    if not isinstance(optimizer, torch.optim.Optimizer):
        raise TypeError(
            "Expected optimizer to be an instance of `torch.optim.Optimizer`, "
            f"but got {type(optimizer)}.",
        )

    if self.lr_scheduler is not None:
        if isinstance(self.lr_scheduler, dict):
            if "scheduler" not in self.lr_scheduler:
                raise ValueError(
                    "Expected 'scheduler' key in the learning rate scheduler dictionary.",
                )

            lr_scheduler = self.lr_scheduler["scheduler"](optimizer)
            if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
                raise TypeError(
                    "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                    f"but got {type(lr_scheduler)}.",
                )
            lr_scheduler_dict: dict[
                str, Union[torch.optim.lr_scheduler.LRScheduler, Any]
            ] = {"scheduler": lr_scheduler}

            if self.lr_scheduler.get("extras"):
                lr_scheduler_dict.update(self.lr_scheduler["extras"])
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

        lr_scheduler = self.lr_scheduler(optimizer)
        if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
            raise TypeError(
                "Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
                f"but got {type(lr_scheduler)}.",
            )
        return [optimizer], [lr_scheduler]

    return optimizer

contrastive_pretraining

Contrastive pretraining task.

ModuleKeySpec dataclass

Module key specification for mapping modules to modalities.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class ModuleKeySpec:
    """Module key specification for mapping modules to modalities."""

    #: The key of the encoder module. If not provided, the modality name is used.
    encoder_key: Optional[str] = None

    #: The key of the head module. If not provided, the modality name is used.
    head_key: Optional[str] = None

    #: The key of the postprocessor module. If not provided, the modality name is used.
    postprocessor_key: Optional[str] = None

LossPairSpec dataclass

Specification for a pair of modalities to compute the contrastive loss.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class LossPairSpec:
    """Specification for a pair of modalities to compute the contrastive loss."""

    #: The pair of modalities to compute the contrastive loss between.
    modalities: tuple[str, str]

    #: The weight to apply to the contrastive loss for the pair of modalities.
    weight: float = 1.0

AuxiliaryTaskSpec dataclass

Specification for an auxiliary task to run alongside the main task.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class AuxiliaryTaskSpec:
    """Specification for an auxiliary task to run alongside the main task."""

    #: The modality of the encoder to use for the auxiliary task.
    modality: str

    #: The auxiliary task module. This is expected to be a partially-initialized
    #: instance of a :py:class:`~lightning.pytorch.core.LightningModule` created
    #: using :py:func:`functools.partial`, such that an initialized encoder can be
    #: passed as the only argument.
    task: Any  # `functools.partial[L.LightningModule]` expected

    #: The weight to apply to the auxiliary task loss.
    loss_weight: float = 1.0

EvaluationSpec dataclass

Specification for an evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
@dataclass
class EvaluationSpec:
    """Specification for an evaluation task."""

    #: The evaluation task module. This is expected to be an instance of
    #: :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.
    task: Any  # `EvaluationHooks` expected

    #: Whether to run the evaluation task during validation.
    run_on_validation: bool = True

    #: Whether to run the evaluation task during training.
    run_on_test: bool = True

ContrastivePretraining

Bases: TrainingTask

Contrastive pretraining task.

This class supports contrastive pretraining with N modalities of data. It allows the sharing of encoders, heads, and postprocessors across modalities. It also supports computing the contrastive loss between specified pairs of modalities, as well as training auxiliary tasks alongside the main contrastive pretraining task.

Parameters:

Name Type Description Default
encoders dict[str, Module]

A dictionary of encoders. The keys can be any string, including the names of any supported modalities. If the keys are not supported modalities, the modality_module_mapping parameter must be provided to map the encoders to specific modalities. The encoders are expected to take a dictionary of input values and return a list-like object with the first element being the encoded values. This first element is passed on to the heads or postprocessors and the remaining elements are ignored.

required
heads Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the encoder outputs, usually projection heads. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a 🇵🇾class:torch.nn.Sequential module. All head modules are expected to take a single input tensor and return a single output tensor.

None
postprocessors Optional[dict[str, Union[Module, dict[str, Module]]]]

A dictionary of modules that process the head outputs. If the keys do not correspond to the name of a supported modality, the modality_module_mapping parameter must be provided. If any of the values are dictionaries, they will be wrapped in a nn.Sequential module. All postprocessor modules are expected to take a single input tensor and return a single output tensor.

None
modality_module_mapping Optional[dict[str, ModuleKeySpec]]

A dictionary mapping modalities to encoders, heads, and postprocessors. Useful for reusing the same instance of a module across multiple modalities.

None
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
init_logit_scale float

The initial value of the logit scale parameter. This is the log of the scale factor applied to the logits before computing the contrastive loss.

1 / 0.07
max_logit_scale float

The maximum value of the logit scale parameter. The logit scale parameter is clamped to the range [0, log(max_logit_scale)].

100
learnable_logit_scale bool

Whether the logit scale parameter is learnable. If set to False, the logit scale parameter is treated as a constant.

True
loss Optional[Module]

The loss function to use.

None
modality_loss_pairs Optional[list[LossPairSpec]]

A list of pairs of modalities to compute the contrastive loss between and the weight to apply to each pair.

None
auxiliary_tasks dict[str, AuxiliaryTaskSpec]

Auxiliary tasks to run alongside the main contrastive pretraining task.

  • The auxiliary task module is expected to be a partially-initialized instance of a 🇵🇾class:~lightning.pytorch.core.LightningModule created using 🇵🇾func:functools.partial, such that an initialized encoder can be passed as the only argument.
  • The modality parameter specifies the modality of the encoder to use for the auxiliary task. The loss_weight parameter specifies the weight to apply to the auxiliary task loss.
None
log_auxiliary_tasks_loss bool

Whether to log the loss of auxiliary tasks to the main logger.

False
compute_validation_loss bool

Whether to compute the validation loss if a validation dataloader is provided. The loss function must be provided to compute the validation loss.

True
compute_test_loss bool

Whether to compute the test loss if a test dataloader is provided. The loss function must be provided to compute the test loss.

True
evaluation_tasks Optional[dict[str, EvaluationSpec]]

Evaluation tasks to run during validation, while training, and during testing.

None

Raises:

Type Description
ValueError
  • If the loss function is not provided and either the validation or test loss needs to be computed.
  • If the given modality is not supported.
  • If the encoder, head, or postprocessor is not mapped to a modality.
  • If an unsupported modality is found in the loss pair specification.
  • If an unsupported modality is found in the auxiliary tasks.
  • If the auxiliary task is not a partial function.
  • If the evaluation task is not an instance of 🇵🇾class:~mmlearn.tasks.hooks.EvaluationHooks.
Source code in mmlearn/tasks/contrastive_pretraining.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
@store(group="task", provider="mmlearn")
class ContrastivePretraining(TrainingTask):
    """Contrastive pretraining task.

    This class supports contrastive pretraining with ``N`` modalities of data. It
    allows the sharing of encoders, heads, and postprocessors across modalities.
    It also supports computing the contrastive loss between specified pairs of
    modalities, as well as training auxiliary tasks alongside the main contrastive
    pretraining task.

    Parameters
    ----------
    encoders : dict[str, torch.nn.Module]
        A dictionary of encoders. The keys can be any string, including the names of
        any supported modalities. If the keys are not supported modalities, the
        ``modality_module_mapping`` parameter must be provided to map the encoders to
        specific modalities. The encoders are expected to take a dictionary of input
        values and return a list-like object with the first element being the encoded
        values. This first element is passed on to the heads or postprocessors and
        the remaining elements are ignored.
    heads : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the encoder outputs, usually projection
        heads. If the keys do not correspond to the name of a supported modality,
        the ``modality_module_mapping`` parameter must be provided. If any of the values
        are dictionaries, they will be wrapped in a :py:class:`torch.nn.Sequential`
        module. All head modules are expected to take a single input tensor and
        return a single output tensor.
    postprocessors : Optional[dict[str, Union[torch.nn.Module, dict[str, torch.nn.Module]]]], optional, default=None
        A dictionary of modules that process the head outputs. If the keys do not
        correspond to the name of a supported modality, the `modality_module_mapping`
        parameter must be provided. If any of the values are dictionaries, they will
        be wrapped in a `nn.Sequential` module. All postprocessor modules are expected
        to take a single input tensor and return a single output tensor.
    modality_module_mapping : Optional[dict[str, ModuleKeySpec]], optional, default=None
        A dictionary mapping modalities to encoders, heads, and postprocessors.
        Useful for reusing the same instance of a module across multiple modalities.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    init_logit_scale : float, optional, default=1 / 0.07
        The initial value of the logit scale parameter. This is the log of the scale
        factor applied to the logits before computing the contrastive loss.
    max_logit_scale : float, optional, default=100
        The maximum value of the logit scale parameter. The logit scale parameter
        is clamped to the range ``[0, log(max_logit_scale)]``.
    learnable_logit_scale : bool, optional, default=True
        Whether the logit scale parameter is learnable. If set to False, the logit
        scale parameter is treated as a constant.
    loss : Optional[torch.nn.Module], optional, default=None
        The loss function to use.
    modality_loss_pairs : Optional[list[LossPairSpec]], optional, default=None
        A list of pairs of modalities to compute the contrastive loss between and
        the weight to apply to each pair.
    auxiliary_tasks : dict[str, AuxiliaryTaskSpec], optional, default=None
        Auxiliary tasks to run alongside the main contrastive pretraining task.

        - The auxiliary task module is expected to be a partially-initialized instance
          of a :py:class:`~lightning.pytorch.core.LightningModule` created using
          :py:func:`functools.partial`, such that an initialized encoder can be
          passed as the only argument.
        - The ``modality`` parameter specifies the modality of the encoder to use
          for the auxiliary task. The ``loss_weight`` parameter specifies the weight
          to apply to the auxiliary task loss.
    log_auxiliary_tasks_loss : bool, optional, default=False
        Whether to log the loss of auxiliary tasks to the main logger.
    compute_validation_loss : bool, optional, default=True
        Whether to compute the validation loss if a validation dataloader is provided.
        The loss function must be provided to compute the validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute the test loss if a test dataloader is provided. The loss
        function must be provided to compute the test loss.
    evaluation_tasks : Optional[dict[str, EvaluationSpec]], optional, default=None
        Evaluation tasks to run during validation, while training, and during testing.

    Raises
    ------
    ValueError

        - If the loss function is not provided and either the validation or test loss
          needs to be computed.
        - If the given modality is not supported.
        - If the encoder, head, or postprocessor is not mapped to a modality.
        - If an unsupported modality is found in the loss pair specification.
        - If an unsupported modality is found in the auxiliary tasks.
        - If the auxiliary task is not a partial function.
        - If the evaluation task is not an instance of :py:class:`~mmlearn.tasks.hooks.EvaluationHooks`.

    """  # noqa: W505

    def __init__(  # noqa: PLR0912, PLR0915
        self,
        encoders: dict[str, nn.Module],
        heads: Optional[dict[str, Union[nn.Module, dict[str, nn.Module]]]] = None,
        postprocessors: Optional[
            dict[str, Union[nn.Module, dict[str, nn.Module]]]
        ] = None,
        modality_module_mapping: Optional[dict[str, ModuleKeySpec]] = None,
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        init_logit_scale: float = 1 / 0.07,
        max_logit_scale: float = 100,
        learnable_logit_scale: bool = True,
        loss: Optional[nn.Module] = None,
        modality_loss_pairs: Optional[list[LossPairSpec]] = None,
        auxiliary_tasks: Optional[dict[str, AuxiliaryTaskSpec]] = None,
        log_auxiliary_tasks_loss: bool = False,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
        evaluation_tasks: Optional[dict[str, EvaluationSpec]] = None,
    ) -> None:
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )

        self.save_hyperparameters(
            ignore=[
                "encoders",
                "heads",
                "postprocessors",
                "modality_module_mapping",
                "loss",
                "auxiliary_tasks",
                "evaluation_tasks",
                "modality_loss_pairs",
            ]
        )

        if modality_module_mapping is None:
            # assume all the module dictionaries use the same keys corresponding
            # to modalities
            modality_module_mapping = {}
            for key in encoders:
                modality_module_mapping[key] = ModuleKeySpec(
                    encoder_key=key,
                    head_key=key,
                    postprocessor_key=key,
                )

        # match modalities to encoders, heads, and postprocessors
        modality_encoder_mapping: dict[str, Optional[str]] = {}
        modality_head_mapping: dict[str, Optional[str]] = {}
        modality_postprocessor_mapping: dict[str, Optional[str]] = {}
        for modality_key, module_mapping in modality_module_mapping.items():
            if not Modalities.has_modality(modality_key):
                raise ValueError(_unsupported_modality_error.format(modality_key))
            modality_encoder_mapping[modality_key] = module_mapping.encoder_key
            modality_head_mapping[modality_key] = module_mapping.head_key
            modality_postprocessor_mapping[modality_key] = (
                module_mapping.postprocessor_key
            )

        # ensure all modules are mapped to a modality
        for key in encoders:
            if key not in modality_encoder_mapping.values():
                if not Modalities.has_modality(key):
                    raise ValueError(_unsupported_modality_error.format(key))
                modality_encoder_mapping[key] = key

        if heads is not None:
            for key in heads:
                if key not in modality_head_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_head_mapping[key] = key

        if postprocessors is not None:
            for key in postprocessors:
                if key not in modality_postprocessor_mapping.values():
                    if not Modalities.has_modality(key):
                        raise ValueError(_unsupported_modality_error.format(key))
                    modality_postprocessor_mapping[key] = key

        self._available_modalities: list[Modality] = [
            Modalities.get_modality(modality_key)
            for modality_key in modality_encoder_mapping
        ]
        assert len(self._available_modalities) >= 2, (
            "Expected at least two modalities to be available. "
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the encoder modules.
        self.encoders = nn.ModuleDict(
            {
                Modalities.get_modality(modality_key).name: encoders[encoder_key]
                for modality_key, encoder_key in modality_encoder_mapping.items()
                if encoder_key is not None
            }
        )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the projection head modules. This can be
        #: ``None`` if no heads modules are provided.
        self.heads = None
        if heads is not None:
            self.heads = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: heads[head_key]
                    if isinstance(heads[head_key], nn.Module)
                    else nn.Sequential(*heads[head_key].values())
                    for modality_key, head_key in modality_head_mapping.items()
                    if head_key is not None and head_key in heads
                }
            )

        #: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
        #: modalities and the values are the postprocessor modules. This can be
        #: ``None`` if no postprocessor modules are provided.
        self.postprocessors = None
        if postprocessors is not None:
            self.postprocessors = nn.ModuleDict(
                {
                    Modalities.get_modality(modality_key).name: postprocessors[
                        postprocessor_key
                    ]
                    if isinstance(postprocessors[postprocessor_key], nn.Module)
                    else nn.Sequential(*postprocessors[postprocessor_key].values())
                    for modality_key, postprocessor_key in modality_postprocessor_mapping.items()
                    if postprocessor_key is not None
                    and postprocessor_key in postprocessors
                }
            )

        # set up logit scaling
        log_logit_scale = torch.ones([]) * np.log(init_logit_scale)
        self.max_logit_scale = max_logit_scale
        self.learnable_logit_scale = learnable_logit_scale

        if self.learnable_logit_scale:
            self.log_logit_scale = torch.nn.Parameter(
                log_logit_scale, requires_grad=True
            )
        else:
            self.register_buffer("log_logit_scale", log_logit_scale)

        # set up contrastive loss pairs
        if modality_loss_pairs is None:
            modality_loss_pairs = [
                LossPairSpec(modalities=(m1.name, m2.name))
                for m1, m2 in itertools.combinations(self._available_modalities, 2)
            ]

        for modality_pair in modality_loss_pairs:
            if not all(
                Modalities.get_modality(modality) in self._available_modalities
                for modality in modality_pair.modalities
            ):
                raise ValueError(
                    "Found unspecified modality in the loss pair specification "
                    f"{modality_pair.modalities}. Available modalities are "
                    f"{self._available_modalities}."
                )

        #: A list :py:class:`LossPairSpec` instances specifying the pairs of
        #: modalities to compute the contrastive loss between and the weight to
        #: apply to each pair.
        self.modality_loss_pairs = modality_loss_pairs

        # set up auxiliary tasks
        self.aux_task_specs = auxiliary_tasks or {}
        self.auxiliary_tasks: nn.ModuleDict[str, L.LightningModule] = nn.ModuleDict()
        for task_name, task_spec in self.aux_task_specs.items():
            if not Modalities.has_modality(task_spec.modality):
                raise ValueError(
                    f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
                    f"Available modalities are {self._available_modalities}."
                )
            if not isinstance(task_spec.task, partial):
                raise TypeError(
                    f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
                )

            self.auxiliary_tasks[task_name] = task_spec.task(
                self.encoders[Modalities.get_modality(task_spec.modality).name]
            )

        self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss

        if evaluation_tasks is not None:
            for eval_task_spec in evaluation_tasks.values():
                if not isinstance(eval_task_spec.task, EvaluationHooks):
                    raise TypeError(
                        f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
                        f"but got {type(eval_task_spec.task)}."
                    )

        #: A dictionary of evaluation tasks to run during validation, while training,
        #: or during testing.
        self.evaluation_tasks = evaluation_tasks

    def configure_model(self) -> None:
        """Configure the model."""
        if self.auxiliary_tasks:
            for task_name in self.auxiliary_tasks:
                self.auxiliary_tasks[task_name].configure_model()

    def encode(
        self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
    ) -> torch.Tensor:
        """Encode the input values for the given modality.

        Parameters
        ----------
        inputs : dict[str, Any]
            Input values.
        modality : Modality
            The modality to encode.
        normalize : bool, optional, default=False
            Whether to apply L2 normalization to the output (after the head and
            postprocessor layers, if present).

        Returns
        -------
        torch.Tensor
            The encoded values for the specified modality.
        """
        output = self.encoders[modality.name](inputs)[0]

        if self.postprocessors and modality.name in self.postprocessors:
            output = self.postprocessors[modality.name](output)

        if self.heads and modality.name in self.heads:
            output = self.heads[modality.name](output)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
        """Run the forward pass.

        Parameters
        ----------
        inputs : dict[str, Any]
            The input tensors to encode.

        Returns
        -------
        dict[str, torch.Tensor]
            The encodings for each modality.
        """
        outputs = {
            modality.embedding: self.encode(inputs, modality, normalize=True)
            for modality in self._available_modalities
            if modality.name in inputs
        }

        if not all(
            output.size(-1) == list(outputs.values())[0].size(-1)
            for output in outputs.values()
        ):
            raise ValueError("Expected all model outputs to have the same dimension.")

        return outputs

    def on_train_epoch_start(self) -> None:
        """Prepare for the training epoch.

        This method sets the modules to training mode.
        """
        self.encoders.train()
        if self.heads:
            self.heads.train()
        if self.postprocessors:
            self.postprocessors.train()

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Compute the loss for the batch.

        Parameters
        ----------
        batch : dict[str, Any]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        torch.Tensor
            The loss for the batch.
        """
        outputs = self(batch)

        with torch.no_grad():
            self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

        loss = self._compute_loss(batch, batch_idx, outputs)

        if loss is None:
            raise ValueError("The loss function must be provided for training.")

        self.log("train/loss", loss, prog_bar=True, sync_dist=True)
        self.log(
            "train/logit_scale",
            self.log_logit_scale.exp(),
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )

        return loss

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Zero out the gradients of the model."""
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_before_zero_grad(optimizer)

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch.

        This method sets the modules to evaluation mode and calls the
        ``on_evaluation_epoch_start`` method of each evaluation task.
        """
        self._on_eval_epoch_start("val")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "val")

    def on_validation_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def test_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        return self._shared_eval_step(batch, batch_idx, "test")

    def on_test_epoch_end(self) -> None:
        """Compute and log epoch-level metrics at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the model checkpoint after loading.

        The `on_load_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint after loading.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The loaded checkpoint.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_load_checkpoint(checkpoint)

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Modify the checkpoint before saving.

        The `on_save_checkpoint` method of auxiliary tasks is called here to allow
        them to modify the checkpoint before saving.

        Parameters
        ----------
        checkpoint : Dict[str, Any]
            The checkpoint to save.
        """
        if self.auxiliary_tasks:
            for task in self.auxiliary_tasks.values():
                task.on_save_checkpoint(checkpoint)

    def _compute_loss(
        self, batch: dict[str, Any], batch_idx: int, outputs: dict[str, torch.Tensor]
    ) -> Optional[torch.Tensor]:
        if self.loss_fn is None:
            return None

        contrastive_loss = self.loss_fn(
            outputs,
            batch["example_ids"],
            self.log_logit_scale.exp(),
            self.modality_loss_pairs,
        )

        auxiliary_losses: list[torch.Tensor] = []
        if self.auxiliary_tasks:
            for task_name, task_spec in self.aux_task_specs.items():
                auxiliary_task_output = self.auxiliary_tasks[task_name].training_step(
                    batch, batch_idx
                )
                if isinstance(auxiliary_task_output, torch.Tensor):
                    auxiliary_task_loss = auxiliary_task_output
                elif isinstance(auxiliary_task_output, Mapping):
                    auxiliary_task_loss = auxiliary_task_output["loss"]
                else:
                    raise ValueError(
                        "Expected auxiliary task output to be a tensor or a mapping "
                        f"containing a 'loss' key, but got {type(auxiliary_task_output)}."
                    )

                auxiliary_task_loss *= task_spec.loss_weight
                auxiliary_losses.append(auxiliary_task_loss)
                if self.log_auxiliary_tasks_loss:
                    self.log(
                        f"train/{task_name}_loss", auxiliary_task_loss, sync_dist=True
                    )

        if not auxiliary_losses:
            return contrastive_loss

        return torch.stack(auxiliary_losses).sum() + contrastive_loss

    def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
        """Prepare for the evaluation epoch."""
        self.encoders.eval()
        if self.heads:
            self.heads.eval()
        if self.postprocessors:
            self.postprocessors.eval()
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_start(self)

    def _shared_eval_step(
        self,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
        eval_type: Literal["val", "test"],
    ) -> Optional[torch.Tensor]:
        """Run a single evaluation step.

        Parameters
        ----------
        batch : dict[str, torch.Tensor]
            The batch of data to process.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            The loss for the batch or ``None`` if the loss function is not provided.
        """
        loss: Optional[torch.Tensor] = None
        if (eval_type == "val" and self.compute_validation_loss) or (
            eval_type == "test" and self.compute_test_loss
        ):
            outputs = self(batch)
            loss = self._compute_loss(batch, batch_idx, outputs)
            if loss is not None and not self.trainer.sanity_checking:
                self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True)

        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.evaluation_step(self, batch, batch_idx)

        return loss

    def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None:
        """Compute and log epoch-level metrics at the end of the evaluation epoch."""
        if self.evaluation_tasks:
            for task_spec in self.evaluation_tasks.values():
                if (eval_type == "val" and task_spec.run_on_validation) or (
                    eval_type == "test" and task_spec.run_on_test
                ):
                    task_spec.task.on_evaluation_epoch_end(self)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def configure_model(self) -> None:
    """Configure the model."""
    if self.auxiliary_tasks:
        for task_name in self.auxiliary_tasks:
            self.auxiliary_tasks[task_name].configure_model()
encode
encode(inputs, modality, normalize=False)

Encode the input values for the given modality.

Parameters:

Name Type Description Default
inputs dict[str, Any]

Input values.

required
modality Modality

The modality to encode.

required
normalize bool

Whether to apply L2 normalization to the output (after the head and postprocessor layers, if present).

False

Returns:

Type Description
Tensor

The encoded values for the specified modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def encode(
    self, inputs: dict[str, Any], modality: Modality, normalize: bool = False
) -> torch.Tensor:
    """Encode the input values for the given modality.

    Parameters
    ----------
    inputs : dict[str, Any]
        Input values.
    modality : Modality
        The modality to encode.
    normalize : bool, optional, default=False
        Whether to apply L2 normalization to the output (after the head and
        postprocessor layers, if present).

    Returns
    -------
    torch.Tensor
        The encoded values for the specified modality.
    """
    output = self.encoders[modality.name](inputs)[0]

    if self.postprocessors and modality.name in self.postprocessors:
        output = self.postprocessors[modality.name](output)

    if self.heads and modality.name in self.heads:
        output = self.heads[modality.name](output)

    if normalize:
        output = torch.nn.functional.normalize(output, p=2, dim=-1)

    return output
forward
forward(inputs)

Run the forward pass.

Parameters:

Name Type Description Default
inputs dict[str, Any]

The input tensors to encode.

required

Returns:

Type Description
dict[str, Tensor]

The encodings for each modality.

Source code in mmlearn/tasks/contrastive_pretraining.py
def forward(self, inputs: dict[str, Any]) -> dict[str, torch.Tensor]:
    """Run the forward pass.

    Parameters
    ----------
    inputs : dict[str, Any]
        The input tensors to encode.

    Returns
    -------
    dict[str, torch.Tensor]
        The encodings for each modality.
    """
    outputs = {
        modality.embedding: self.encode(inputs, modality, normalize=True)
        for modality in self._available_modalities
        if modality.name in inputs
    }

    if not all(
        output.size(-1) == list(outputs.values())[0].size(-1)
        for output in outputs.values()
    ):
        raise ValueError("Expected all model outputs to have the same dimension.")

    return outputs
on_train_epoch_start
on_train_epoch_start()

Prepare for the training epoch.

This method sets the modules to training mode.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_train_epoch_start(self) -> None:
    """Prepare for the training epoch.

    This method sets the modules to training mode.
    """
    self.encoders.train()
    if self.heads:
        self.heads.train()
    if self.postprocessors:
        self.postprocessors.train()
training_step
training_step(batch, batch_idx)

Compute the loss for the batch.

Parameters:

Name Type Description Default
batch dict[str, Any]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The loss for the batch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Compute the loss for the batch.

    Parameters
    ----------
    batch : dict[str, Any]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss for the batch.
    """
    outputs = self(batch)

    with torch.no_grad():
        self.log_logit_scale.clamp_(0, math.log(self.max_logit_scale))

    loss = self._compute_loss(batch, batch_idx, outputs)

    if loss is None:
        raise ValueError("The loss function must be provided for training.")

    self.log("train/loss", loss, prog_bar=True, sync_dist=True)
    self.log(
        "train/logit_scale",
        self.log_logit_scale.exp(),
        prog_bar=True,
        on_step=True,
        on_epoch=False,
    )

    return loss
on_before_zero_grad
on_before_zero_grad(optimizer)

Zero out the gradients of the model.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Zero out the gradients of the model."""
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_before_zero_grad(optimizer)
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

This method sets the modules to evaluation mode and calls the on_evaluation_epoch_start method of each evaluation task.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch.

    This method sets the modules to evaluation mode and calls the
    ``on_evaluation_epoch_start`` method of each evaluation task.
    """
    self._on_eval_epoch_start("val")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def validation_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "val")
on_validation_epoch_end
on_validation_epoch_end()

Compute and log epoch-level metrics at the end of the validation epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_validation_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch of data to process.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Tensor]

The loss for the batch or None if the loss function is not provided.

Source code in mmlearn/tasks/contrastive_pretraining.py
def test_step(
    self, batch: dict[str, torch.Tensor], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, torch.Tensor]
        The batch of data to process.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        The loss for the batch or ``None`` if the loss function is not provided.
    """
    return self._shared_eval_step(batch, batch_idx, "test")
on_test_epoch_end
on_test_epoch_end()

Compute and log epoch-level metrics at the end of the test epoch.

Source code in mmlearn/tasks/contrastive_pretraining.py
def on_test_epoch_end(self) -> None:
    """Compute and log epoch-level metrics at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_load_checkpoint
on_load_checkpoint(checkpoint)

Modify the model checkpoint after loading.

The on_load_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint after loading.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The loaded checkpoint.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the model checkpoint after loading.

    The `on_load_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint after loading.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The loaded checkpoint.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_load_checkpoint(checkpoint)
on_save_checkpoint
on_save_checkpoint(checkpoint)

Modify the checkpoint before saving.

The on_save_checkpoint method of auxiliary tasks is called here to allow them to modify the checkpoint before saving.

Parameters:

Name Type Description Default
checkpoint Dict[str, Any]

The checkpoint to save.

required
Source code in mmlearn/tasks/contrastive_pretraining.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Modify the checkpoint before saving.

    The `on_save_checkpoint` method of auxiliary tasks is called here to allow
    them to modify the checkpoint before saving.

    Parameters
    ----------
    checkpoint : Dict[str, Any]
        The checkpoint to save.
    """
    if self.auxiliary_tasks:
        for task in self.auxiliary_tasks.values():
            task.on_save_checkpoint(checkpoint)

hooks

Task-related hooks for Lightning modules.

EvaluationHooks

Hooks for evaluation.

Source code in mmlearn/tasks/hooks.py
class EvaluationHooks:
    """Hooks for evaluation."""

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Prepare the evaluation loop.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        """

    def evaluation_step(
        self, pl_module: pl.LightningModule, batch: Any, batch_idx: int
    ) -> Optional[Mapping[str, Any]]:
        """Run a single evaluation step.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : Any
            A batch of data.
        batch_idx : int
            The index of the batch.

        Returns
        -------
        Optional[Mapping[str, Any]]
            A dictionary of evaluation results for the batch or ``None`` if no
            batch results are available.

        """
        rank_zero_warn(
            f"`evaluation_step` must be implemented to use {self.__class__.__name__} for evaluation."
        )
        return None

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[Union[Mapping[str, Any]]]:
        """Run after the evaluation epoch.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[Union[Mapping[str, Any]]]
            A dictionary of evaluation results or ``None`` if no results are available.
        """
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Prepare the evaluation loop.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
Source code in mmlearn/tasks/hooks.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Prepare the evaluation loop.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    """
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run a single evaluation step.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch Any

A batch of data.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Optional[Mapping[str, Any]]

A dictionary of evaluation results for the batch or None if no batch results are available.

Source code in mmlearn/tasks/hooks.py
def evaluation_step(
    self, pl_module: pl.LightningModule, batch: Any, batch_idx: int
) -> Optional[Mapping[str, Any]]:
    """Run a single evaluation step.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : Any
        A batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    Optional[Mapping[str, Any]]
        A dictionary of evaluation results for the batch or ``None`` if no
        batch results are available.

    """
    rank_zero_warn(
        f"`evaluation_step` must be implemented to use {self.__class__.__name__} for evaluation."
    )
    return None
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Run after the evaluation epoch.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[Union[Mapping[str, Any]]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/hooks.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[Union[Mapping[str, Any]]]:
    """Run after the evaluation epoch.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[Union[Mapping[str, Any]]]
        A dictionary of evaluation results or ``None`` if no results are available.
    """

ijepa

IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task.

IJEPA

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:

Name Type Description Default
encoder VisionTransformer

Vision transformer encoder.

required
predictor VisionTransformerPredictor

Vision transformer predictor.

required
optimizer Optional[partial[Optimizer]]

The optimizer to use for training. This is expected to be a 🇵🇾func:~functools.partial function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

None
lr_scheduler Optional[Union[dict[str, Union[partial[LRScheduler], Any]], partial[LRScheduler]]]

The learning rate scheduler to use for training. This can be a 🇵🇾func:~functools.partial function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

None
ema_decay float

Initial momentum for EMA of target encoder.

0.996
ema_decay_end float

Final momentum for EMA of target encoder.

1.0
ema_anneal_end_step int

Number of steps to anneal EMA momentum to ema_decay_end.

1000
loss_fn Optional[Callable[[Tensor, Tensor], Tensor]]

Loss function to use. If not provided, defaults to 🇵🇾func:~torch.nn.functional.smooth_l1_loss.

None
compute_validation_loss bool

Whether to compute validation loss.

True
compute_test_loss bool

Whether to compute test loss.

True
Source code in mmlearn/tasks/ijepa.py
@store(group="task", provider="mmlearn", zen_partial=False)
class IJEPA(TrainingTask):
    """Pretraining module for IJEPA.

    This class implements the IJEPA (Image Joint-Embedding Predictive Architecture)
    pretraining task using PyTorch Lightning. It trains an encoder and a predictor to
    reconstruct masked regions of an image based on its unmasked context.

    Parameters
    ----------
    encoder : VisionTransformer
        Vision transformer encoder.
    predictor : VisionTransformerPredictor
        Vision transformer predictor.
    optimizer : Optional[partial[torch.optim.Optimizer]], optional, default=None
        The optimizer to use for training. This is expected to be a :py:func:`~functools.partial`
        function that takes the model parameters as the only required argument.
        If not provided, training will continue without an optimizer.
    lr_scheduler : Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None
        The learning rate scheduler to use for training. This can be a
        :py:func:`~functools.partial` function that takes the optimizer as the only
        required argument or a dictionary with a ``'scheduler'`` key that specifies
        the scheduler and an optional ``'extras'`` key that specifies additional
        arguments to pass to the scheduler. If not provided, the learning rate will
        not be adjusted during training.
    ema_decay : float, optional, default=0.996
        Initial momentum for EMA of target encoder.
    ema_decay_end : float, optional, default=1.0
        Final momentum for EMA of target encoder.
    ema_anneal_end_step : int, optional, default=1000
        Number of steps to anneal EMA momentum to ``ema_decay_end``.
    loss_fn : Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional
        Loss function to use. If not provided, defaults to
        :py:func:`~torch.nn.functional.smooth_l1_loss`.
    compute_validation_loss : bool, optional, default=True
        Whether to compute validation loss.
    compute_test_loss : bool, optional, default=True
        Whether to compute test loss.
    """  # noqa: W505

    def __init__(
        self,
        encoder: VisionTransformer,
        predictor: VisionTransformerPredictor,
        modality: str = "RGB",
        optimizer: Optional[partial[torch.optim.Optimizer]] = None,
        lr_scheduler: Optional[
            Union[
                dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
                partial[torch.optim.lr_scheduler.LRScheduler],
            ]
        ] = None,
        ema_decay: float = 0.996,
        ema_decay_end: float = 1.0,
        ema_anneal_end_step: int = 1000,
        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        compute_validation_loss: bool = True,
        compute_test_loss: bool = True,
    ):
        super().__init__(
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            loss_fn=loss_fn if loss_fn is not None else F.smooth_l1_loss,
            compute_validation_loss=compute_validation_loss,
            compute_test_loss=compute_test_loss,
        )
        self.modality = Modalities.get_modality(modality)
        self.mask_generator = IJEPAMaskGenerator()

        self.encoder = encoder
        self.predictor = predictor

        self.predictor.num_patches = encoder.patch_embed.num_patches
        self.predictor.embed_dim = encoder.embed_dim
        self.predictor.num_heads = encoder.num_heads

        self.target_encoder = ExponentialMovingAverage(
            self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
        )

    def configure_model(self) -> None:
        """Configure the model."""
        self.target_encoder.configure_model(self.device)

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
        """Perform exponential moving average update of target encoder.

        This is done right after the ``optimizer.step()`, which comes just before
        ``optimizer.zero_grad()`` to account for gradient accumulation.
        """
        if self.target_encoder is not None:
            self.target_encoder.step(self.encoder)

    def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
        """Perform a single training step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        torch.Tensor
            Loss value.
        """
        return self._shared_step(batch, batch_idx, step_type="train")

    def validation_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single validation step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed.
        """
        return self._shared_step(batch, batch_idx, step_type="val")

    def test_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> Optional[torch.Tensor]:
        """Run a single test step.

        Parameters
        ----------
        batch : dict[str, Any]
            A batch of data.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        Optional[torch.Tensor]
            Loss value or ``None`` if no loss is computed
        """
        return self._shared_step(batch, batch_idx, step_type="test")

    def on_validation_epoch_start(self) -> None:
        """Prepare for the validation epoch."""
        self._on_eval_epoch_start("val")

    def on_validation_epoch_end(self) -> None:
        """Actions at the end of the validation epoch."""
        self._on_eval_epoch_end("val")

    def on_test_epoch_start(self) -> None:
        """Prepare for the test epoch."""
        self._on_eval_epoch_start("test")

    def on_test_epoch_end(self) -> None:
        """Actions at the end of the test epoch."""
        self._on_eval_epoch_end("test")

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Add relevant EMA state to the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to save the EMA state to.
        """
        if self.target_encoder is not None:
            checkpoint["ema_params"] = {
                "decay": self.target_encoder.decay,
                "num_updates": self.target_encoder.num_updates,
            }

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Restore EMA state from the checkpoint.

        Parameters
        ----------
        checkpoint : dict[str, Any]
            The state dictionary to restore the EMA state from.
        """
        if "ema_params" in checkpoint and self.target_encoder is not None:
            ema_params = checkpoint.pop("ema_params")
            self.target_encoder.decay = ema_params["decay"]
            self.target_encoder.num_updates = ema_params["num_updates"]

            self.target_encoder.restore(self.encoder)

    def _shared_step(
        self, batch: dict[str, Any], batch_idx: int, step_type: str
    ) -> Optional[torch.Tensor]:
        images = batch[self.modality.name]

        # Generate masks
        batch_size = images.size(0)
        mask_info = self.mask_generator(batch_size=batch_size)

        # Extract masks and move to device
        device = images.device
        encoder_masks = [mask.to(device) for mask in mask_info["encoder_masks"]]
        predictor_masks = [mask.to(device) for mask in mask_info["predictor_masks"]]

        # Forward pass through target encoder to get h
        with torch.no_grad():
            h = self.target_encoder.model(batch)[0]
            h = F.layer_norm(h, h.size()[-1:])
            h_masked = apply_masks(h, predictor_masks)
            h_masked = repeat_interleave_batch(
                h_masked, images.size(0), repeat=len(encoder_masks)
            )

        # Forward pass through encoder with encoder_masks
        batch[self.modality.mask] = encoder_masks
        z = self.encoder(batch)[0]

        # Pass z through predictor with encoder_masks and predictor_masks
        z_pred = self.predictor(z, encoder_masks, predictor_masks)

        if step_type == "train":
            self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

        if self.loss_fn is not None and (
            step_type == "train"
            or (step_type == "val" and self.compute_validation_loss)
            or (step_type == "test" and self.compute_test_loss)
        ):
            # Compute loss between z_pred and h_masked
            loss = self.loss_fn(z_pred, h_masked)

            # Log loss
            self.log(f"{step_type}/loss", loss, prog_bar=True, sync_dist=True)

            return loss

        return None

    def _on_eval_epoch_start(self, step_type: str) -> None:
        """Initialize states or configurations at the start of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/start", 1, prog_bar=True, sync_dist=True)

    def _on_eval_epoch_end(self, step_type: str) -> None:
        """Finalize states or logging at the end of an evaluation epoch.

        Parameters
        ----------
        step_type : str
            Type of the evaluation phase ("val" or "test").
        """
        if (
            step_type == "val"
            and self.compute_validation_loss
            or step_type == "test"
            and self.compute_test_loss
        ):
            self.log(f"{step_type}/end", 1, prog_bar=True, sync_dist=True)
configure_model
configure_model()

Configure the model.

Source code in mmlearn/tasks/ijepa.py
def configure_model(self) -> None:
    """Configure the model."""
    self.target_encoder.configure_model(self.device)
on_before_zero_grad
on_before_zero_grad(optimizer)

Perform exponential moving average update of target encoder.

This is done right after the optimizer.step()`, which comes just beforeoptimizer.zero_grad()`` to account for gradient accumulation.

Source code in mmlearn/tasks/ijepa.py
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
    """Perform exponential moving average update of target encoder.

    This is done right after the ``optimizer.step()`, which comes just before
    ``optimizer.zero_grad()`` to account for gradient accumulation.
    """
    if self.target_encoder is not None:
        self.target_encoder.step(self.encoder)
training_step
training_step(batch, batch_idx)

Perform a single training step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Tensor

Loss value.

Source code in mmlearn/tasks/ijepa.py
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
    """Perform a single training step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    torch.Tensor
        Loss value.
    """
    return self._shared_step(batch, batch_idx, step_type="train")
validation_step
validation_step(batch, batch_idx)

Run a single validation step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed.

Source code in mmlearn/tasks/ijepa.py
def validation_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single validation step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed.
    """
    return self._shared_step(batch, batch_idx, step_type="val")
test_step
test_step(batch, batch_idx)

Run a single test step.

Parameters:

Name Type Description Default
batch dict[str, Any]

A batch of data.

required
batch_idx int

Index of the batch.

required

Returns:

Type Description
Optional[Tensor]

Loss value or None if no loss is computed

Source code in mmlearn/tasks/ijepa.py
def test_step(
    self, batch: dict[str, Any], batch_idx: int
) -> Optional[torch.Tensor]:
    """Run a single test step.

    Parameters
    ----------
    batch : dict[str, Any]
        A batch of data.
    batch_idx : int
        Index of the batch.

    Returns
    -------
    Optional[torch.Tensor]
        Loss value or ``None`` if no loss is computed
    """
    return self._shared_step(batch, batch_idx, step_type="test")
on_validation_epoch_start
on_validation_epoch_start()

Prepare for the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_start(self) -> None:
    """Prepare for the validation epoch."""
    self._on_eval_epoch_start("val")
on_validation_epoch_end
on_validation_epoch_end()

Actions at the end of the validation epoch.

Source code in mmlearn/tasks/ijepa.py
def on_validation_epoch_end(self) -> None:
    """Actions at the end of the validation epoch."""
    self._on_eval_epoch_end("val")
on_test_epoch_start
on_test_epoch_start()

Prepare for the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_start(self) -> None:
    """Prepare for the test epoch."""
    self._on_eval_epoch_start("test")
on_test_epoch_end
on_test_epoch_end()

Actions at the end of the test epoch.

Source code in mmlearn/tasks/ijepa.py
def on_test_epoch_end(self) -> None:
    """Actions at the end of the test epoch."""
    self._on_eval_epoch_end("test")
on_save_checkpoint
on_save_checkpoint(checkpoint)

Add relevant EMA state to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to save the EMA state to.

required
Source code in mmlearn/tasks/ijepa.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Add relevant EMA state to the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to save the EMA state to.
    """
    if self.target_encoder is not None:
        checkpoint["ema_params"] = {
            "decay": self.target_encoder.decay,
            "num_updates": self.target_encoder.num_updates,
        }
on_load_checkpoint
on_load_checkpoint(checkpoint)

Restore EMA state from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The state dictionary to restore the EMA state from.

required
Source code in mmlearn/tasks/ijepa.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Restore EMA state from the checkpoint.

    Parameters
    ----------
    checkpoint : dict[str, Any]
        The state dictionary to restore the EMA state from.
    """
    if "ema_params" in checkpoint and self.target_encoder is not None:
        ema_params = checkpoint.pop("ema_params")
        self.target_encoder.decay = ema_params["decay"]
        self.target_encoder.num_updates = ema_params["num_updates"]

        self.target_encoder.restore(self.encoder)

zero_shot_classification

Zero-shot classification evaluation task.

ClassificationTaskSpec dataclass

Specification for a classification task.

Source code in mmlearn/tasks/zero_shot_classification.py
@dataclass
class ClassificationTaskSpec:
    """Specification for a classification task."""

    #: The modality of the query input.
    query_modality: str

    #: The top-k values for which to compute the classification metrics like accuracy.
    top_k: list[int]

ZeroShotClassification

Bases: EvaluationHooks

Zero-shot classification evaluation task.

This task evaluates the zero-shot classification performance.

Parameters:

Name Type Description Default
task_specs list[ClassificationTaskSpec]

A list of classification task specifications.

required
tokenizer Callable[[Union[str, list[str]]], Union[Tensor, dict[str, Tensor]]]

A function to tokenize text inputs.

required
Source code in mmlearn/tasks/zero_shot_classification.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotClassification(EvaluationHooks):
    """Zero-shot classification evaluation task.

    This task evaluates the zero-shot classification performance.

    Parameters
    ----------
    task_specs : list[ClassificationTaskSpec]
        A list of classification task specifications.
    tokenizer : Callable[[Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]]
        A function to tokenize text inputs.
    """  # noqa: W505

    def __init__(
        self,
        task_specs: list[ClassificationTaskSpec],
        tokenizer: Callable[
            [Union[str, list[str]]], Union[torch.Tensor, dict[str, torch.Tensor]]
        ],
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.task_specs = task_specs
        for spec in self.task_specs:
            assert Modalities.has_modality(spec.query_modality)

        self.metrics: dict[tuple[str, int], MetricCollection] = {}
        self._embeddings_store: dict[int, torch.Tensor] = {}

    def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
        """Set up the evaluation task.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Raises
        ------
        ValueError
            - If the task is not being run for validation or testing.
            - If the dataset does not have the required attributes to perform zero-shot
              classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
        """
        if pl_module.trainer.validating:
            eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
        elif pl_module.trainer.testing:
            eval_dataset = pl_module.trainer.test_dataloaders.dataset
        else:
            raise ValueError(
                "ZeroShotClassification task is only supported for validation and testing."
            )

        self.all_dataset_info = {}

        # create metrics for each dataset/query_modality combination
        if not self.metrics:
            for dataset_index, dataset in enumerate(eval_dataset.datasets):
                dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
                try:
                    id2label: dict[int, str] = dataset.id2label
                except AttributeError:
                    raise ValueError(
                        f"Dataset '{dataset_name}' must have a `id2label` attribute "
                        "to perform zero-shot classification."
                    ) from None

                try:
                    zero_shot_prompt_templates: list[str] = (
                        dataset.zero_shot_prompt_templates
                    )
                except AttributeError:
                    raise ValueError(
                        "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                    ) from None

                num_classes = len(id2label)

                self.all_dataset_info[dataset_index] = {
                    "name": dataset_name,
                    "id2label": id2label,
                    "prompt_templates": zero_shot_prompt_templates,
                    "num_classes": num_classes,
                }

                for spec in self.task_specs:
                    query_modality = Modalities.get_modality(spec.query_modality).name
                    self.metrics[(query_modality, dataset_index)] = (
                        self._create_metrics(
                            num_classes,
                            spec.top_k,
                            prefix=f"{dataset_name}/{query_modality}_",
                            postfix="",
                        )
                    )

        for metric in self.metrics.values():
            metric.to(pl_module.device)

        for dataset_index, dataset_info in self.all_dataset_info.items():
            id2label = dataset_info["id2label"]
            prompt_templates: list[str] = dataset_info["prompt_templates"]
            labels = list(id2label.values())

            with torch.no_grad():
                chunk_size = 10
                all_embeddings = []

                for i in tqdm(
                    range(0, len(labels), chunk_size),
                    desc="Encoding class descriptions",
                ):
                    batch_labels = labels[i : min(i + chunk_size, len(labels))]
                    descriptions = [
                        template.format(label)
                        for label in batch_labels
                        for template in prompt_templates
                    ]
                    tokenized_descriptions = move_data_to_device(
                        self.tokenizer(descriptions),
                        pl_module.device,
                    )

                    # Encode the chunk using the pl_module's encode method
                    chunk_embeddings = pl_module.encode(
                        tokenized_descriptions, Modalities.TEXT
                    )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                    chunk_embeddings = chunk_embeddings.reshape(
                        len(batch_labels), len(prompt_templates), -1
                    ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                    chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                    # Append the chunk embeddings to the list
                    all_embeddings.append(chunk_embeddings)

                # Concatenate all chunk embeddings into a single tensor
                class_embeddings = torch.cat(all_embeddings, dim=0)

            self._embeddings_store[dataset_index] = class_embeddings

    def evaluation_step(
        self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
    ) -> None:
        """Compute logits and update metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A batch of data.
        batch_idx : int
            The index of the batch.
        """
        if pl_module.trainer.sanity_checking:
            return

        for (query_modality, dataset_index), metric_collection in self.metrics.items():
            matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

            if not matching_indices.numel():
                continue

            class_embeddings = self._embeddings_store[dataset_index]
            query_embeddings: torch.Tensor = pl_module.encode(
                batch, Modalities.get_modality(query_modality)
            )
            query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
            query_embeddings = query_embeddings[matching_indices]

            if self.all_dataset_info[dataset_index]["num_classes"] == 2:
                softmax_output = _safe_matmul(
                    query_embeddings, class_embeddings
                ).softmax(dim=-1)
                logits = softmax_output[:, 1] - softmax_output[:, 0]
            else:
                logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
            targets = batch[Modalities.get_modality(query_modality).target][
                matching_indices
            ]

            metric_collection.update(logits, targets)

    def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
        """Compute and reset metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        dict[str, Any]
            The computed metrics.
        """
        results = {}
        for metric_collection in self.metrics.values():
            results.update(metric_collection.compute())
            metric_collection.reset()

        self._embeddings_store.clear()

        eval_type = "val" if pl_module.trainer.validating else "test"
        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results

    @staticmethod
    def _create_metrics(
        num_classes: int, top_k: list[int], prefix: str, postfix: str
    ) -> MetricCollection:
        """Create a collection of classification metrics."""
        task_type = "binary" if num_classes == 2 else "multiclass"
        acc_metrics = (
            {
                f"top{k}_accuracy": Accuracy(
                    task=task_type, num_classes=num_classes, top_k=k, average="micro"
                )
                for k in top_k
            }
            if num_classes > 2
            else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
        )
        return MetricCollection(
            {
                "precision": Precision(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "recall": Recall(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "f1_score_macro": F1Score(
                    task=task_type,
                    num_classes=num_classes,
                    average="macro" if num_classes > 2 else "micro",
                ),
                "aucroc": AUROC(task=task_type, num_classes=num_classes),
                **acc_metrics,
            },
            prefix=prefix,
            postfix=postfix,
            compute_groups=True,
        )
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Set up the evaluation task.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Raises:

Type Description
ValueError
  • If the task is not being run for validation or testing.
  • If the dataset does not have the required attributes to perform zero-shot classification (i.e id2label and zero_shot_prompt_templates).
Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_start(self, pl_module: LightningModule) -> None:
    """Set up the evaluation task.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Raises
    ------
    ValueError
        - If the task is not being run for validation or testing.
        - If the dataset does not have the required attributes to perform zero-shot
          classification (i.e ``id2label`` and ``zero_shot_prompt_templates``).
    """
    if pl_module.trainer.validating:
        eval_dataset: CombinedDataset = pl_module.trainer.val_dataloaders.dataset
    elif pl_module.trainer.testing:
        eval_dataset = pl_module.trainer.test_dataloaders.dataset
    else:
        raise ValueError(
            "ZeroShotClassification task is only supported for validation and testing."
        )

    self.all_dataset_info = {}

    # create metrics for each dataset/query_modality combination
    if not self.metrics:
        for dataset_index, dataset in enumerate(eval_dataset.datasets):
            dataset_name = getattr(dataset, "name", dataset.__class__.__name__)
            try:
                id2label: dict[int, str] = dataset.id2label
            except AttributeError:
                raise ValueError(
                    f"Dataset '{dataset_name}' must have a `id2label` attribute "
                    "to perform zero-shot classification."
                ) from None

            try:
                zero_shot_prompt_templates: list[str] = (
                    dataset.zero_shot_prompt_templates
                )
            except AttributeError:
                raise ValueError(
                    "Dataset must have a `zero_shot_prompt_templates` attribute to perform zero-shot classification."
                ) from None

            num_classes = len(id2label)

            self.all_dataset_info[dataset_index] = {
                "name": dataset_name,
                "id2label": id2label,
                "prompt_templates": zero_shot_prompt_templates,
                "num_classes": num_classes,
            }

            for spec in self.task_specs:
                query_modality = Modalities.get_modality(spec.query_modality).name
                self.metrics[(query_modality, dataset_index)] = (
                    self._create_metrics(
                        num_classes,
                        spec.top_k,
                        prefix=f"{dataset_name}/{query_modality}_",
                        postfix="",
                    )
                )

    for metric in self.metrics.values():
        metric.to(pl_module.device)

    for dataset_index, dataset_info in self.all_dataset_info.items():
        id2label = dataset_info["id2label"]
        prompt_templates: list[str] = dataset_info["prompt_templates"]
        labels = list(id2label.values())

        with torch.no_grad():
            chunk_size = 10
            all_embeddings = []

            for i in tqdm(
                range(0, len(labels), chunk_size),
                desc="Encoding class descriptions",
            ):
                batch_labels = labels[i : min(i + chunk_size, len(labels))]
                descriptions = [
                    template.format(label)
                    for label in batch_labels
                    for template in prompt_templates
                ]
                tokenized_descriptions = move_data_to_device(
                    self.tokenizer(descriptions),
                    pl_module.device,
                )

                # Encode the chunk using the pl_module's encode method
                chunk_embeddings = pl_module.encode(
                    tokenized_descriptions, Modalities.TEXT
                )  # shape: [chunk_size x len(prompt_templates), embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)
                chunk_embeddings = chunk_embeddings.reshape(
                    len(batch_labels), len(prompt_templates), -1
                ).mean(dim=1)  # shape: [chunk_size, embed_dim]
                chunk_embeddings /= chunk_embeddings.norm(p=2, dim=-1, keepdim=True)

                # Append the chunk embeddings to the list
                all_embeddings.append(chunk_embeddings)

            # Concatenate all chunk embeddings into a single tensor
            class_embeddings = torch.cat(all_embeddings, dim=0)

        self._embeddings_store[dataset_index] = class_embeddings
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Compute logits and update metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A batch of data.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_classification.py
def evaluation_step(
    self, pl_module: LightningModule, batch: dict[str, torch.Tensor], batch_idx: int
) -> None:
    """Compute logits and update metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A batch of data.
    batch_idx : int
        The index of the batch.
    """
    if pl_module.trainer.sanity_checking:
        return

    for (query_modality, dataset_index), metric_collection in self.metrics.items():
        matching_indices = torch.where(batch["dataset_index"] == dataset_index)[0]

        if not matching_indices.numel():
            continue

        class_embeddings = self._embeddings_store[dataset_index]
        query_embeddings: torch.Tensor = pl_module.encode(
            batch, Modalities.get_modality(query_modality)
        )
        query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
        query_embeddings = query_embeddings[matching_indices]

        if self.all_dataset_info[dataset_index]["num_classes"] == 2:
            softmax_output = _safe_matmul(
                query_embeddings, class_embeddings
            ).softmax(dim=-1)
            logits = softmax_output[:, 1] - softmax_output[:, 0]
        else:
            logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
        targets = batch[Modalities.get_modality(query_modality).target][
            matching_indices
        ]

        metric_collection.update(logits, targets)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute and reset metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
dict[str, Any]

The computed metrics.

Source code in mmlearn/tasks/zero_shot_classification.py
def on_evaluation_epoch_end(self, pl_module: LightningModule) -> dict[str, Any]:
    """Compute and reset metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    dict[str, Any]
        The computed metrics.
    """
    results = {}
    for metric_collection in self.metrics.values():
        results.update(metric_collection.compute())
        metric_collection.reset()

    self._embeddings_store.clear()

    eval_type = "val" if pl_module.trainer.validating else "test"
    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

zero_shot_retrieval

Zero-shot cross-modal retrieval evaluation task.

RetrievalTaskSpec dataclass

Specification for a retrieval task.

Source code in mmlearn/tasks/zero_shot_retrieval.py
@dataclass
class RetrievalTaskSpec:
    """Specification for a retrieval task."""

    #: The query modality.
    query_modality: str

    #: The target modality.
    target_modality: str

    #: The top-k values for which to compute the retrieval recall metrics.
    top_k: list[int]

ZeroShotCrossModalRetrieval

Bases: EvaluationHooks

Zero-shot cross-modal retrieval evaluation task.

This task evaluates the retrieval performance of a model on a set of query-target pairs. The model is expected to produce embeddings for both the query and target modalities. The task computes the retrieval recall at k for each pair of modalities.

Parameters:

Name Type Description Default
task_specs list[RetrievalTaskSpec]

A list of retrieval task specifications. Each specification defines the query and target modalities, as well as the top-k values for which to compute the retrieval recall metrics.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
@store(group="eval_task", provider="mmlearn")
class ZeroShotCrossModalRetrieval(EvaluationHooks):
    """Zero-shot cross-modal retrieval evaluation task.

    This task evaluates the retrieval performance of a model on a set of query-target
    pairs. The model is expected to produce embeddings for both the query and target
    modalities. The task computes the retrieval recall at `k` for each pair of
    modalities.

    Parameters
    ----------
    task_specs : list[RetrievalTaskSpec]
        A list of retrieval task specifications. Each specification defines the query
        and target modalities, as well as the top-k values for which to compute the
        retrieval recall metrics.

    """

    def __init__(self, task_specs: list[RetrievalTaskSpec]) -> None:
        super().__init__()

        self.task_specs = task_specs
        self.metrics: dict[tuple[str, str], MetricCollection] = {}
        self._available_modalities = set()

        for spec in self.task_specs:
            query_modality = spec.query_modality
            target_modality = spec.target_modality
            assert Modalities.has_modality(query_modality)
            assert Modalities.has_modality(target_modality)

            self.metrics[(query_modality, target_modality)] = MetricCollection(
                {
                    f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
                        top_k=k, aggregation="mean", reduction="none"
                    )
                    for k in spec.top_k
                }
            )
            self._available_modalities.add(query_modality)
            self._available_modalities.add(target_modality)

    def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
        """Move the metrics to the device of the Lightning module."""
        for metric in self.metrics.values():
            metric.to(pl_module.device)

    def evaluation_step(
        self,
        pl_module: pl.LightningModule,
        batch: dict[str, torch.Tensor],
        batch_idx: int,
    ) -> None:
        """Run the forward pass and update retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.
        batch : dict[str, torch.Tensor]
            A dictionary of batched input tensors.
        batch_idx : int
            The index of the batch.

        """
        if pl_module.trainer.sanity_checking:
            return

        outputs: dict[str, Any] = {}
        for modality_name in self._available_modalities:
            if modality_name in batch:
                outputs[modality_name] = pl_module.encode(
                    batch, Modalities.get_modality(modality_name), normalize=False
                )
        for (query_modality, target_modality), metric in self.metrics.items():
            if query_modality not in outputs or target_modality not in outputs:
                continue
            query_embeddings: torch.Tensor = outputs[query_modality]
            target_embeddings: torch.Tensor = outputs[target_modality]
            indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

            metric.update(query_embeddings, target_embeddings, indexes)

    def on_evaluation_epoch_end(
        self, pl_module: pl.LightningModule
    ) -> Optional[dict[str, Any]]:
        """Compute the retrieval recall metrics.

        Parameters
        ----------
        pl_module : pl.LightningModule
            A reference to the Lightning module being evaluated.

        Returns
        -------
        Optional[dict[str, Any]]
            A dictionary of evaluation results or `None` if no results are available.
        """
        if pl_module.trainer.sanity_checking:
            return None

        results = {}
        for metric in self.metrics.values():
            results.update(metric.compute())
            metric.reset()

        eval_type = "val" if pl_module.trainer.validating else "test"

        for key, value in results.items():
            pl_module.log(f"{eval_type}/{key}", value)

        return results
on_evaluation_epoch_start
on_evaluation_epoch_start(pl_module)

Move the metrics to the device of the Lightning module.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
    """Move the metrics to the device of the Lightning module."""
    for metric in self.metrics.values():
        metric.to(pl_module.device)
evaluation_step
evaluation_step(pl_module, batch, batch_idx)

Run the forward pass and update retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required
batch dict[str, Tensor]

A dictionary of batched input tensors.

required
batch_idx int

The index of the batch.

required
Source code in mmlearn/tasks/zero_shot_retrieval.py
def evaluation_step(
    self,
    pl_module: pl.LightningModule,
    batch: dict[str, torch.Tensor],
    batch_idx: int,
) -> None:
    """Run the forward pass and update retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.
    batch : dict[str, torch.Tensor]
        A dictionary of batched input tensors.
    batch_idx : int
        The index of the batch.

    """
    if pl_module.trainer.sanity_checking:
        return

    outputs: dict[str, Any] = {}
    for modality_name in self._available_modalities:
        if modality_name in batch:
            outputs[modality_name] = pl_module.encode(
                batch, Modalities.get_modality(modality_name), normalize=False
            )
    for (query_modality, target_modality), metric in self.metrics.items():
        if query_modality not in outputs or target_modality not in outputs:
            continue
        query_embeddings: torch.Tensor = outputs[query_modality]
        target_embeddings: torch.Tensor = outputs[target_modality]
        indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)

        metric.update(query_embeddings, target_embeddings, indexes)
on_evaluation_epoch_end
on_evaluation_epoch_end(pl_module)

Compute the retrieval recall metrics.

Parameters:

Name Type Description Default
pl_module LightningModule

A reference to the Lightning module being evaluated.

required

Returns:

Type Description
Optional[dict[str, Any]]

A dictionary of evaluation results or None if no results are available.

Source code in mmlearn/tasks/zero_shot_retrieval.py
def on_evaluation_epoch_end(
    self, pl_module: pl.LightningModule
) -> Optional[dict[str, Any]]:
    """Compute the retrieval recall metrics.

    Parameters
    ----------
    pl_module : pl.LightningModule
        A reference to the Lightning module being evaluated.

    Returns
    -------
    Optional[dict[str, Any]]
        A dictionary of evaluation results or `None` if no results are available.
    """
    if pl_module.trainer.sanity_checking:
        return None

    results = {}
    for metric in self.metrics.values():
        results.update(metric.compute())
        metric.reset()

    eval_type = "val" if pl_module.trainer.validating else "test"

    for key, value in results.items():
        pl_module.log(f"{eval_type}/{key}", value)

    return results

Utilities

mmlearn.hf_utils

Utilities for loading components from the HuggingFace transformers library.

load_huggingface_model

load_huggingface_model(
    model_type,
    model_name_or_path,
    load_pretrained_weights=True,
    get_model_attr=None,
    model_config_kwargs=None,
    config_type=None,
)

Load a model from the HuggingFace transformers library.

Parameters:

Name Type Description Default
model_type Type[_BaseAutoModelClass]

The model class to instantiate e.g. transformers.AutoModel.

required
model_name_or_path str

The model name or path to load the model from.

required
load_pretrained_weights bool

Whether to load the pretrained weights or not. If false, the argument pretrained_model_name_or_path will be used to get the model configuration and the model will be initialized with random weights.

True
get_model_attr Optional[str]

If not None, the attribute of the model to return. For example, if the model is an transformers.AutoModel and get_model_attr='encoder', the encoder part of the model will be returned. If None, the full model will be returned.

None
model_config_kwargs Optional[dict[str, Any]]

Additional keyword arguments to pass to the model configuration. The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

None
config_type Optional[Type[PretrainedConfig]]

The class of the configuration to use. If None, transformers.AutoConfig will be used.

None

Returns:

Type Description
Module

The instantiated model.

Source code in mmlearn/hf_utils.py
def load_huggingface_model(
    model_type: Type[_BaseAutoModelClass],
    model_name_or_path: str,
    load_pretrained_weights: bool = True,
    get_model_attr: Optional[str] = None,
    model_config_kwargs: Optional[dict[str, Any]] = None,
    config_type: Optional[Type[PretrainedConfig]] = None,
) -> nn.Module:
    """Load a model from the HuggingFace ``transformers`` library.

    Parameters
    ----------
    model_type : Type[_BaseAutoModelClass]
        The model class to instantiate e.g. ``transformers.AutoModel``.
    model_name_or_path : str
        The model name or path to load the model from.
    load_pretrained_weights : bool, optional, default=True
        Whether to load the pretrained weights or not. If false, the argument
        ``pretrained_model_name_or_path`` will be used to get the model configuration
        and the model will be initialized with random weights.
    get_model_attr : Optional[str], optional, default=None
        If not None, the attribute of the model to return. For example, if the model
        is an ``transformers.AutoModel`` and ``get_model_attr='encoder'``, the
        encoder part of the model will be returned. If ``None``, the full model
        will be returned.
    model_config_kwargs : Optional[dict[str, Any]], optional, default=None
        Additional keyword arguments to pass to the model configuration.
        The values in kwargs of any keys which are configuration attributes will
        be used to override the loaded values. Behavior concerning key/value pairs
        whose keys are *not* configuration attributes is controlled by the
        ``return_unused_kwargs`` keyword parameter.
    config_type : Optional[Type[PretrainedConfig]], optional, default=None
        The class of the configuration to use. If None, ``transformers.AutoConfig``
        will be used.

    Returns
    -------
    torch.nn.Module
        The instantiated model.
    """
    model_config_kwargs = model_config_kwargs or {}
    if load_pretrained_weights:
        model = model_type.from_pretrained(model_name_or_path, **model_config_kwargs)
    else:
        if config_type is None:
            config_type = AutoConfig
        config, kwargs = config_type.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path,
            return_unused_kwargs=True,
            **model_config_kwargs,
        )
        model = model_type.from_config(config, **kwargs)

    if get_model_attr is not None and hasattr(model, get_model_attr):
        model = getattr(model, get_model_attr)

    return model

mmlearn.constants

Constants.