Source code for fl4health.parameter_exchange.fedpm_exchanger

import torch
import torch.nn as nn
from flwr.common.typing import Config, NDArrays

from fl4health.parameter_exchange.layer_exchanger import DynamicLayerExchanger
from fl4health.parameter_exchange.parameter_selection_criteria import select_scores_and_sample_masks
from fl4health.utils.functions import sigmoid_inverse


[docs] class FedPmExchanger(DynamicLayerExchanger): def __init__(self) -> None: super().__init__(select_scores_and_sample_masks)
[docs] def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: current_state = model.state_dict() layer_params, layer_names = self.unpack_parameters(parameters) for layer_name, layer_param in zip(layer_names, layer_params): # Apply the inverse of the Sigmoid function # since the scores for masked layers are supposed to be unbounded. with torch.no_grad(): current_state[layer_name] = sigmoid_inverse(torch.tensor(layer_param)) model.load_state_dict(current_state, strict=True)