mmlearn.modules.ema.ExponentialMovingAverage¶
- class ExponentialMovingAverage(model, ema_decay, ema_end_decay, ema_anneal_end_step, device_id=None, skip_keys=None)[source]¶
Bases:
object
Exponential Moving Average (EMA) for the input model.
At each step the parameter of the EMA model is updates as the weighted average of the model’s parameters.
- Parameters:
model (torch.nn.Module) – The model to apply EMA to.
ema_decay (float) – The initial decay value for EMA.
ema_end_decay (float) – The final decay value for EMA.
ema_anneal_end_step (int) – The number of steps to anneal the decay from
ema_decay
toema_end_decay
.device_id (Optional[Union[int, torch.device]], optional, default=None) – The device to move the model to.
skip_keys (Optional[Union[list[str], Set[str]]], optional, default=None) – The keys to skip in the EMA update. These parameters will be copied directly from the model to the EMA model.
- Raises:
RuntimeError – If a deep copy of the model cannot be created.
Methods
- static deepcopy_model(model)[source]¶
Deep copy the model.
- Parameters:
model (torch.nn.Module) – The model to copy.
- Returns:
The copied model.
- Return type:
- Raises:
RuntimeError – If the model cannot be copied.
- static get_annealed_rate(start, end, curr_step, total_steps)[source]¶
Calculate EMA annealing rate.
- Return type:
- restore(model)[source]¶
Reassign weights from another model.
- Parameters:
model (torch.nn.Module) – Model to load weights from.
- Returns:
model with new weights
- Return type: