Source code for fl4health.losses.weight_drift_loss
import torch
import torch.nn as nn
[docs]
class WeightDriftLoss(nn.Module):
def __init__(
self,
device: torch.device,
) -> None:
super().__init__()
self.device = device
def _compute_weight_difference_inner_product(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.pow(torch.linalg.norm(x - y), 2.0)
[docs]
def forward(self, target_model: nn.Module, constraint_tensors: list[torch.Tensor], weight: float) -> torch.Tensor:
# move model and tensors to device if needed
target_model = target_model.to(self.device)
constraint_tensors = [constraint_tensor.to(self.device) for constraint_tensor in constraint_tensors]
model_weights = [layer_weights for layer_weights in target_model.parameters()]
assert len(constraint_tensors) == len(model_weights)
assert len(model_weights) > 0
layer_inner_products: list[torch.Tensor] = [
self._compute_weight_difference_inner_product(constraint_layer_weights, model_layer_weights)
for constraint_layer_weights, model_layer_weights in zip(constraint_tensors, model_weights)
]
# Network l2 inner product tensor
# NOTE: Scaling by 1/2 is for grad consistency.
return (weight / 2.0) * torch.stack(layer_inner_products).sum()