Source code for fl4health.losses.weight_drift_loss
import torch
import torch.nn as nn
[docs]
class WeightDriftLoss(nn.Module):
[docs]
def __init__(
self,
device: torch.device,
) -> None:
"""
Used to compute the :math:`l_2`-inner product between a Torch model and a reference set of weights
corresponding to a past version of that model.
Args:
device (torch.device): Device on which the loss should be computed.
"""
super().__init__()
self.device = device
def _compute_weight_difference_inner_product(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute the :math:`l_2`-inner product between two tensors. This amounts to the Frobenius norm of the
difference between the tensors :math:`\\Vert x - y \\Vert_F.`
Args:
x (torch.Tensor): first tensor
y (torch.Tensor): second tensor
Returns:
torch.Tensor: Frobenius norm of their difference
"""
return torch.pow(torch.linalg.norm(x - y), 2.0)
[docs]
def forward(self, target_model: nn.Module, constraint_tensors: list[torch.Tensor], weight: float) -> torch.Tensor:
"""
Compute the :math:`l_2`-inner product between a Torch model and a reference set of weights in a differentiable
way. The `constraint_tenors` are frozen.
Args:
target_model (nn.Module): Model being constrained by the `constraint_tensors`. Weights are differentiable.
constraint_tensors (list[torch.Tensor]): Tensors corresponding to a previous version of the target_model
weight (float): Weight to scale the loss with
Returns:
torch.Tensor: Loss value.
"""
# move model and tensors to device if needed
target_model = target_model.to(self.device)
constraint_tensors = [constraint_tensor.to(self.device) for constraint_tensor in constraint_tensors]
model_weights = [layer_weights for layer_weights in target_model.parameters()]
assert len(constraint_tensors) == len(model_weights)
assert len(model_weights) > 0
layer_inner_products: list[torch.Tensor] = [
self._compute_weight_difference_inner_product(constraint_layer_weights, model_layer_weights)
for constraint_layer_weights, model_layer_weights in zip(constraint_tensors, model_weights)
]
# Network l2 inner product tensor
# NOTE: Scaling by 1/2 is for grad consistency.
return (weight / 2.0) * torch.stack(layer_inner_products).sum()