fl4health.strategies.flash module

class Flash(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, eta=0.1, eta_l=0.1, beta_1=0.9, beta_2=0.99, tau=1e-09, weighted_aggregation=False, weighted_eval_losses=False)[source]

Bases: BasicFedAvg

__init__(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, initial_parameters, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, eta=0.1, eta_l=0.1, beta_1=0.9, beta_2=0.99, tau=1e-09, weighted_aggregation=False, weighted_eval_losses=False)[source]

Flash: Concept Drift Adaptation in Federated Learning.

Implementation based on https://proceedings.mlr.press/v202/panchal23a/panchal23a.pdf

Parameters:
  • fraction_fit (float) – float, optional Fraction of clients used during training. Defaults to 1.0.

  • fraction_evaluate (float) – float, optional Fraction of clients used during validation. Defaults to 1.0.

  • min_fit_clients (int) – int, optional Minimum number of clients used during training. Defaults to 2.

  • min_evaluate_clients (int) – int, optional Minimum number of clients used during validation. Defaults to 2.

  • min_available_clients (int) – int, optional Minimum number of total clients in the system. Defaults to 2.

  • evaluate_fn (Optional[Callable[[int, List[ndarray[Any, dtype[Any]]], dict[str, Union[bool, bytes, float, int, str]]], tuple[float, dict[str, Union[bool, bytes, float, int, str]]] | None]]) –

    Callable[[int, NDArrays, dict[str, Scalar] | None,

    tuple[float, dict[str, Scalar]]]] | None

    Optional function used for validation. Defaults to None.

  • on_fit_config_fn (Optional[Callable[[int], dict[str, Union[bool, bytes, float, int, str]]]]) – Callable[[int], dict[str, Scalar]], optional Function used to configure training. Defaults to None.

  • on_evaluate_config_fn (Optional[Callable[[int], dict[str, Union[bool, bytes, float, int, str]]]]) – Callable[[int], dict[str, Scalar]], optional Function used to configure validation. Defaults to None.

  • accept_failures (bool) – bool, optional Whether or not accept rounds containing failures. Defaults to True.

  • initial_parameters (Parameters) – Parameters Initial global model parameters.

  • fit_metrics_aggregation_fn (Optional[Callable[[List[Tuple[int, Dict[str, Union[bool, bytes, float, int, str]]]]], Dict[str, Union[bool, bytes, float, int, str]]]]) – MetricsAggregationFn | None Metrics aggregation function, optional.

  • evaluate_metrics_aggregation_fn (Optional[Callable[[List[Tuple[int, Dict[str, Union[bool, bytes, float, int, str]]]]], Dict[str, Union[bool, bytes, float, int, str]]]]) – MetricsAggregationFn | None Metrics aggregation function, optional.

  • eta (float) – float, optional Server-side learning rate. Defaults to 1e-1.

  • eta_l (float) – float, optional Client-side learning rate. Defaults to 1e-1.

  • beta_1 (float) – float, optional Momentum parameter. Defaults to 0.9.

  • beta_2 (float) – float, optional Second moment parameter. Defaults to 0.99.

  • tau (float) – float, optional Controls the algorithm’s degree of adaptability. Defaults to 1e-9.

  • weighted_aggregation (bool, optional) – Determines whether parameter aggregation is a linearly weighted average or a uniform average. Flash default is a uniform average by the number of clients. Defaults to False.

  • weighted_eval_losses (bool, optional) – Determines whether losses during evaluation are linearly weighted averages or a uniform average. Flash default is a uniform average of the losses by dividing the total loss by the number of clients. Defaults to False.

aggregate_fit(server_round, results, failures)[source]

Aggregate fit results using the Flash method.

Return type:

tuple[Parameters | None, dict[str, Union[bool, bytes, float, int, str]]]