mmlearn.tasks.ijepa module

IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task.

class IJEPA(encoder, predictor, modality='RGB', optimizer=None, lr_scheduler=None, ema_decay=0.996, ema_decay_end=1.0, ema_anneal_end_step=1000, loss_fn=None, compute_validation_loss=True, compute_test_loss=True)[source]

Bases: TrainingTask

Pretraining module for IJEPA.

This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.

Parameters:
  • encoder (VisionTransformer) – Vision transformer encoder.

  • predictor (VisionTransformerPredictor) – Vision transformer predictor.

  • optimizer (Optional[partial[torch.optim.Optimizer]], optional, default=None) – The optimizer to use for training. This is expected to be a partial() function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.

  • lr_scheduler (Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None) – The learning rate scheduler to use for training. This can be a partial() function that takes the optimizer as the only required argument or a dictionary with a 'scheduler' key that specifies the scheduler and an optional 'extras' key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.

  • ema_decay (float, optional, default=0.996) – Initial momentum for EMA of target encoder.

  • ema_decay_end (float, optional, default=1.0) – Final momentum for EMA of target encoder.

  • ema_anneal_end_step (int, optional, default=1000) – Number of steps to anneal EMA momentum to ema_decay_end.

  • loss_fn (Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional) – Loss function to use. If not provided, defaults to smooth_l1_loss().

  • compute_validation_loss (bool, optional, default=True) – Whether to compute validation loss.

  • compute_test_loss (bool, optional, default=True) – Whether to compute test loss.

configure_model()[source]

Configure the model.

Return type:

None

on_before_zero_grad(optimizer)[source]

Perform exponential moving average update of target encoder.

This is done right after the optimizer step, which comes just before zero_grad to account for gradient accumulation.

Return type:

None

on_load_checkpoint(checkpoint)[source]

Restore EMA state from the checkpoint.

Parameters:

checkpoint (dict[str, Any]) – The state dictionary to restore the EMA state from.

Return type:

None

on_save_checkpoint(checkpoint)[source]

Add relevant EMA state to the checkpoint.

Parameters:

checkpoint (dict[str, Any]) – The state dictionary to save the EMA state to.

Return type:

None

on_test_epoch_end()[source]

Actions at the end of the test epoch.

Return type:

None

on_test_epoch_start()[source]

Prepare for the test epoch.

Return type:

None

on_validation_epoch_end()[source]

Actions at the end of the validation epoch.

Return type:

None

on_validation_epoch_start()[source]

Prepare for the validation epoch.

Return type:

None

test_step(batch, batch_idx)[source]

Run a single test step.

Parameters:
  • batch (dict[str, Any]) – A batch of data.

  • batch_idx (int) – Index of the batch.

Returns:

Loss value or None if no loss is computed

Return type:

Optional[torch.Tensor]

training_step(batch, batch_idx)[source]

Perform a single training step.

Parameters:
  • batch (dict[str, Any]) – A batch of data.

  • batch_idx (int) – Index of the batch.

Returns:

Loss value.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Run a single validation step.

Parameters:
  • batch (dict[str, Any]) – A batch of data.

  • batch_idx (int) – Index of the batch.

Returns:

Loss value or None if no loss is computed.

Return type:

Optional[torch.Tensor]