import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from fl4health.utils.functions import bernoulli_sample
[docs]
class MaskedLinear(nn.Linear):
[docs]
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""
Implementation of masked linear layers.
Like regular linear layers (i.e., nn.Linear module), a masked linear layer has a weight and a bias.
However, the weight and the bias do not receive gradient in back propagation.
Instead, two score tensors - one for the weight and another for the bias - are maintained.
In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores,
which are then used to produce binary masks via bernoulli sampling.
Finally, the binary masks are applied to the weight and the bias. During training,
gradients with respect to the score tensors are computed and used to update the score tensors.
Note: the scores are not assumed to be bounded between 0 and 1.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Attributes:
weight: weights of the module.
bias: bias of the module.
weight_score: learnable scores for the weights. Has the same shape as weight.
bias_score: learnable scores for the bias. Has the same shape as bias.
"""
super().__init__(in_features, out_features, bias, device, dtype)
self.in_features = in_features
self.out_features = out_features
self.weight.requires_grad = False
self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True)
if bias:
assert self.bias is not None
self.bias.requires_grad = False
self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True)
else:
self.register_parameter("bias_scores", None)
[docs]
def forward(self, input: Tensor) -> Tensor:
# Produce probability scores and perform bernoulli sampling
weight_prob_scores = torch.sigmoid(self.weight_scores)
weight_mask = bernoulli_sample(weight_prob_scores)
masked_weight = weight_mask * self.weight
if self.bias is not None:
bias_prob_scores = torch.sigmoid(self.bias_scores)
bias_mask = bernoulli_sample(bias_prob_scores)
masked_bias = bias_mask * self.bias
else:
masked_bias = None
# Apply the masks to weight and bias
return F.linear(input, masked_weight, masked_bias)
[docs]
@classmethod
def from_pretrained(cls, linear_module: nn.Linear) -> "MaskedLinear":
"""
Return an instance of MaskedLinear whose weight and bias have the same values as those of linear_module.
"""
has_bias = linear_module.bias is not None
masked_linear_module = cls(
in_features=linear_module.in_features,
out_features=linear_module.out_features,
bias=has_bias,
)
masked_linear_module.weight = Parameter(linear_module.weight.clone().detach(), requires_grad=False)
masked_linear_module.weight_scores = Parameter(torch.randn_like(linear_module.weight), requires_grad=True)
if has_bias:
masked_linear_module.bias = Parameter(linear_module.bias.clone().detach(), requires_grad=False)
masked_linear_module.bias_scores = Parameter(torch.randn_like(linear_module.bias), requires_grad=True)
return masked_linear_module