mmlearn.modules.ema

Exponential Moving Average (EMA) module.

Classes

ExponentialMovingAverage

Exponential Moving Average (EMA) for the input model.

class ExponentialMovingAverage(model, ema_decay, ema_end_decay, ema_anneal_end_step, device_id=None, skip_keys=None)[source]

Exponential Moving Average (EMA) for the input model.

At each step the parameter of the EMA model is updates as the weighted average of the model’s parameters.

Parameters:
  • model (torch.nn.Module) – The model to apply EMA to.

  • ema_decay (float) – The initial decay value for EMA.

  • ema_end_decay (float) – The final decay value for EMA.

  • ema_anneal_end_step (int) – The number of steps to anneal the decay from ema_decay to ema_end_decay.

  • device_id (Optional[Union[int, torch.device]], optional, default=None) – The device to move the model to.

  • skip_keys (Optional[Union[list[str], Set[str]]], optional, default=None) – The keys to skip in the EMA update. These parameters will be copied directly from the model to the EMA model.

Raises:

RuntimeError – If a deep copy of the model cannot be created.

static deepcopy_model(model)[source]

Deep copy the model.

Parameters:

model (torch.nn.Module) – The model to copy.

Returns:

The copied model.

Return type:

torch.nn.Module

Raises:

RuntimeError – If the model cannot be copied.

static get_annealed_rate(start, end, curr_step, total_steps)[source]

Calculate EMA annealing rate.

Return type:

float

restore(model)[source]

Reassign weights from another model.

Parameters:

model (torch.nn.Module) – Model to load weights from.

Returns:

model with new weights

Return type:

torch.nn.Module

state_dict()[source]

Return the state dict of the model.

Return type:

dict[str, Any]

step(new_model)[source]

Perform single EMA update step.

Return type:

None