Source code for fl4health.clients.fedper_client

from flwr.common.typing import Config

from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger


[docs] class FedPerClient(BasicClient): """ Client to implement the FedPer method (https://arxiv.org/abs/1912.00818). Trains a global feature extractor shared by all clients through FedAvg and a private classifier that is unique to each client. The training is nearly identical to the BasicClient with the exception that our parameter exchanger needs to be a fixed layer exchanger that only exchanges the feature extraction base, which relies on the model being of type SequentiallySplitExchangeBaseModel. """
[docs] def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, SequentiallySplitExchangeBaseModel), ( "Models for FedPer must be of type SequentiallySplitExchangeBaseModel to facilitate partial weight " f"exchange. The current model is of type {type(self.model)}." ) return FixedLayerExchanger(self.model.layers_to_exchange())