import copy
from collections.abc import Sequence
from logging import INFO
from pathlib import Path
import torch
from flwr.common.logger import log
from flwr.common.typing import Config, NDArrays
from opacus.optimizers.optimizer import DPOptimizer
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient
from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.losses import LossMeterType, TrainingLosses
from fl4health.utils.metrics import Metric
ScaffoldTrainStepOutput = tuple[torch.Tensor, torch.Tensor]
[docs]
class ScaffoldClient(BasicClient):
[docs]
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
) -> None:
"""
Federated Learning Client for Scaffold strategy.
Implementation based on https://arxiv.org/pdf/1910.06378.pdf.
Args:
data_path (Path): path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model
device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or
'cuda'
loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over
each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle
both checkpointing and state saving. The module, and its underlying model and state checkpointing
components will determine when and how to do checkpointing during client-side training.
No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client
should send data to. Defaults to None.
progress_bar (bool, optional): Whether or not to display a progress bar during client training and
validation. Uses tqdm. Defaults to False
client_name (str | None, optional): An optional client name that uniquely identifies a client.
If not passed, a hash is randomly generated. Client state will use this as part of its state file
name. Defaults to None.
"""
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
)
self.learning_rate: float # eta_l in paper
self.client_control_variates: NDArrays | None = None # c_i in paper
self.client_control_variates_updates: NDArrays | None = None # delta_c_i in paper
self.server_control_variates: NDArrays | None = None # c in paper
# Scaffold require vanilla SGD as optimizer, will assert during setup_client
self.optimizers: dict[str, torch.optim.Optimizer]
self.server_model_weights: NDArrays | None = None # x in paper
self.parameter_exchanger: FullParameterExchangerWithPacking[NDArrays]
[docs]
def get_parameters(self, config: Config) -> NDArrays:
"""
Packs the parameters and control variates into a single NDArrays to be sent to the server for aggregation
"""
if not self.initialized:
log(
INFO,
"Setting up client and providing full model parameters to the server for initialization",
)
# If initialized==False, the server is requesting model parameters from which to initialize all other
# clients. As such get_parameters is being called before fit or evaluate, so we must call
# setup_client first.
self.setup_client(config)
# Need all parameters even if normally exchanging partial
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
assert self.model is not None and self.parameter_exchanger is not None
model_weights = self.parameter_exchanger.push_parameters(self.model, config=config)
# Weights and control variates updates sent to server for aggregation
# Control variates updates sent because only client has access to previous client control variate
# Therefore it can only be computed locally
assert self.client_control_variates_updates is not None
packed_params = self.parameter_exchanger.pack_parameters(
model_weights, self.client_control_variates_updates
)
return packed_params
[docs]
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
"""
Assumes that the parameters being passed contain model parameters concatenated with server control variates.
They are unpacked for the clients to use in training. If it's the first time the model is being initialized,
we assume the full model is being initialized and use the FullParameterExchanger() to set all model weights
Args:
parameters (NDArrays): Parameters have information about model state to be added to the relevant client
model and also the server control variates (initial or after aggregation)
config (Config): The config is sent by the FL server to allow for customization in the function if desired.
"""
assert self.model is not None and self.parameter_exchanger is not None
server_model_state, server_control_variates = self.parameter_exchanger.unpack_parameters(parameters)
self.server_control_variates = server_control_variates
super().set_parameters(server_model_state, config, fitting_round)
# Note that we are restricting to weights that require a gradient here because they are used to compute
# control variates
self.server_model_weights = [
model_params.cpu().detach().clone().numpy()
for model_params in self.model.parameters()
if model_params.requires_grad
]
# If client control variates do not exist, initialize them to be the same as the server control variates.
# Server variates default to be 0, but as stated in the paper the control variates should be the uniform
# average of the client variates. So if server_control_variates are non-zero, this ensures that average
# still holds.
if self.client_control_variates is None:
self.client_control_variates = copy.deepcopy(self.server_control_variates)
[docs]
def update_control_variates(self, local_steps: int) -> None:
"""
Updates local control variates along with the corresponding updates
according to the option 2 in Equation 4 in https://arxiv.org/pdf/1910.06378.pdf
To be called after weights of local model have been updated.
"""
assert self.client_control_variates is not None
assert self.server_control_variates is not None
assert self.server_model_weights is not None
assert self.learning_rate is not None
# y_i
client_model_weights = [
val.cpu().detach().clone().numpy() for val in self.model.parameters() if val.requires_grad
]
# (x - y_i)
delta_model_weights = self.compute_parameters_delta(self.server_model_weights, client_model_weights)
# (c_i - c)
delta_control_variates = self.compute_parameters_delta(
self.client_control_variates, self.server_control_variates
)
updated_client_control_variates = self.compute_updated_control_variates(
local_steps, delta_model_weights, delta_control_variates
)
self.client_control_variates_updates = self.compute_parameters_delta(
updated_client_control_variates, self.client_control_variates
)
# c_i = c_i^plus
self.client_control_variates = updated_client_control_variates
[docs]
def modify_grad(self) -> None:
"""
Modifies the gradient of the local model to correct for client drift.
To be called after the gradients have been computed on a batch of data.
Updates not applied to params until step is called on optimizer.
"""
assert self.client_control_variates is not None
assert self.server_control_variates is not None
model_params_with_grad = [
model_params for model_params in self.model.parameters() if model_params.requires_grad
]
for param, client_cv, server_cv in zip(
model_params_with_grad,
self.client_control_variates,
self.server_control_variates,
):
assert param.grad is not None
tensor_type = param.grad.dtype
server_cv_tensor = torch.from_numpy(server_cv).type(tensor_type)
client_cv_tensor = torch.from_numpy(client_cv).type(tensor_type)
update = server_cv_tensor.to(self.device) - client_cv_tensor.to(self.device)
param.grad += update
[docs]
def compute_parameters_delta(self, params_1: NDArrays, params_2: NDArrays) -> NDArrays:
"""
Computes element-wise difference of two lists of NDarray
where elements in params_2 are subtracted from elements in params_1
"""
parameter_delta: NDArrays = [param_1 - param_2 for param_1, param_2 in zip(params_1, params_2)]
return parameter_delta
[docs]
def compute_updated_control_variates(
self,
local_steps: int,
delta_model_weights: NDArrays,
delta_control_variates: NDArrays,
) -> NDArrays:
"""
Computes the updated local control variates according to option 2 in Equation 4 of paper
"""
# coef = 1 / (K * eta_l)
scaling_coefficient = 1 / (local_steps * self.learning_rate)
# c_i^plus = c_i - c + 1/(K*lr) * (x - y_i)
updated_client_control_variates = [
delta_control_variate + scaling_coefficient * delta_model_weight
for delta_control_variate, delta_model_weight in zip(delta_control_variates, delta_model_weights)
]
return updated_client_control_variates
[docs]
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
assert self.model is not None
model_size = len(self.model.state_dict())
parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size))
return parameter_exchanger
[docs]
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None:
"""
Called after training with the number of local_steps performed over the FL round and
the corresponding loss dictionary.
"""
self.update_control_variates(local_steps)
[docs]
def setup_client(self, config: Config) -> None:
"""
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these.
Then set initialized attribute to True. Extends the basic client to extract the learning rate
from the optimizer and set the learning_rate attribute (used to compute updated control variates).
Args:
config (Config): The config from the server.
"""
super().setup_client(config)
if isinstance(self, DPScaffoldClient):
assert isinstance(self.optimizers["global"], DPOptimizer)
else:
assert isinstance(self.optimizers["global"], torch.optim.SGD)
self.learning_rate = self.optimizers["global"].defaults["lr"]
[docs]
class DPScaffoldClient(ScaffoldClient, InstanceLevelDpClient):
"""
Federated Learning client for Instance Level Differentially Private Scaffold strategy
Implemented as specified in https://arxiv.org/abs/2111.09278
"""
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None,
reporters: Sequence[BaseReporter] | None = None,
progress_bar: bool = False,
client_name: str | None = None,
) -> None:
ScaffoldClient.__init__(
self,
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
)
InstanceLevelDpClient.__init__(
self,
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpoint_and_state_module=checkpoint_and_state_module,
reporters=reporters,
progress_bar=progress_bar,
client_name=client_name,
)