Source code for fl4health.clients.fedper_client
from flwr.common.typing import Config
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
[docs]
class FedPerClient(BasicClient):
"""
Client to implement the FedPer method (https://arxiv.org/abs/1912.00818).
Trains a global feature extractor shared by all clients through FedAvg and a private classifier that is unique to
each client. The training is nearly identical to the ``BasicClient`` with the exception that our parameter
exchanger needs to be a fixed layer exchanger that only exchanges the feature extraction base, which relies on the
model being of type ``SequentiallySplitExchangeBaseModel``.
"""
[docs]
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
assert isinstance(self.model, SequentiallySplitExchangeBaseModel), (
"Models for FedPer must be of type SequentiallySplitExchangeBaseModel to facilitate partial weight "
f"exchange. The current model is of type {type(self.model)}."
)
return FixedLayerExchanger(self.model.layers_to_exchange())