from collections.abc import Callable
from functools import reduce
from logging import WARNING
import numpy as np
import torch.nn as nn
from flwr.common import (
FitIns,
MetricsAggregationFn,
NDArrays,
Parameters,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.common.typing import FitRes, Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from opacus import GradSampleModule
from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results
from fl4health.utils.parameter_extraction import get_all_model_parameters
[docs]
class Scaffold(BasicFedAvg):
[docs]
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_available_clients: int = 2,
evaluate_fn: (
Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None
) = None,
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
accept_failures: bool = True,
initial_parameters: Parameters,
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
weighted_eval_losses: bool = True,
learning_rate: float = 1.0,
initial_control_variates: Parameters | None = None,
model: nn.Module | None = None,
) -> None:
"""
Scaffold Federated Learning strategy. Implementation based on https://arxiv.org/pdf/1910.06378.pdf
Args:
initial_parameters (Parameters): Initial model parameters to which all client models are set.
fraction_fit (float, optional): Fraction of clients used during training. Defaults to 1.0.
fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0.
min_available_clients (int, optional): Minimum number of total clients in the system.
Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None):
Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure client-side validation by providing a Config dictionary.
Defaults to None.
accept_failures (bool, optional):Whether or not accept rounds containing failures. Defaults to True.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted
averages or a uniform average. FedAvg default is weighted average of the losses by client dataset
counts. Defaults to True.
learning_rate (float, optional): Learning rate for server side optimization. Defaults to 1.0.
initial_control_variates (Parameters | None, optional): These are the initial set of control variates
to use for the scaffold strategy both on the server and client sides. It is optional, but if it is not
provided, the strategy must receive a model that reflects the architecture to be used on the clients.
Defaults to None.
model (nn.Module | None, optional): If provided and initial_control_variates is not, this is used to
set the server control variates and the initial control variates on the client side to all zeros.
If initial_control_variates are provided, they take precedence. Defaults to None.
"""
self.server_model_weights = parameters_to_ndarrays(initial_parameters)
# Setup the initial control variates on the server-side and store them to be transmitted to the clients
initial_control_variates = self.initialize_control_variates(initial_control_variates, model)
initial_parameters.tensors.extend(initial_control_variates.tensors)
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
weighted_aggregation=False,
weighted_eval_losses=weighted_eval_losses,
)
self.learning_rate = learning_rate
self.parameter_packer = ParameterPackerWithControlVariates(len(self.server_model_weights))
[docs]
def initialize_control_variates(
self, initial_control_variates: Parameters | None, model: nn.Module | None
) -> Parameters:
"""
This is a helper function for the SCAFFOLD strategy init function to initialize the server_control_variates.
It either initializes the control variates with custom provided variates or using the provided model
architecture.
Args:
initial_control_variates (Parameters | None): These are the initial set of control variates
to use for the scaffold strategy both on the server and client sides. It is optional, but if it is not
provided, the strategy must receive a model that reflects the architecture to be used on the clients.
Defaults to None.
model (nn.Module | None): If provided and initial_control_variates is not, this is used to
set the server control variates and the initial control variates on the client side to all zeros.
If initial_control_variates are provided, they take precedence. Defaults to None.
Returns:
Parameters: This quantity represents the initial values for the control variates for the server and on the
client-side.
Raises:
ValueError: This error will be raised if neither a model nor initial control variates are provided.
"""
if initial_control_variates is not None:
# If we've been provided with a set of initial control variates, we use those values
self.server_control_variates = parameters_to_ndarrays(initial_control_variates)
return initial_control_variates
elif model is not None:
# If no initial values are provided but a model structure has been given, we initialize the control
# variates to zeros as recommended in the SCAFFOLD paper.
zero_control_variates = [np.zeros_like(val.data) for val in model.parameters() if val.requires_grad]
self.server_control_variates = zero_control_variates
return ndarrays_to_parameters(zero_control_variates)
else:
# Either a model structure or custom initial values for the control variates must be provided to run
# SCAFFOLD
raise ValueError(
"Both initial_control_variates and model are None. One must be defined in order to establish "
"initial values for the control variates."
)
[docs]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[tuple[ClientProxy, FitRes] | BaseException],
) -> tuple[Parameters | None, dict[str, Scalar]]:
"""
Performs server-side aggregation of model weights and control variates associated with the SCAFFOLD method
Both model weights and control variates are aggregated through UNWEIGHTED averaging consistent with the paper.
The newly aggregated weights and control variates are then repacked and sent back to the clients.
This function also handles aggregation of training run metrics (i.e. accuracy over the local training etc.)
through the fit_metrics_aggregation_fn provided in constructing the strategy.
Args:
server_round (int): What round of FL we're on (from servers perspective).
results (list[tuple[ClientProxy, FitRes]]): These are the "successful" training run results. By default
these results are the only ones used in aggregation, even if some of the failed clients have partial
results (in the failures list).
failures (list[tuple[ClientProxy, FitRes] | BaseException]): This is the list of clients that
"failed" during the training phase for one reason or another, including timeouts and exceptions.
Returns:
tuple[Parameters | None, dict[str, Scalar]]: The aggregated weighted and metrics dictionary. The
parameters are optional and will be none in the even that there are no successful clients or there
were failures and they are not accepted.
"""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)]
# x = 1 / |S| * sum(x_i) and c = 1 / |S| * sum(delta_c_i)
# Aggregation operation over packed params (includes both weights and control variate updates)
aggregated_params = self.aggregate(decoded_and_sorted_results)
weights, control_variates_update = self.parameter_packer.unpack_parameters(aggregated_params)
self.server_model_weights = self.compute_updated_weights(weights)
self.server_control_variates = self.compute_updated_control_variates(control_variates_update)
parameters = self.parameter_packer.pack_parameters(self.server_model_weights, self.server_control_variates)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return ndarrays_to_parameters(parameters), metrics_aggregated
[docs]
def compute_parameter_delta(self, params_1: NDArrays, params_2: NDArrays) -> NDArrays:
"""
Computes element-wise difference of two lists of NDarray where elements in params_2 are subtracted from
elements in params_1.
Args:
params_1 (NDArrays): Parameters to be subtracted from.
params_2 (NDArrays): Parameters to subtract from params_1.
Returns:
NDArrays: Element-wise subtraction result across all numpy arrays.
"""
parameter_delta: NDArrays = [param_1 - param_2 for param_1, param_2 in zip(params_1, params_2)]
return parameter_delta
[docs]
def compute_updated_parameters(
self, scaling_coefficient: float, original_params: NDArrays, parameter_updates: NDArrays
) -> NDArrays:
"""
Computes updated_params by moving in the direction of parameter_updates with a step proportional the scaling
coefficient. Calculates original_params + scaling_coefficient * parameter_updates.
Args:
scaling_coefficient (float): Scaling length for the parameter updates (can be thought of as
"learning rate").
original_params (NDArrays): parameters to be updated.
parameter_updates (NDArrays): update direction to update the original_params.
Returns:
NDArrays: Updated numpy arrays according to original_params + scaling_coefficient * parameter_updates.
"""
updated_parameters = [
original_param + scaling_coefficient * update
for original_param, update in zip(original_params, parameter_updates)
]
return updated_parameters
[docs]
def aggregate(self, params: list[NDArrays]) -> NDArrays:
"""
Simple unweighted average to aggregate params, consistent with SCAFFOLD paper. This is "element-wise"
averaging.
Args:
params (list[NDArrays]): numpy arrays whose entries are to be averaged together.
Returns:
NDArrays: element-wise average over the list of numpy arrays.
"""
num_clients = len(params)
# Compute average weights of each layer
params_prime: NDArrays = [reduce(np.add, layer_updates) / num_clients for layer_updates in zip(*params)]
return params_prime
[docs]
def compute_updated_weights(self, weights: NDArrays) -> NDArrays:
"""
Computes and update to the current self.server_model_weights. This assumes that the weights represents the
raw weights aggregated from the client. Therefore it first needs to be turned into a "delta" with
weights - self.server_model_weights.
Then this is used to update with a learning rate scalar (set by self.learning_rate) as
self.server_model_weights + self.learning_rate * (weights - self.server_model_weights).
Args:
weights (NDArrays): The updated weights (aggregated from the clients).
Returns:
NDArrays: self.server_model_weights + self.learning_rate * (weights - self.server_model_weights)
These are the updated server model weights.
"""
# x_update = y_i - x
delta_weights = self.compute_parameter_delta(weights, self.server_model_weights)
# x = x + lr * x_update
server_model_weights = self.compute_updated_parameters(
self.learning_rate, self.server_model_weights, delta_weights
)
return server_model_weights
[docs]
def compute_updated_control_variates(self, control_variates_update: NDArrays) -> NDArrays:
"""
Given the aggregated control variates from the clients, this updates the server control variates in line with
the paper. If c is the server control variates and c_update is the client control variates, then this update
takes the form
c + |S| / N * c_update,
where |S| is the number of clients that participated and N is the total number of clients |S|/N is the
proportion given by fraction fit.
Args:
control_variates_update (NDArrays): Aggregated control variates received from the clients
(uniformly averaged).
Returns:
NDArrays: Updated server control variates according to the formula.
"""
# c = c + |S| / N * c_update
server_control_variates = self.compute_updated_parameters(
self.fraction_fit, self.server_control_variates, control_variates_update
)
return server_control_variates
[docs]
class OpacusScaffold(Scaffold):
[docs]
def __init__(
self,
*,
model: GradSampleModule,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_available_clients: int = 2,
evaluate_fn: (
Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None
) = None,
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
accept_failures: bool = True,
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
weighted_eval_losses: bool = True,
learning_rate: float = 1.0,
) -> None:
"""
A simple extension of the Scaffold strategy to force the model being federally trained to be an valid Opacus
GradSamplingModule and, thereby, ensure that associated the parameters are aligned with those of Opacus based
models used by the InstanceLevelDpClient.
Args:
model (nn.Module): The model architecture to be federally trained. When using this strategy, the provided
model must be of type Opacus GradSampleModule. This model will then be used to set
initialize_parameters as the initial parameters to be used by all clients AND the
initial_control_variates.
**NOTE**: The initial_control_variates are all initialized to zero, as recommended in the SCAFFOLD
paper. If one wants a specific type of control variate initialization, this class will need to be
overridden.
fraction_fit (float, optional): Fraction of clients used during training. Defaults to 1.0.
fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0.
min_available_clients (int, optional): Minimum number of total clients in the system.
Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None):
Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure client-side validation by providing a Config dictionary.
Defaults to None.
accept_failures (bool, optional):Whether or not accept rounds containing failures. Defaults to True.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted
averages or a uniform average. FedAvg default is weighted average of the losses by client dataset
counts. Defaults to True.
learning_rate (float, optional): Learning rate for server side optimization. Defaults to 1.0.
"""
assert isinstance(model, GradSampleModule), "Provided model must be Opacus type GradSampleModule"
# Setting the initial parameters to correspond with those of the provided model
initial_parameters = get_all_model_parameters(model)
# Initializing the control variates to be uniformly zero using the structure of the provided model.
initial_control_variates = ndarrays_to_parameters(
[np.zeros_like(val.data) for val in model.parameters() if val.requires_grad]
)
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
weighted_eval_losses=weighted_eval_losses,
initial_control_variates=initial_control_variates,
model=None,
learning_rate=learning_rate,
)