from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parameter import Parameter
from fl4health.utils.functions import bernoulli_sample
TorchShape = int | list[int] | torch.Size
[docs]
class MaskedLayerNorm(nn.LayerNorm):
[docs]
def __init__(
self,
normalized_shape: TorchShape,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""
Implementation of the masked Layer Normalization module. When ``elementwise_affine`` is True, ``nn.LayerNorm``
has a learnable weight and (optional) bias. For ``MaskedLayerNorm``, the weight and bias do not receive
gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are
maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability
scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are
applied to the weight and the bias. During training, gradients with respect to the score tensors are computed
and used to update the score tensors.
When ``elementwise_affine`` is False, ``nn.LayerNorm`` does not have weight or bias. Under this condition, both
score tensors are None and ``MaskedLayerNorm`` acts in the same way as ``nn.LayerNorm``.
**NOTE:** The scores are not assumed to be bounded between 0 and 1.
Args:
normalized_shape (TorchShape): Input shape from an expected input. If a single integer is used, it is
treated as a singleton list, and this module will normalize over the last dimension which is expected
to be of that specific size.
eps: A value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: A boolean value that when set to ``True``, this module has learnable per-element
affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``.
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
``elementwise_affine`` is ``True``). Default: ``True``.
"""
# Attributes:
# weight: the weights of the module. The values are initialized to 1.
# bias: the bias of the module. The values are initialized to 0.
# weight_score: learnable scores for the weights. Has the same shape as weight. When applied
# to the default initial values of self.weight (i.e., all ones), this is equivalent to
# randomly dropping out certain features.
# bias_score: learnable scores for the bias. Has the same shape as bias. When applied to
# the default initial values of self.bias (i.e., all zeros), it does not have any actual
# effect. Thus, bias_score only influences training when MaskedLayerNorm is created
# from some pretrained nn.LayerNorm module whose bias is not all zeros.
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
bias=bias,
device=device,
dtype=dtype,
)
if self.elementwise_affine:
assert self.weight is not None
self.weight.requires_grad = False
self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True)
if self.bias is not None:
self.bias.requires_grad = False
self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True)
else:
self.register_parameter("bias_scores", None)
else:
self.register_parameter("weight_scores", None)
self.register_parameter("bias_scores", None)
[docs]
def forward(self, input: Tensor) -> Tensor:
"""
Mapping function for the ``MaskedLayerNorm``
Args:
input (Tensor): Tensor to be mapped by the layer
Returns:
Tensor: Output tensor after mapping of the input tensor
"""
if not self.elementwise_affine:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
else:
assert self.weight is not None
weight_prob_scores = torch.sigmoid(self.weight_scores)
weight_mask = bernoulli_sample(weight_prob_scores)
masked_weight = weight_mask * self.weight
if self.bias is not None:
bias_prob_scores = torch.sigmoid(self.bias_scores)
bias_mask = bernoulli_sample(bias_prob_scores)
masked_bias = bias_mask * self.bias
else:
masked_bias = None
return F.layer_norm(input, self.normalized_shape, masked_weight, masked_bias, self.eps)
[docs]
@classmethod
def from_pretrained(cls, layer_norm_module: nn.LayerNorm) -> MaskedLayerNorm:
"""
Return an instance of ``MaskedLayerNorm`` whose weight and bias have the same values as those of
``layer_norm_module``.
Args:
layer_norm_module (nn.LayerNorm): Target module to be converted
Returns:
MaskedLayerNorm: New copy of the provided module with mask layers added to enable FedPM
"""
masked_layer_norm_module = cls(
# layer_norm_module.normalized_shape is a tuple so we
# simply transform it into torch.Size so it is compatible with
# the constructor's type signature.
normalized_shape=torch.Size(layer_norm_module.normalized_shape),
eps=layer_norm_module.eps,
elementwise_affine=layer_norm_module.elementwise_affine,
bias=(layer_norm_module.bias is not None),
)
if layer_norm_module.elementwise_affine:
assert layer_norm_module.weight is not None
masked_layer_norm_module.weight = Parameter(layer_norm_module.weight.clone().detach(), requires_grad=False)
masked_layer_norm_module.weight_scores = Parameter(
torch.randn_like(layer_norm_module.weight), requires_grad=True
)
if layer_norm_module.bias is not None:
masked_layer_norm_module.bias = Parameter(layer_norm_module.bias.clone().detach(), requires_grad=False)
masked_layer_norm_module.bias_scores = Parameter(
torch.randn_like(layer_norm_module.bias), requires_grad=True
)
return masked_layer_norm_module
class _MaskedBatchNorm(_BatchNorm):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""
Base class for masked batch normalization modules of various dimensions. When affine is True,
``_BatchNorm`` has a learnable weight and bias. For ``_MaskedBatchNorm``, the weight and bias do not
receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the
bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function
into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the
binary masks are applied to the weight and the bias. During training, gradients with respect to the score
tensors are computed and used to update the score tensors.
When affine is False, _BatchNorm does not have weight or bias. Under this condition, both score tensors
are None and ``_MaskedBatchNorm`` acts in the same way as ``_BatchNorm``.
**NOTE:** The scores are not assumed to be bounded between 0 and 1.
Args:
num_features: number of features or channels :math:`C` of the input
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and ``running_var`` computation. Can be set to ``None`` for
cumulative moving average (i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has learnable affine parameters.
Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and
variance, and when set to ``False``, this module does not track such statistics, and initializes
statistics buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers
are ``None``, this module always uses batch statistics. in both training and eval modes.
Default: ``True``
"""
# Attributes:
# weight: the weights of the module. The values are initialized to 1.
# bias: the bias of the module. The values are initialized to 0.
# weight_score: learnable scores for the weights. Has the same shape as weight. When applied
# to the default initial values of self.weight (i.e., all ones), this is equivalent to
# randomly dropping out certain features.
# bias_score: learnable scores for the bias. Has the same shape as bias. When applied to
# the default initial values of self.bias (i.e., all zeros), it does not have any actual
# effect. Thus, bias_score only influences training when MaskedLayerNorm is created
# from some pretrained nn.LayerNorm module whose bias is not all zeros.
super().__init__(num_features, eps, momentum, affine, track_running_stats, device=device, dtype=dtype)
if self.affine:
assert (self.weight is not None) and (self.bias is not None)
self.weight.requires_grad = False
self.bias.requires_grad = False
self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True)
self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True)
else:
self.register_parameter("weight_scores", None)
self.register_parameter("bias_scores", None)
def forward(self, input: Tensor) -> Tensor:
"""
Mapping function for the ``_MaskedBatchNorm`` module
Args:
input (Tensor): Tensor to be mapped via the ``_MaskedBatchNorm``
Returns:
Tensor: Output tensor after mapping
"""
self._check_input_dim(input)
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# Decide whether the mini-batch stats should be used for normalization rather than the buffers.
# Mini-batch stats are used in training mode, and in eval mode when buffers are None.
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
# Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
# passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
# used for normalization (i.e. in eval mode when buffers are not None).
if self.affine:
assert (self.weight is not None) and (self.bias is not None)
weight_prob_scores = torch.sigmoid(self.weight_scores)
weight_mask = bernoulli_sample(weight_prob_scores)
masked_weight = weight_mask * self.weight
bias_prob_scores = torch.sigmoid(self.bias_scores)
bias_mask = bernoulli_sample(bias_prob_scores)
masked_bias = bias_mask * self.bias
else:
masked_weight = None
masked_bias = None
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
masked_weight,
masked_bias,
bn_training,
exponential_average_factor,
self.eps,
)
@classmethod
def from_pretrained(cls, batch_norm_module: _BatchNorm) -> _MaskedBatchNorm:
"""
Mapping a ``_BatchNorm`` module to a ``_MaskedBatchNorm`` by injecting masked layers
Args:
batch_norm_module (_BatchNorm): Module to be transformed to a masked module through layer insertion
Returns:
_MaskedBatchNorm: New copy of the input module with masked layers to enable FedPM
"""
masked_batch_norm_module = cls(
num_features=batch_norm_module.num_features,
eps=batch_norm_module.eps,
momentum=batch_norm_module.momentum,
affine=batch_norm_module.affine,
track_running_stats=batch_norm_module.track_running_stats,
)
if batch_norm_module.affine:
assert (batch_norm_module.weight is not None) and (batch_norm_module.bias is not None)
masked_batch_norm_module.weight = Parameter(batch_norm_module.weight.clone().detach(), requires_grad=False)
masked_batch_norm_module.weight_scores = Parameter(
torch.randn_like(batch_norm_module.weight), requires_grad=True
)
masked_batch_norm_module.bias = Parameter(batch_norm_module.bias.clone().detach(), requires_grad=False)
masked_batch_norm_module.bias_scores = Parameter(
torch.randn_like(batch_norm_module.weight), requires_grad=True
)
return masked_batch_norm_module
[docs]
class MaskedBatchNorm1d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 2D or 3D input. Input shape should be ``(N, C)`` or ``(N, C, L)``,
where ``N`` is the batch size, ``C`` is the number of features/channels, and ``L`` is the sequence length.
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 2 and input.dim() != 3:
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
[docs]
class MaskedBatchNorm2d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 4:
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
[docs]
class MaskedBatchNorm3d(_MaskedBatchNorm):
"""
Applies (masked) Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension).
"""
def _check_input_dim(self, input: Tensor) -> None:
if input.dim() != 5:
raise ValueError(f"expected 5D input (got {input.dim()}D input)")