from collections.abc import Callable
from logging import INFO, WARNING
from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
GetPropertiesIns,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from opacus import GradSampleModule
from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results
from fl4health.strategies.strategy_with_poll import StrategyWithPolling
from fl4health.utils.functions import decode_and_pseudo_sort_results
from fl4health.utils.parameter_extraction import get_all_model_parameters
[docs]
class BasicFedAvg(FedAvg, StrategyWithPolling):
"""Configurable FedAvg strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes
[docs]
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: (
Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None
) = None,
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
accept_failures: bool = True,
initial_parameters: Parameters | None = None,
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
weighted_aggregation: bool = True,
weighted_eval_losses: bool = True,
) -> None:
"""
Federated Averaging with Flexible Sampling. This implementation extends that of Flower in two ways. The first
is that it provides an option for unweighted averaging, where Flower only offers weighted averaging based on
client sample counts. The second is that it allows users to Flower's standard sampling or use a custom
sampling approach implemented in by a custom client manager.
FedAvg Paper: https://arxiv.org/abs/1602.05629.
Args:
fraction_fit (float, optional): Fraction of clients used during training. In case `min_fit_clients` is
larger than `fraction_fit * available_clients`, `min_fit_clients` will still be sampled.
Defaults to 1.0.
fraction_evaluate (float, optional): Fraction of clients used during validation. In case
`min_evaluate_clients` is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients`
will still be sampled. Defaults to 1.0.
min_fit_clients (int, optional): _description_. Defaults to 2.
min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2.
min_available_clients (int, optional): Minimum number of total clients in the system.
Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None):
Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure client-side validation by providing a Config dictionary.
Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True.
initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted
average or a uniform average. FedAvg default is weighted average by client dataset counts.
Defaults to True.
weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted
averages or a uniform average. FedAvg default is weighted average of the losses by client dataset
counts. Defaults to True.
"""
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.weighted_aggregation = weighted_aggregation
self.weighted_eval_losses = weighted_eval_losses
[docs]
def aggregate_fit(
self,
server_round: int,
results: list[tuple[ClientProxy, FitRes]],
failures: list[tuple[ClientProxy, FitRes] | BaseException],
) -> tuple[Parameters | None, dict[str, Scalar]]:
"""
Aggregate the results from the federated fit round. This is done with either weighted or unweighted FedAvg,
depending on the settings used for the strategy.
Args:
server_round (int): Indicates the server round we're currently on.
results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training
that need to be aggregated on the server-side.
failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions
from clients that experienced an issue during training, such as timeouts or exceptions.
Returns:
tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary.
"""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]
# Aggregate them in a weighted or unweighted fashion based on settings.
aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation)
# Convert back to parameters
parameters_aggregated = ndarrays_to_parameters(aggregated_arrays)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
return parameters_aggregated, metrics_aggregated
[docs]
def aggregate_evaluate(
self,
server_round: int,
results: list[tuple[ClientProxy, EvaluateRes]],
failures: list[tuple[ClientProxy, EvaluateRes] | BaseException],
) -> tuple[float | None, dict[str, Scalar]]:
"""
Aggregate the metrics and losses returned from the clients as a result of the evaluation round.
Args:
results (list[tuple[ClientProxy, EvaluateRes]]): The client identifiers and the results of their local
evaluation that need to be aggregated on the server-side. These results are loss values and the
metrics dictionary.
failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]): These are the results and
exceptions from clients that experienced an issue during evaluation, such as timeouts or exceptions.
Returns:
tuple[float | None, dict[str, Scalar]]: Aggregated loss values and the aggregated metrics. The metrics
are aggregated according to evaluate_metrics_aggregation_fn.
"""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}
# Get losses and number of examples from the evaluation results.
loss_results = [(evaluate_res.num_examples, evaluate_res.loss) for _, evaluate_res in results]
# Then aggregate the losses
loss_aggregated = aggregate_losses(loss_results, self.weighted_eval_losses)
# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")
return loss_aggregated, metrics_aggregated
[docs]
class OpacusBasicFedAvg(BasicFedAvg):
"""Configurable FedAvg strategy implementation."""
# pylint: disable=too-many-arguments,too-many-instance-attributes
[docs]
def __init__(
self,
*,
model: GradSampleModule,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: (
Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None
) = None,
on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
accept_failures: bool = True,
fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
weighted_aggregation: bool = True,
weighted_eval_losses: bool = True,
) -> None:
"""
This strategy is a simple extension of the BasicFedAvg strategy to force the model being federally trained to
be an valid Opacus GradSampleModule and, thereby, ensure that associated the parameters are aligned with
those of Opacus based models used by the InstanceLevelDpClient.
Args:
model (GradSampleModule): The model architecture to be federally trained. When using this strategy,
the model must be of type Opacus GradSampleModule. This model will then be used to set
initialize_parameters as the initial parameters to be used by all clients.
fraction_fit (float, optional): Fraction of clients used during training. In case `min_fit_clients` is
larger than `fraction_fit * available_clients`, `min_fit_clients` will still be sampled.
Defaults to 1.0.
fraction_evaluate (float, optional): Fraction of clients used during validation. In case
`min_evaluate_clients` is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients`
will still be sampled. Defaults to 1.0.
min_fit_clients (int, optional): _description_. Defaults to 2.
min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2.
min_available_clients (int, optional): Minimum number of total clients in the system.
Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None):
Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional):
Function used to configure client-side validation by providing a Config dictionary.
Defaults to None.
accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function.
Defaults to None.
weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted
average or a uniform average. FedAvg default is weighted average by client dataset counts.
Defaults to True.
weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted
averages or a uniform average. FedAvg default is weighted average of the losses by client dataset
counts. Defaults to True.
"""
super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
weighted_aggregation=weighted_aggregation,
weighted_eval_losses=weighted_eval_losses,
)
assert isinstance(model, GradSampleModule), "Provided model must be Opacus type GradSampleModule"
# Setting the initial parameters to correspond with those of the provided model
self.initial_parameters = get_all_model_parameters(model)