import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
from fl4health.utils.functions import bernoulli_sample
TorchShape = int | list[int] | torch.Size
[docs]
class MaskedLayerNorm(nn.LayerNorm):
[docs]
def __init__(
self,
normalized_shape: TorchShape,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""
Implementation of the masked Layer Normalization module. When elementwise_affine is True,
nn.LayerNorm has a learnable weight and (optional) bias. For MaskedLayerNorm,
the weight and bias do not receive gradient in back propagation.
Instead, two score tensors - one for the weight and another for the bias - are maintained.
In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores,
which are then used to produce binary masks via bernoulli sampling.
Finally, the binary masks are applied to the weight and the bias. During training,
gradients with respect to the score tensors are computed and used to update the score tensors.
When elementwise_affine is False, nn.LayerNorm does not have weight or bias.
Under this condition, both score tensors are None and MaskedLayerNorm acts in the same way as nn.LayerNorm.
Note: the scores are not assumed to be bounded between 0 and 1.
Args:
normalized_shape (TorchShape): input shape from an expected input.
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
:attr:`elementwise_affine` is ``True``). Default: ``True``.
Attributes:
weight: the weights of the module. The values are initialized to 1.
bias: the bias of the module. The values are initialized to 0.
weight_score: learnable scores for the weights. Has the same shape as weight. When applied
to the default initial values of self.weight (i.e., all ones), this is equivalent to
randomly dropping out certain features.
bias_score: learnable scores for the bias. Has the same shape as bias. When applied to
the default initial values of self.bias (i.e., all zeros), it does not have any actual
effect. Thus, bias_score only influences training when MaskedLayerNorm is created
from some pretrained nn.LayerNorm module whose bias is not all zeros.
"""
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
bias=bias,
device=device,
dtype=dtype,
)
if self.elementwise_affine:
assert self.weight is not None
self.weight.requires_grad = False
self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True)
if self.bias is not None:
self.bias.requires_grad = False
self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True)
else:
self.register_parameter("bias_scores", None)
else:
self.register_parameter("weight_scores", None)
self.register_parameter("bias_scores", None)
[docs]
def forward(self, input: Tensor) -> Tensor:
if not self.elementwise_affine:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
else:
assert self.weight is not None
weight_prob_scores = torch.sigmoid(self.weight_scores)
weight_mask = bernoulli_sample(weight_prob_scores)
masked_weight = weight_mask * self.weight
if self.bias is not None:
bias_prob_scores = torch.sigmoid(self.bias_scores)
bias_mask = bernoulli_sample(bias_prob_scores)
masked_bias = bias_mask * self.bias
else:
masked_bias = None
return F.layer_norm(input, self.normalized_shape, masked_weight, masked_bias, self.eps)
[docs]
@classmethod
def from_pretrained(cls, layer_norm_module: nn.LayerNorm) -> "MaskedLayerNorm":
"""
Return an instance of MaskedLayerNorm whose weight and bias have the same values as those of layer_norm_module.
"""
masked_layer_norm_module = cls(
# layer_norm_module.normalized_shape is a tuple so we
# simply transform it into torch.Size so it is compatible with
# the constructor's type signature.
normalized_shape=torch.Size(layer_norm_module.normalized_shape),
eps=layer_norm_module.eps,
elementwise_affine=layer_norm_module.elementwise_affine,
bias=(layer_norm_module.bias is not None),
)
if layer_norm_module.elementwise_affine:
assert layer_norm_module.weight is not None
masked_layer_norm_module.weight = Parameter(layer_norm_module.weight.clone().detach(), requires_grad=False)
masked_layer_norm_module.weight_scores = Parameter(
torch.randn_like(layer_norm_module.weight), requires_grad=True
)
if layer_norm_module.bias is not None:
masked_layer_norm_module.bias = Parameter(layer_norm_module.bias.clone().detach(), requires_grad=False)
masked_layer_norm_module.bias_scores = Parameter(
torch.randn_like(layer_norm_module.bias), requires_grad=True
)
return masked_layer_norm_module
class _MaskedBatchNorm(_BatchNorm):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""
Base class for masked batch normalization modules of various dimensions. When affine is True,
_BatchNorm has a learnable weight and bias. For _MaskedBatchNorm,
the weight and bias do not receive gradient in back propagation.
Instead, two score tensors - one for the weight and another for the bias - are maintained.
In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores,
which are then used to produce binary masks via bernoulli sampling.
Finally, the binary masks are applied to the weight and the bias. During training,
gradients with respect to the score tensors are computed and used to update the score tensors.
When affine is False, _BatchNorm does not have weight or bias.
Under this condition, both score tensors are None and _MaskedBatchNorm acts in the same way as _BatchNorm.
Note: the scores are not assumed to be bounded between 0 and 1.
Args:
num_features: number of features or channels :math:`C` of the input
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Attributes:
weight: the weights of the module. The values are initialized to 1.
bias: the bias of the module. The values are initialized to 0.
weight_score: learnable scores for the weights. Has the same shape as weight. When applied
to the default initial values of self.weight (i.e., all ones), this is equivalent to
randomly dropping out certain features.
bias_score: learnable scores for the bias. Has the same shape as bias. When applied to
the default initial values of self.bias (i.e., all zeros), it does not have any actual
effect. Thus, bias_score only influences training when MaskedLayerNorm is created
from some pretrained nn.LayerNorm module whose bias is not all zeros.
"""
super().__init__(num_features, eps, momentum, affine, track_running_stats, device=device, dtype=dtype)
if self.affine:
assert (self.weight is not None) and (self.bias is not None)
self.weight.requires_grad = False
self.bias.requires_grad = False
self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True)
self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True)
else:
self.register_parameter("weight_scores", None)
self.register_parameter("bias_scores", None)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# Decide whether the mini-batch stats should be used for normalization rather than the buffers.
# Mini-batch stats are used in training mode, and in eval mode when buffers are None.
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
# Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
# passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
# used for normalization (i.e. in eval mode when buffers are not None).
if self.affine:
assert (self.weight is not None) and (self.bias is not None)
weight_prob_scores = torch.sigmoid(self.weight_scores)
weight_mask = bernoulli_sample(weight_prob_scores)
masked_weight = weight_mask * self.weight
bias_prob_scores = torch.sigmoid(self.bias_scores)
bias_mask = bernoulli_sample(bias_prob_scores)
masked_bias = bias_mask * self.bias
else:
masked_weight = None
masked_bias = None
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
masked_weight,
masked_bias,
bn_training,
exponential_average_factor,
self.eps,
)
@classmethod
def from_pretrained(cls, batch_norm_module: _BatchNorm) -> "_MaskedBatchNorm":
masked_batch_norm_module = cls(
num_features=batch_norm_module.num_features,
eps=batch_norm_module.eps,
momentum=batch_norm_module.momentum,
affine=batch_norm_module.affine,
track_running_stats=batch_norm_module.track_running_stats,
)
if batch_norm_module.affine:
assert (batch_norm_module.weight is not None) and (batch_norm_module.bias is not None)
masked_batch_norm_module.weight = Parameter(batch_norm_module.weight.clone().detach(), requires_grad=False)
masked_batch_norm_module.weight_scores = Parameter(
torch.randn_like(batch_norm_module.weight), requires_grad=True
)
masked_batch_norm_module.bias = Parameter(batch_norm_module.bias.clone().detach(), requires_grad=False)
masked_batch_norm_module.bias_scores = Parameter(
torch.randn_like(batch_norm_module.weight), requires_grad=True
)
return masked_batch_norm_module
[docs]
class MaskedBatchNorm1d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 2D or 3D input.
Input shape should be (N, C) or (N, C, L), where N is the batch size,
C is the number of features/channels, and L is the sequence length.
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 2 and input.dim() != 3:
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
[docs]
class MaskedBatchNorm2d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension).
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 4:
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
[docs]
class MaskedBatchNorm3d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension).
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 5:
raise ValueError(f"expected 5D input (got {input.dim()}D input)")