fl4health.losses.weight_drift_loss module

class WeightDriftLoss(device)[source]

Bases: Module

__init__(device)[source]

Used to compute the \(l_2\)-inner product between a Torch model and a reference set of weights corresponding to a past version of that model.

Parameters:

device (torch.device) – Device on which the loss should be computed.

forward(target_model, constraint_tensors, weight)[source]

Compute the \(l_2\)-inner product between a Torch model and a reference set of weights in a differentiable way. The constraint_tenors are frozen.

Parameters:
  • target_model (nn.Module) – Model being constrained by the constraint_tensors. Weights are differentiable.

  • constraint_tensors (list[torch.Tensor]) – Tensors corresponding to a previous version of the target_model

  • weight (float) – Weight to scale the loss with

Returns:

Loss value.

Return type:

torch.Tensor