Source code for fl4health.parameter_exchange.partial_parameter_exchanger

from abc import abstractmethod
from typing import Generic, TypeVar

import torch.nn as nn
from flwr.common.typing import NDArrays

from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.parameter_exchange.parameter_packer import ParameterPacker

T = TypeVar("T")


[docs] class PartialParameterExchanger(ParameterExchanger, Generic[T]): def __init__(self, parameter_packer: ParameterPacker[T]) -> None: super().__init__() self.parameter_packer = parameter_packer
[docs] def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> NDArrays: return self.parameter_packer.pack_parameters(model_weights, additional_parameters)
[docs] def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]: return self.parameter_packer.unpack_parameters(packed_parameters)
[docs] @abstractmethod def select_parameters( self, model: nn.Module, initial_model: nn.Module | None = None, ) -> tuple[NDArrays, T]: raise NotImplementedError