Source code for fl4health.parameter_exchange.parameter_selection_criteria

import math
from functools import partial

import torch
import torch.nn as nn
from flwr.common.typing import NDArray, NDArrays
from scipy.stats import bernoulli
from torch import Tensor

from fl4health.model_bases.masked_layers.masked_layers_utils import is_masked_module
from fl4health.utils.typing import LayerSelectionFunction


[docs] class LayerSelectionFunctionConstructor:
[docs] def __init__( self, norm_threshold: float, exchange_percentage: float, normalize: bool = True, select_drift_more: bool = True ) -> None: """ This class leverages ``functools.partial`` to construct layer selection functions, which are meant to be used by the ``DynamicLayerExchanger`` class. Args: norm_threshold (float): A nonnegative real number used to select those layers whose drift in l2 norm exceeds (or falls short of) it. exchange_percentage (float): Indicates the percentage of layers that are selected. normalize (bool, optional): Indicates whether when calculating the norm of a layer, we also divide by the number of parameters in that layer. Defaults to True. select_drift_more (bool, optional): Indicates whether layers with larger drift norm are selected. Defaults to True. """ assert 0 < exchange_percentage <= 1 assert 0 < norm_threshold self.norm_threshold = norm_threshold self.exchange_percentage = exchange_percentage self.normalize = normalize self.select_drift_more = select_drift_more
[docs] def select_by_threshold(self) -> LayerSelectionFunction: return partial( select_layers_by_threshold, self.norm_threshold, self.normalize, self.select_drift_more, )
[docs] def select_by_percentage(self) -> LayerSelectionFunction: return partial( select_layers_by_percentage, self.exchange_percentage, self.normalize, self.select_drift_more, )
def _calculate_drift_norm(t1: torch.Tensor, t2: torch.Tensor, normalize: bool) -> float: """ Selection criteria functions for selecting entire layers. Intended to be used by the ``DynamicLayerExchanger`` class via the ``LayerSelectionFunctionConstructor`` class. Args: t1 (torch.Tensor): First tensor t2 (torch.Tensor): Second tensor normalize (bool): Whether to divide the difference between the tensors by their number of elements. Returns: float: _description_ """ t_diff = (t1 - t2).float() drift_norm = torch.linalg.norm(t_diff) if normalize: drift_norm /= torch.numel(t_diff) return drift_norm.item()
[docs] def select_layers_by_threshold( threshold: float, normalize: bool, select_drift_more: bool, model: nn.Module, initial_model: nn.Module, ) -> tuple[NDArrays, list[str]]: """ Return those layers of model that deviate (in l2 norm) away from corresponding layers of ``self.initial_model`` by at least (or at most) ``self.threshold``. Args: threshold (float): Drift threshold to be used for selection. It is an fixed value. normalize (bool): Whether to divide the difference between the tensors by their number of elements. select_drift_more (bool): Whether we are selecting parameters that have drifted further (True) or less far from their comparison values model (nn.Module): Model after training/modification initial_model (nn.Module): Model that we started with to which we are comparing parameters. Returns: tuple[NDArrays, list[str]]: Layers selected by the process and their corresponding names in the model's ``state_dict``. """ layer_names = [] layers_to_transfer = [] initial_model_states = initial_model.state_dict() model_states = model.state_dict() for layer_name, layer_param in model_states.items(): ghost_of_layer_params_past = initial_model_states[layer_name] drift_norm = _calculate_drift_norm(layer_param, ghost_of_layer_params_past, normalize) if select_drift_more: if drift_norm > threshold: layers_to_transfer.append(layer_param.cpu().numpy()) layer_names.append(layer_name) else: if drift_norm <= threshold: layers_to_transfer.append(layer_param.cpu().numpy()) layer_names.append(layer_name) return layers_to_transfer, layer_names
[docs] def select_layers_by_percentage( exchange_percentage: float, normalize: bool, select_drift_more: bool, model: nn.Module, initial_model: nn.Module, ) -> tuple[NDArrays, list[str]]: names_to_norm_drift = {} initial_model_states = initial_model.state_dict() model_states = model.state_dict() for layer_name, layer_param in model_states.items(): layer_param_past = initial_model_states[layer_name] drift_norm = _calculate_drift_norm(layer_param, layer_param_past, normalize) names_to_norm_drift[layer_name] = drift_norm total_param_num = len(names_to_norm_drift.keys()) num_param_exchange = int(math.ceil(total_param_num * exchange_percentage)) param_to_exchange_names = sorted( names_to_norm_drift.keys(), key=lambda x: names_to_norm_drift[x], reverse=select_drift_more )[:(num_param_exchange)] return [model_states[name].cpu().numpy() for name in param_to_exchange_names], param_to_exchange_names
# Score generating functions used for selecting arbitrary sets of weights. # The ones implemented here are those that demonstrated good performance in the super-mask paper. # Link to this paper: https://arxiv.org/abs/1905.01067
[docs] def largest_final_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: names_to_scores = {} for tensor_name, tensor_values in model.state_dict().items(): names_to_scores[tensor_name] = torch.abs(tensor_values) return names_to_scores
[docs] def smallest_final_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: names_to_scores = {} for tensor_name, tensor_values in model.state_dict().items(): names_to_scores[tensor_name] = (-1) * torch.abs(tensor_values) return names_to_scores
[docs] def largest_magnitude_change_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() initial_model_states = initial_model.state_dict() for tensor_name, current_tensor_values in current_model_states.items(): initial_tensor_values = initial_model_states[tensor_name] names_to_scores[tensor_name] = torch.abs(current_tensor_values - initial_tensor_values) return names_to_scores
[docs] def smallest_magnitude_change_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() initial_model_states = initial_model.state_dict() for tensor_name, current_tensor_values in current_model_states.items(): initial_tensor_values = initial_model_states[tensor_name] names_to_scores[tensor_name] = (-1) * torch.abs(current_tensor_values - initial_tensor_values) return names_to_scores
[docs] def largest_increase_in_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() initial_model_states = initial_model.state_dict() for tensor_name, current_tensor_values in current_model_states.items(): initial_tensor_values = initial_model_states[tensor_name] names_to_scores[tensor_name] = torch.abs(current_tensor_values) - torch.abs(initial_tensor_values) return names_to_scores
[docs] def smallest_increase_in_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() initial_model_states = initial_model.state_dict() for tensor_name, current_tensor_values in current_model_states.items(): initial_tensor_values = initial_model_states[tensor_name] names_to_scores[tensor_name] = (-1) * (torch.abs(current_tensor_values) - torch.abs(initial_tensor_values)) return names_to_scores
# Helper functions for select_scores_and_sample_masks def _sample_masks(score_tensor: Tensor) -> NDArray: bernoulli_probabilities = torch.sigmoid(score_tensor).cpu().numpy() # Perform Bernoulli sampling. binary_mask = bernoulli.rvs(bernoulli_probabilities) return binary_mask def _process_masked_module( module: nn.Module, model_state_dict: dict[str, Tensor], module_name: str | None = None ) -> tuple[NDArrays, list[str]]: """ Perform Bernoulli sampling using the weight and bias scores of a masked module. Args: module (nn.Module): the module upon which operations described above are performed. "module" can either be a submodule of the model trained in FedPM, or it can a standalone module itself. In the latter case, the argument ``model_state_dict`` should be the same as ``module.state_dict()``. In either case, it is assumed that module is a masked module. model_state_dict (dict[str, Tensor]): the state dictionary of the model trained in FedPM. module_name (str | None): the name of module if module is a submodule of the model trained in FedPM. This is used to access the weight and bias score tensors in ``model_state_dict``. Defaults to None. """ masks_to_exchange = [] score_tensor_names = [] # If module_name is passed in, then we prepend it to "weight_scores" to get the correct # key in the state dictionary. weight_scores_tensor_name = f"{module_name}.weight_scores" if module_name else "weight_scores" score_tensor_names.append(weight_scores_tensor_name) weight_scores = model_state_dict[weight_scores_tensor_name] # Note: due to the Bernoulli sampling performed here, the parameters selected are in fact binary masks # even though their corresponding names are something like "weight_scores" or "bias_scores". # After the tensors have been aggregated by the strategy, they will become score tensors again. # This misalignment was allowed because these parameter names will later be used to load the model anyway. masks_to_exchange.append(_sample_masks(weight_scores)) # Do the same thing with bias_scores if it exists if "bias_scores" in module.state_dict().keys(): bias_scores_tensor_name = f"{module_name}.bias_scores" if module_name else "bias_scores" score_tensor_names.append(bias_scores_tensor_name) bias_scores = model_state_dict[bias_scores_tensor_name] masks_to_exchange.append(_sample_masks(bias_scores)) return masks_to_exchange, score_tensor_names
[docs] def select_scores_and_sample_masks(model: nn.Module, initial_model: nn.Module | None) -> tuple[NDArrays, list[str]]: """ Selection function that first selects the ``weight_scores`` and ``bias_scores`` parameters for the masked layers, and then samples binary masks based on those scores to send to the server. This function is meant to be used for the FedPM algorithm. **NOTE:** in the current implementation, we always exchange the score tensors for all layers. In the future, we might support exchanging a subset of the layers (for example, filtering out the masks that are all zeros). """ model_states = model.state_dict() with torch.no_grad(): if is_masked_module(model): return _process_masked_module(module=model, model_state_dict=model_states) else: masks_to_exchange = [] score_tensor_names = [] for name, module in model.named_modules(): if is_masked_module(module): module_masks, module_score_tensor_names = _process_masked_module( module=module, model_state_dict=model_states, module_name=name ) masks_to_exchange.extend(module_masks) score_tensor_names.extend(module_score_tensor_names) return masks_to_exchange, score_tensor_names