fl4health.strategies.feddg_ga module

class FairnessMetric(metric_type, metric_name=None, signal=None)[source]

Bases: object

Defines a fairness metric with attributes that can be overridden if needed.

__init__(metric_type, metric_name=None, signal=None)[source]
Instantiates a fairness metric with a type and optional metric name and

signal if one wants to override them.

Parameters:
  • metric_type (FairnessMetricType) – (FairnessMetricType) the fairness metric type. If CUSTOM, the metric_name and signal should be provided.

  • metric_name (Optional[str]) – (str, optional) the name of the metric to be used as fairness metric. Optional, default is metric_type.value. Mandatory if metric_type is CUSTOM.

  • signal (Optional[float]) – (float, optional) the signal of the fairness metric. Optional, default is FairnessMetricType.signal_for_type(metric_type). Mandatory if metric_type is CUSTOM.

class FairnessMetricType(value)[source]

Bases: Enum

Defines the basic types for fairness metrics, their default names and their default signals

ACCURACY = 'val - prediction - accuracy'
CUSTOM = 'custom'
LOSS = 'val - checkpoint'
classmethod signal_for_type(fairness_metric_type)[source]

Return the default signal for the given metric type.

Parameters:

fairness_metric_type (FairnessMetricType) – (FairnessMetricType) the fairness metric type.

Return type:

float

Returns: (float) -1.0 if FairnessMetricType.ACCURACY or 1.0 if FairnessMetricType.LOSS.

Raises: (SignalForTypeException) if type is CUSTOM as the signal has to be defined by the user.

class FedDgGa(*, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters=None, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, fairness_metric=None, adjustment_weight_step_size=0.2)[source]

Bases: FedAvg

__init__(*, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters=None, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, fairness_metric=None, adjustment_weight_step_size=0.2)[source]

Strategy for the FedDG-GA algorithm (Federated Domain Generalization with Generalization Adjustment, Zhang et al. 2023). This strategy assumes (and checks) that the configuration sent by the server to the clients has the key “evaluate_after_fit” and it is set to True. It also ensures that the key “pack_losses_with_val_metrics” is present and its value is set to True. These are to facilitate the exchange of evaluation information needed for the strategy to work correctly.

Parameters:
  • min_fit_clients (int) – int, optional Minimum number of clients used during training. Defaults to 2.

  • min_evaluate_clients (int) – int, optional Minimum number of clients used during validation. Defaults to 2.

  • min_available_clients (int) – int, optional Minimum number of total clients in the system. Defaults to 2.

  • evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None) – Optional function used for validation. Defaults to None.

  • on_fit_config_fn (Optional[Callable[[int], dict[str, Union[bool, bytes, float, int, str]]]]) – Callable[[int], dict[str, Scalar]], optional Function used to configure training. Must be specified for this strategy. Defaults to None.

  • on_evaluate_config_fn (Optional[Callable[[int], dict[str, Union[bool, bytes, float, int, str]]]]) – Callable[[int], dict[str, Scalar]], optional Function used to configure validation. Must be specified for this strategy. Defaults to None.

  • accept_failures (bool) – bool, optional Whether or not accept rounds containing failures. Defaults to True.

  • initial_parameters (Optional[Parameters]) – Parameters, optional Initial global model parameters.

  • fit_metrics_aggregation_fn (Optional[Callable[[List[Tuple[int, Dict[str, Union[bool, bytes, float, int, str]]]]], Dict[str, Union[bool, bytes, float, int, str]]]]) – MetricsAggregationFn | None Metrics aggregation function, optional.

  • evaluate_metrics_aggregation_fn (Optional[Callable[[List[Tuple[int, Dict[str, Union[bool, bytes, float, int, str]]]]], Dict[str, Union[bool, bytes, float, int, str]]]]) – MetricsAggregationFn | None Metrics aggregation function, optional.

  • fairness_metric (Optional[FairnessMetric]) – FairnessMetric, optional. The metric to evaluate the local model of each client against the global model in order to determine their adjustment weight for aggregation. Can be set to any default metric in FairnessMetricType or set to use a custom metric. Optional, default is FairnessMetric(FairnessMetricType.LOSS).

  • adjustment_weight_step_size (float) – float The step size to determine the magnitude of change for the generalization adjustment weights. It has to be 0 < adjustment_weight_step_size < 1. Optional, default is 0.2.

