mmlearn.tasks.ijepa module¶
IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task.
- class IJEPA(encoder, predictor, modality='RGB', optimizer=None, lr_scheduler=None, ema_decay=0.996, ema_decay_end=1.0, ema_anneal_end_step=1000, loss_fn=None, compute_validation_loss=True, compute_test_loss=True)[source]¶
Bases:
TrainingTask
Pretraining module for IJEPA.
This class implements the IJEPA (Image Joint-Embedding Predictive Architecture) pretraining task using PyTorch Lightning. It trains an encoder and a predictor to reconstruct masked regions of an image based on its unmasked context.
- Parameters:
encoder (VisionTransformer) – Vision transformer encoder.
predictor (VisionTransformerPredictor) – Vision transformer predictor.
optimizer (Optional[partial[torch.optim.Optimizer]], optional, default=None) – The optimizer to use for training. This is expected to be a
partial()
function that takes the model parameters as the only required argument. If not provided, training will continue without an optimizer.lr_scheduler (Optional[Union[dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]]], optional, default=None) – The learning rate scheduler to use for training. This can be a
partial()
function that takes the optimizer as the only required argument or a dictionary with a'scheduler'
key that specifies the scheduler and an optional'extras'
key that specifies additional arguments to pass to the scheduler. If not provided, the learning rate will not be adjusted during training.ema_decay (float, optional, default=0.996) – Initial momentum for EMA of target encoder.
ema_decay_end (float, optional, default=1.0) – Final momentum for EMA of target encoder.
ema_anneal_end_step (int, optional, default=1000) – Number of steps to anneal EMA momentum to
ema_decay_end
.loss_fn (Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional) – Loss function to use. If not provided, defaults to
smooth_l1_loss()
.compute_validation_loss (bool, optional, default=True) – Whether to compute validation loss.
compute_test_loss (bool, optional, default=True) – Whether to compute test loss.
- on_before_zero_grad(optimizer)[source]¶
Perform exponential moving average update of target encoder.
This is done right after the optimizer step, which comes just before zero_grad to account for gradient accumulation.
- Return type:
- test_step(batch, batch_idx)[source]¶
Run a single test step.
- Parameters:
- Returns:
Loss value or
None
if no loss is computed- Return type:
Optional[torch.Tensor]