fl4health.strategies.feddg_ga module¶
- class FairnessMetric(metric_type, metric_name=None, signal=None)[source]¶
Bases:
object
- __init__(metric_type, metric_name=None, signal=None)[source]¶
Defines a fairness metric with attributes that can be overridden if needed.
Instantiates a fairness metric with a type and optional metric name and signal if one wants to override them.
- Parameters:
metric_type (FairnessMetricType) – the fairness metric type. If
CUSTOM
, themetric_name
and signal should be provided.metric_name (str | None, optional) – the name of the metric to be used as fairness metric. Mandatory if
metric_type
isCUSTOM
. Defaults to None.signal (float | None, optional) – the signal of the fairness metric. Mandatory if
metric_type
isCUSTOM
. Defaults to None.
- class FairnessMetricType(value)[source]¶
Bases:
Enum
Defines the basic types for fairness metrics, their default names and their default signals
- ACCURACY = 'val - prediction - accuracy'¶
- CUSTOM = 'custom'¶
- LOSS = 'val - checkpoint'¶
- classmethod signal_for_type(fairness_metric_type)[source]¶
Return the default signal for the given metric type.
- Parameters:
fairness_metric_type (FairnessMetricType) – the fairness metric type.
- Raises:
SignalForTypeException – if type is
CUSTOM
as the signal has to be defined by the user.- Returns:
-1.0 if
FairnessMetricType.ACCURACY
or 1.0 ifFairnessMetricType.LOSS
.- Return type:
- class FedDgGa(*, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters=None, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, fairness_metric=None, adjustment_weight_step_size=0.2)[source]¶
Bases:
FedAvg
- __init__(*, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters=None, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, fairness_metric=None, adjustment_weight_step_size=0.2)[source]¶
Strategy for the FedDG-GA algorithm (Federated Domain Generalization with Generalization Adjustment, Zhang et al. 2023). This strategy assumes (and checks) that the configuration sent by the server to the clients has the key “evaluate_after_fit” and it is set to True. It also ensures that the key “pack_losses_with_val_metrics” is present and its value is set to True. These are to facilitate the exchange of evaluation information needed for the strategy to work correctly.
NOTE: For FedDG-GA, we require that
fraction_fit
andfraction_evaluate
are 1.0, as behavior of the FedDG-GA algorithm is not well-defined when participation in each round of training and evaluation is partial. Thus, we force these values to be 1.0 in super and do not allow them to be set by the user.- Parameters:
min_fit_clients (int, optional) – Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients (int, optional) – Minimum number of clients used during validation. Defaults to 2.
min_available_clients (int, optional) – Minimum number of total clients in the system. Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None) – Optional function used for validation.. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure training. Must be specified for this strategy.. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure validation. Must be specified for this strategy.. Defaults to None.
accept_failures (bool, optional) – Whether or not accept rounds containing failures. Defaults to True.
initial_parameters (Parameters | None, optional) – Initial global model parameters. Defaults to None.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None.
fairness_metric (FairnessMetric | None, optional) – The metric to evaluate the local model of each client against the global model in order to determine their adjustment weight for aggregation. Can be set to any default metric in
FairnessMetricType
or set to use a custom metric. Defaults to None.adjustment_weight_step_size (float, optional) – The step size to determine the magnitude of change for the generalization adjustment weights. It has to be
0 < adjustment_weight_step_size < 1
. Defaults to 0.2.
- aggregate_evaluate(server_round, results, failures)[source]¶
Aggregate evaluation losses using weighted average.
Collects the evaluation metrics and updates the adjustment weights, which will be used when aggregating the results for the next round.
- Parameters:
- Returns:
A tuple containing the aggregated evaluation loss and the aggregated evaluation metrics.
- Return type:
- aggregate_fit(server_round, results, failures)[source]¶
Aggregate fit results by weighing them against the adjustment weights and then summing them.
Collects the fit metrics that will be used to change the adjustment weights for the next round.
- Parameters:
- Returns:
A tuple containing the aggregated parameters and the aggregated fit metrics.
- Return type:
- configure_evaluate(server_round, parameters, client_manager)[source]¶
Configure the next round of evaluation.
- configure_fit(server_round, parameters, client_manager)[source]¶
Configure the next round of training. Will also collect the number of rounds the training will run for in order to calculate the adjustment weight step size. Fails if
n_server_rounds
is not set in the config or if it’s not an integer.- Parameters:
server_round (int) – the current server round.
parameters (Parameters) – the model parameters.
client_manager (ClientManager) – The client manager which holds all currently connected clients. It must be an instance of
FixedSamplingClientManager
.
- Returns:
the input for the clients’ fit function.
- Return type:
- get_current_weight_step_size(server_round)[source]¶
Calculates the current weight step size based on the current server round, weight step size and total number of rounds.
- Returns (float):
the current value for the weight step size.