Source code for fl4health.model_bases.masked_layers.masked_linear

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter

from fl4health.utils.functions import bernoulli_sample


[docs] class MaskedLinear(nn.Linear):
[docs] def __init__( self, in_features: int, out_features: int, bias: bool = True, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked linear layers. Like regular linear layers (i.e., nn.Linear module), a masked linear layer has a weight and a bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. Note: the scores are not assumed to be bounded between 0 and 1. Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` Attributes: weight: weights of the module. bias: bias of the module. weight_score: learnable scores for the weights. Has the same shape as weight. bias_score: learnable scores for the bias. Has the same shape as bias. """ super().__init__(in_features, out_features, bias, device, dtype) self.in_features = in_features self.out_features = out_features self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor) -> Tensor: # Produce probability scores and perform bernoulli sampling weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None # Apply the masks to weight and bias return F.linear(input, masked_weight, masked_bias)
[docs] @classmethod def from_pretrained(cls, linear_module: nn.Linear) -> "MaskedLinear": """ Return an instance of MaskedLinear whose weight and bias have the same values as those of linear_module. """ has_bias = linear_module.bias is not None masked_linear_module = cls( in_features=linear_module.in_features, out_features=linear_module.out_features, bias=has_bias, ) masked_linear_module.weight = Parameter(linear_module.weight.clone().detach(), requires_grad=False) masked_linear_module.weight_scores = Parameter(torch.randn_like(linear_module.weight), requires_grad=True) if has_bias: masked_linear_module.bias = Parameter(linear_module.bias.clone().detach(), requires_grad=False) masked_linear_module.bias_scores = Parameter(torch.randn_like(linear_module.bias), requires_grad=True) return masked_linear_module