Source code for fl4health.parameter_exchange.full_exchanger

from collections import OrderedDict

import torch
import torch.nn as nn
from flwr.common.typing import Config, NDArrays

from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger


[docs] class FullParameterExchanger(ParameterExchanger):
[docs] def push_parameters( self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: # Sending all of parameters ordered by state_dict keys # NOTE: Order matters, because it is relied upon by pull_parameters below return [val.cpu().numpy() for _, val in model.state_dict().items()]
[docs] def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: # Assumes all model parameters are contained in parameters # The state_dict is reconstituted because parameters is simply a list of bytes params_dict = zip(model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True)