Source code for fl4health.parameter_exchange.partial_parameter_exchanger

from abc import abstractmethod
from typing import Generic, TypeVar

from flwr.common.typing import NDArrays
from torch import nn

from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.parameter_exchange.parameter_packer import ParameterPacker


T = TypeVar("T")


[docs] class PartialParameterExchanger(ParameterExchanger, Generic[T]):
[docs] def __init__(self, parameter_packer: ParameterPacker[T]) -> None: """ Base class meant to properly facilitate partial parameter exchange through a selection criterion. This mechanism is more complicated than, for example, that used by the ``FixedLayerExchanger`` where the subset parameters to exchange do not change dynamically from round to round. Args: parameter_packer (ParameterPacker[T]): Parameter packer that can be used to pack in more information than just the parameters being exchange. This is important, for example, when exchanging different sets of layers in each round. """ super().__init__() self.parameter_packer = parameter_packer
[docs] def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> NDArrays: return self.parameter_packer.pack_parameters(model_weights, additional_parameters)
[docs] def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]: return self.parameter_packer.unpack_parameters(packed_parameters)
[docs] @abstractmethod def select_parameters( self, model: nn.Module, initial_model: nn.Module | None = None, ) -> tuple[NDArrays, T]: raise NotImplementedError