fl4health.clients.fed_prox_client module

class FedProxClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None)[source]

Bases: AdaptiveDriftConstraintClient

This client implements the FedProx algorithm from Federated Optimization in Heterogeneous Networks. The idea is fairly straightforward. The local loss for each client is augmented with a norm on the difference between the local client weights during training (\(\mathbf{w}\)) and the initial globally shared weights (\(\mathbf{w}^t\)).

NOTE: The initial value for mu (the drift penalty weight) is set on the server side and passed to each client through parameter exchange. It is stored as the more generally named drift_penalty_weight.

update_before_train(current_server_round)[source]

Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.

Parameters:

current_server_round (int) – The number of current server round.

Return type:

None