Source code for fl4health.parameter_exchange.full_exchanger

from collections import OrderedDict

import torch
import torch.nn as nn
from flwr.common.typing import Config, NDArrays

from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger


[docs] class FullParameterExchanger(ParameterExchanger):
[docs] def push_parameters( self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: """ Sending all of parameters ordered by state_dict keys **NOTE**: Order matters, because it is relied upon by ``pull_parameters`` below Args: model (nn.Module): Model containing the weights to be sent. initial_model (nn.Module | None, optional): Not Used. Defaults to None. config (Config | None, optional): Not Used. Defaults to None. Returns: NDArrays: All parameters contained in the ``state_dict`` of the model parameter. The ``state_dict`` maintains a specific order. """ # Sending all of parameters ordered by state_dict keys # NOTE: Order matters, because it is relied upon by pull_parameters below return [val.cpu().numpy() for _, val in model.state_dict().items()]
[docs] def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: """ Takes in a set of parameters in the form of ``NDArrays`` (list of numpy arrays) and injects them into the provided model. Assumes all model parameters are contained in parameters. The ``state_dict`` is reconstituted because parameters is simply a list of arrays Args: parameters (NDArrays): Parameter to inject into the provided model model (nn.Module): Model to inject the parameters into config (Config | None, optional): Not used.. Defaults to None. """ params_dict = zip(model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True)