fl4health.losses.weight_drift_loss module¶
- class WeightDriftLoss(device)[source]¶
Bases:
Module
- __init__(device)[source]¶
Used to compute the \(l_2\)-inner product between a Torch model and a reference set of weights corresponding to a past version of that model.
- Parameters:
device (torch.device) – Device on which the loss should be computed.