Source code for fl4health.model_bases.masked_layers.masked_conv

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _pair, _single, _triple
from torch.nn.parameter import Parameter

from fl4health.utils.functions import bernoulli_sample


[docs] class MaskedConv1d(nn.Conv1d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: str | _size_1_t = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv1d layers. Like regular Conv1d layers (i.e., nn.Conv1d module), a masked convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` """ # Attributes: # weight: weights of the module. # bias: bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor) -> Tensor: """ Forward for the mask 1D Convolution Args: input (Tensor): input tensor for the layer Returns: Tensor: output tensor for the convolution """ weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return self._conv_forward(input, weight=masked_weight, bias=masked_bias)
[docs] @classmethod def from_pretrained(cls, conv_module: nn.Conv1d) -> MaskedConv1d: """ Return an instance of ``MaskedConv1d`` whose weight and bias have the same values as those of ``conv_module``. Args: conv_module (nn.Conv1d): Module to be converted Returns: MaskedConv1d: Module with masked layers added to enable FedPM training """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _single(conv_module.kernel_size) stride_ = _single(conv_module.stride) padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _single(conv_module.padding) dilation_ = _single(conv_module.dilation) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, dilation=dilation_, groups=conv_module.groups, bias=has_bias, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module
[docs] class MaskedConv2d(nn.Conv2d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: str | _size_2_t = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv2d layers. Like regular Conv2d layers (i.e., nn.Conv2d module), a masked convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: 0 padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` """ # Attributes: # weight: weights of the module. # bias: bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor) -> Tensor: """ Forward for the Masked 2D Convolution Args: input (Tensor): input tensor for the layer Returns: Tensor: output tensor for the convolution """ weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return self._conv_forward(input, weight=masked_weight, bias=masked_bias)
[docs] @classmethod def from_pretrained(cls, conv_module: nn.Conv2d) -> MaskedConv2d: """ Return an instance of MaskedConv2d whose weight and bias have the same values as those of conv_module. Args: conv_module (nn.Conv2d): Module to be converted Returns: MaskedConv2d: Module with masked layers to enable FedPM """ has_bias = conv_module.bias is not None kernel_size_ = _pair(conv_module.kernel_size) stride_ = _pair(conv_module.stride) padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _pair(conv_module.padding) dilation_ = _pair(conv_module.dilation) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, dilation=dilation_, groups=conv_module.groups, bias=has_bias, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module
[docs] class MaskedConv3d(nn.Conv3d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: str | _size_3_t = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv2d layers. Like regular Conv3d layers (i.e., nn.Conv3d module), a masked convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` """ # Attributes: # weight: weights of the module. # bias: bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor) -> Tensor: """ Forward for the Masked 3D Convolution Args: input (Tensor): input tensor for the layer Returns: Tensor: output tensor for the convolution """ weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return self._conv_forward(input, weight=masked_weight, bias=masked_bias)
[docs] @classmethod def from_pretrained(cls, conv_module: nn.Conv3d) -> MaskedConv3d: """ Return an instance of MaskedConv3d whose weight and bias have the same values as those of conv_module. Args: conv_module (nn.Conv3d): Module to convert Returns: MaskedConv3d: Module with mask layers added to enable FedPM """ has_bias = conv_module.bias is not None kernel_size_ = _triple(conv_module.kernel_size) stride_ = _triple(conv_module.stride) padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _triple(conv_module.padding) dilation_ = _triple(conv_module.dilation) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, dilation=dilation_, groups=conv_module.groups, bias=has_bias, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module
[docs] class MaskedConvTranspose1d(nn.ConvTranspose1d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t = 0, output_padding: _size_1_t = 0, groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ``ConvTranspose1d`` layers. For more information on transposed convolution, please see the PyTorch implementation of ``nn.Conv1d.`` (https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#ConvTranspose1d) Like regular ``ConvTranspose1d`` layers (i.e., ``nn.ConvTranspose1d`` module), a masked transpose convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the transposed convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 """ # Attributes: # weight (Tensor): weights of the module. # bias (Tensor): bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: """ Forward for the ``MaskedConvTranspose1D`` Args: input (Tensor): input to be mapped with the module output_size (list[int] | None, optional): Desired output from the transpose. Defaults to None. Raises: ValueError: If something other than "zeros" padding has been requested. Returns: Tensor: Output tensors. """ # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") assert isinstance(self.padding, tuple) # (The type ignore below is just used to resolve some small typing issue.) # One cannot replace List by Tuple or Sequence in "_output_padding" # because TorchScript does not support `Sequence[T]` or `tuple[T, ...]`. output_padding = self._output_padding( input, output_size, self.stride, # type: ignore[arg-type] self.padding, # type: ignore[arg-type] self.kernel_size, # type: ignore[arg-type] num_spatial_dims=1, dilation=self.dilation, # type: ignore[arg-type] ) weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return F.conv_transpose1d( input, masked_weight, masked_bias, self.stride, self.padding, output_padding, self.groups, self.dilation )
[docs] @classmethod def from_pretrained(cls, conv_module: nn.ConvTranspose1d) -> MaskedConvTranspose1d: """ Return an instance of ``MaskedConvTranspose1d`` whose weight and bias have the same values as those of ``conv_module``. Args: conv_module (nn.ConvTranspose1d): Target module to be converted Returns: MaskedConvTranspose1d: Module with masked layers to enable FedPM """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _single(conv_module.kernel_size) stride_ = _single(conv_module.stride) padding_ = _single(conv_module.padding) dilation_ = _single(conv_module.dilation) output_padding_ = _single(conv_module.output_padding) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, output_padding=output_padding_, groups=conv_module.groups, bias=has_bias, dilation=dilation_, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module
[docs] class MaskedConvTranspose2d(nn.ConvTranspose2d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t = 0, output_padding: _size_2_t = 0, groups: int = 1, bias: bool = True, dilation: _size_2_t = 1, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ``ConvTranspose2d`` layers. For more information on transposed convolution, please see the PyTorch implementation of ``nn.Conv2d``. (https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#ConvTranspose2d) Like regular ``ConvTranspose2d`` layers (i.e., ``nn.ConvTranspose2d`` module), a masked transpose convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the transposed convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of each dimension in the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of each dimension in the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 """ # Attributes: # weight (Tensor): weights of the module. # bias (Tensor): bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: """ Maps input tensor through the ``MaskedConvTranspose2D`` module Args: input (Tensor): tensor to be mapped output_size (list[int] | None, optional): Desired output size from the module. Defaults to None. Raises: ValueError: Thrown if anything except "zeros" padding is requested Returns: Tensor: Mapped tensor """ # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") assert isinstance(self.padding, tuple) output_padding = self._output_padding( input, output_size, self.stride, # type: ignore[arg-type] self.padding, # type: ignore[arg-type] self.kernel_size, # type: ignore[arg-type] num_spatial_dims=2, dilation=self.dilation, # type: ignore[arg-type] ) weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return F.conv_transpose2d( input, masked_weight, masked_bias, self.stride, self.padding, output_padding, self.groups, self.dilation )
[docs] @classmethod def from_pretrained(cls, conv_module: nn.ConvTranspose2d) -> MaskedConvTranspose2d: """ Return an instance of ``MaskedConvTranspose2d`` whose weight and bias have the same values as those of ``conv_module``. Args: conv_module (nn.ConvTranspose2d): Target module to be converted Returns: MaskedConvTranspose2d: Module with mask layers added to enable FedPM """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _pair(conv_module.kernel_size) stride_ = _pair(conv_module.stride) padding_ = _pair(conv_module.padding) dilation_ = _pair(conv_module.dilation) output_padding_ = _pair(conv_module.output_padding) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, output_padding=output_padding_, groups=conv_module.groups, bias=has_bias, dilation=dilation_, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module
[docs] class MaskedConvTranspose3d(nn.ConvTranspose3d):
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: _size_3_t = 0, output_padding: _size_3_t = 0, groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, padding_mode: str = "zeros", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ``ConvTranspose3d`` layers. For more information on transposed convolution, please see the PyTorch implementation of ``nn.Conv3d``. (https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#ConvTranspose3d) Like regular ``ConvTranspose3d`` layers (i.e., ``nn.ConvTranspose3d`` module), a masked transpose convolutional layer has a weight (i.e., convolutional filter) and a (optional) bias. However, the weight and the bias do not receive gradient in back propagation. Instead, two score tensors - one for the weight and another for the bias - are maintained. In the forward pass, the score tensors are transformed by the Sigmoid function into probability scores, which are then used to produce binary masks via Bernoulli sampling. Finally, the binary masks are applied to the weight and the bias. During training, gradients with respect to the score tensors are computed and used to update the score tensors. **NOTE:** The scores are not assumed to be bounded between 0 and 1. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the transposed convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both sides of each dimension in the input. Default: 0 output_padding (int or tuple, optional): Additional size added to one side of each dimension in the output shape. Default: 0 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 """ # Attributes: # weight (Tensor): weights of the module. # bias (Tensor): bias of the module. # weight_score: learnable scores for the weights. Has the same shape as weight. # bias_score: learnable scores for the bias. Has the same shape as bias. super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, device, dtype, ) self.weight.requires_grad = False self.weight_scores = Parameter(torch.randn_like(self.weight), requires_grad=True) if bias: assert self.bias is not None self.bias.requires_grad = False self.bias_scores = Parameter(torch.randn_like(self.bias), requires_grad=True) else: self.register_parameter("bias_scores", None)
[docs] def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: """ Maps the input tensor with ``MaskedConvTranspose3D`` Args: input (Tensor): Tensor to be mapped output_size (list[int] | None, optional): Desired output size from the transpose. Defaults to None. Raises: ValueError: Throws if anything except "zeros" padding is requested Returns: Tensor: Mapped tensor """ # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") assert isinstance(self.padding, tuple) output_padding = self._output_padding( input, output_size, self.stride, # type: ignore[arg-type] self.padding, # type: ignore[arg-type] self.kernel_size, # type: ignore[arg-type] num_spatial_dims=3, dilation=self.dilation, # type: ignore[arg-type] ) weight_prob_scores = torch.sigmoid(self.weight_scores) weight_mask = bernoulli_sample(weight_prob_scores) masked_weight = weight_mask * self.weight if self.bias is not None: bias_prob_scores = torch.sigmoid(self.bias_scores) bias_mask = bernoulli_sample(bias_prob_scores) masked_bias = bias_mask * self.bias else: masked_bias = None return F.conv_transpose3d( input, masked_weight, masked_bias, self.stride, self.padding, output_padding, self.groups, self.dilation )
[docs] @classmethod def from_pretrained(cls, conv_module: nn.ConvTranspose3d) -> MaskedConvTranspose3d: """ Return an instance of ``MaskedConvTranspose3d`` whose weight and bias have the same values as those of ``conv_module``. Args: conv_module (nn.ConvTranspose3d): Target module to be converted Returns: MaskedConvTranspose3d: Module with masked layers added to enable FedPM """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _triple(conv_module.kernel_size) stride_ = _triple(conv_module.stride) padding_ = _triple(conv_module.padding) dilation_ = _triple(conv_module.dilation) output_padding_ = _triple(conv_module.output_padding) masked_conv_module = cls( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=kernel_size_, stride=stride_, padding=padding_, output_padding=output_padding_, groups=conv_module.groups, bias=has_bias, dilation=dilation_, padding_mode=conv_module.padding_mode, ) masked_conv_module.weight = Parameter(conv_module.weight.clone().detach(), requires_grad=False) masked_conv_module.weight_scores = Parameter(torch.randn_like(masked_conv_module.weight), requires_grad=True) if has_bias: assert conv_module.bias is not None masked_conv_module.bias = Parameter(conv_module.bias.clone().detach(), requires_grad=False) masked_conv_module.bias_scores = Parameter(torch.randn_like(masked_conv_module.bias), requires_grad=True) return masked_conv_module