aggregate_evaluate(server_round, results, failures)[source]

Aggregate evaluation losses using weighted average.

Collects the evaluation metrics and updates the adjustment weights, which will be used when aggregating the results for the next round.

Parameters:
  • server_round (int) – (int) the current server round.

  • results (list[tuple[ClientProxy, EvaluateRes]]) – (list[tuple[ClientProxy, FitRes]]) The clients’ evaluate results.

  • failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]) – (list[tuple[ClientProxy, FitRes] | BaseException]) the clients’ evaluate failures.

Return type:

tuple[float | None, dict[str, Union[bool, bytes, float, int, str]]]

Returns:

(tuple[float | None, dict[str, Scalar]]) A tuple containing the aggregated evaluation loss

and the aggregated evaluation metrics.

aggregate_fit(server_round, results, failures)[source]

Aggregate fit results by weighing them against the adjustment weights and then summing them.

Collects the fit metrics that will be used to change the adjustment weights for the next round.

Parameters:
  • server_round (int) – (int) the current server round.

  • results (list[tuple[ClientProxy, FitRes]]) – (list[tuple[ClientProxy, FitRes]]) The clients’ fit results.

  • failures (list[tuple[ClientProxy, FitRes] | BaseException]) – (list[tuple[ClientProxy, FitRes] | BaseException]) the clients’ fit failures.

Return type:

tuple[Parameters | None, dict[str, Union[bool, bytes, float, int, str]]]

Returns:

(tuple[Parameters | None, dict[str, Scalar]]) A tuple containing the aggregated parameters

and the aggregated fit metrics.

configure_evaluate(server_round, parameters, client_manager)[source]

Configure the next round of evaluation.

Return type:

list[tuple[ClientProxy, EvaluateIns]]

configure_fit(server_round, parameters, client_manager)[source]

Configure the next round of training.

Will also collect the number of rounds the training will run for in order to calculate the adjustment weight step size. Fails if n_server_rounds is not set in the config or if it’s not an integer.

Parameters:
  • server_round (int) – (int) the current server round.

  • parameters (Parameters) – (Parameters) the model parameters.

  • client_manager (ClientManager) – (ClientManager) The client manager which holds all currently connected clients. It must be an instance of FixedSamplingClientManager.

Return type:

list[tuple[ClientProxy, FitIns]]

Returns:

(list[tuple[ClientProxy, FitIns]]) the input for the clients’ fit function.

get_current_weight_step_size(server_round)[source]

Calculates the current weight step size based on the current server round, weight step size and total number of rounds.

Parameters:

server_round (int) – (int) the current server round

Return type:

float

Returns: (float) the current value for the weight step size.

update_weights_by_ga(server_round, cids)[source]

Update the self.adjustment_weights dictionary by calculating the new weights based on the current server round, fit and evaluation metrics.

Parameters:
  • server_round (int) – (int) the current server round.

  • cids (list[str]) – (list[str]) the list of client ids that participated in this round.

Return type:

None

weight_and_aggregate_results(results)[source]

Aggregate results by weighing them against the adjustment weights and then summing them.

Parameters:

results (list[tuple[ClientProxy, FitRes]]) – (list[tuple[ClientProxy, FitRes]]) The clients’ fit results.

Return type:

List[ndarray[Any, dtype[Any]]]

Returns:

(NDArrays) the weighted and aggregated results.

exception SignalForTypeException[source]

Bases: Exception

Thrown when there is an error in signal_for_type function.