"""Exponential Moving Average (EMA) module."""
import copy
from typing import Any, Optional, Set, Union
import torch
from lightning.fabric.utilities import rank_zero_warn
[docs]
class ExponentialMovingAverage:
"""Exponential Moving Average (EMA) for the input model.
At each step the parameter of the EMA model is updates as the weighted average
of the model's parameters.
Parameters
----------
model : torch.nn.Module
The model to apply EMA to.
ema_decay : float
The initial decay value for EMA.
ema_end_decay : float
The final decay value for EMA.
ema_anneal_end_step : int
The number of steps to anneal the decay from ``ema_decay`` to ``ema_end_decay``.
device_id : Optional[Union[int, torch.device]], optional, default=None
The device to move the model to.
skip_keys : Optional[Union[list[str], Set[str]]], optional, default=None
The keys to skip in the EMA update. These parameters will be copied directly
from the model to the EMA model.
Raises
------
RuntimeError
If a deep copy of the model cannot be created.
"""
def __init__(
self,
model: torch.nn.Module,
ema_decay: float,
ema_end_decay: float,
ema_anneal_end_step: int,
device_id: Optional[Union[int, torch.device]] = None,
skip_keys: Optional[Union[list[str], Set[str]]] = None,
):
self.model = self.deepcopy_model(model)
self.model.requires_grad_(False)
if device_id is not None:
self.model.to(device_id)
self.skip_keys: Union[list[str], set[str]] = skip_keys or set()
self.num_updates = 0
self.decay = ema_decay # stores the current decay value
self.ema_decay = ema_decay
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step
[docs]
@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model.
Parameters
----------
model : torch.nn.Module
The model to copy.
Returns
-------
torch.nn.Module
The copied model.
Raises
------
RuntimeError
If the model cannot be copied.
"""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e
[docs]
@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
[docs]
def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()
[docs]
def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.
Parameters
----------
model : torch.nn.Module
Model to load weights from.
Returns
-------
torch.nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model
[docs]
def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]
@torch.no_grad() # type: ignore[misc]
def _update_weights(self, new_model: torch.nn.Module) -> None:
if self.decay < 1:
ema_state_dict = {}
ema_params = self.model.state_dict()
for key, param in new_model.state_dict().items():
ema_param = ema_params[key].float()
if param.shape != ema_param.shape:
raise ValueError(
"Incompatible tensor shapes between student param and teacher param"
+ "{} vs. {}".format(param.shape, ema_param.shape)
)
if key in self.skip_keys or not param.requires_grad:
ema_param = param.to(dtype=ema_param.dtype).clone()
else:
ema_param.mul_(self.decay)
ema_param.add_(
param.to(dtype=ema_param.dtype),
alpha=1 - self.decay,
)
ema_state_dict[key] = ema_param
self.model.load_state_dict(ema_state_dict, strict=False)
self.num_updates += 1
else:
rank_zero_warn(
"Exponential Moving Average decay is 1.0, no update is applied to the model.",
stacklevel=1,
category=UserWarning,
)
def _update_ema_decay(self) -> None:
if self.ema_decay != self.ema_end_decay:
if self.num_updates >= self.ema_anneal_end_step:
decay = self.ema_end_decay
else:
decay = self.get_annealed_rate(
self.ema_decay,
self.ema_end_decay,
self.num_updates,
self.ema_anneal_end_step,
)
self.decay = decay