Source code for fl4health.parameter_exchange.partial_parameter_exchanger
from abc import abstractmethod
from typing import Generic, TypeVar
import torch.nn as nn
from flwr.common.typing import NDArrays
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.parameter_exchange.parameter_packer import ParameterPacker
T = TypeVar("T")
[docs]
class PartialParameterExchanger(ParameterExchanger, Generic[T]):
def __init__(self, parameter_packer: ParameterPacker[T]) -> None:
super().__init__()
self.parameter_packer = parameter_packer
[docs]
def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> NDArrays:
return self.parameter_packer.pack_parameters(model_weights, additional_parameters)
[docs]
def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]:
return self.parameter_packer.unpack_parameters(packed_parameters)
[docs]
@abstractmethod
def select_parameters(
self,
model: nn.Module,
initial_model: nn.Module | None = None,
) -> tuple[NDArrays, T]:
raise NotImplementedError