fl4health.strategies.model_merge_strategy module¶
- class ModelMergeStrategy(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, weighted_aggregation=True)[source]¶
Bases:
Strategy
- __init__(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, weighted_aggregation=True)[source]¶
- Model Merging strategy in which weights are loaded from clients, averaged (weighted or unweighted)
and redistributed to the clients for evaluation.
- Parameters:
fraction_fit (float, optional) – Fraction of clients used during training. In case min_fit_clients is larger than fraction_fit * available_clients, min_fit_clients will still be sampled. Defaults to 1.0.
fraction_evaluate (float, optional) – Fraction of clients used during validation. In case min_evaluate_clients is larger than fraction_evaluate * available_clients, min_evaluate_clients will still be sampled. Defaults to 1.0.
min_fit_clients (int, optional) – _description_. Defaults to 2.
min_evaluate_clients (int, optional) – Minimum number of clients used during validation. Defaults to 2.
min_available_clients (int, optional) – Minimum number of total clients in the system. Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None) – Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure client-side validation by providing a Config dictionary. Defaults to None.
accept_failures (bool, optional) – Whether or not accept rounds containing failures. Defaults to True.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None. counts. Defaults to True.
weighted_aggregation (bool, optional) – Determines whether parameter aggregation is a linearly weighted average or a uniform average. Important to note that weighting is based on number of samples in the test dataset for the ModelMergeStrategy. Defaults to True.
- aggregate_evaluate(server_round, results, failures)[source]¶
- Aggregate the metrics returned from the clients as a result of the evaluation round.
ModelMergeStrategy assumes only metrics will be computed on client and loss is set to None.
- Parameters:
results (list[tuple[ClientProxy, EvaluateRes]]) – The client identifiers and the results of their local evaluation that need to be aggregated on the server-side. These results are loss values (None in this case) and the metrics dictionary.
failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]) – These are the results and exceptions from clients that experienced an issue during evaluation, such as timeouts or exceptions.
- Returns:
- Aggregated loss values and the aggregated metrics. The metrics
are aggregated according to evaluate_metrics_aggregation_fn.
- Return type:
- aggregate_fit(server_round, results, failures)[source]¶
Performs model merging by taking an unweighted average of client weights and metrics.
- Parameters:
server_round (int) – Indicates the server round we’re currently on. Only one round for ModelMergeStrategy.
results (list[tuple[ClientProxy, FitRes]]) – The client identifiers and the results of their local fit that need to be aggregated on the server-side.
failures (list[tuple[ClientProxy, FitRes] | BaseException]) – These are the results and exceptions from clients that experienced an issue during fit, such as timeouts or exceptions.
- Returns:
The aggregated model weights and the metrics dictionary.
- Return type:
- configure_evaluate(server_round, parameters, client_manager)[source]¶
Sample and configure clients for a evaluation round.
- Parameters:
server_round (int) – Indicates the server round we’re currently on. Only one round for ModelMergeStrategy
parameters (Parameters) – The parameters to be used to initialize the clients for the eval round. This will only occur following model merging.
client_manager (ClientManager) – The manager used to sample from the available clients.
- Returns:
- List of sampled client identifiers and the configuration/parameters
to be sent to each client (packaged as EvaluateIns).
- Return type:
- configure_fit(server_round, parameters, client_manager)[source]¶
Sample and configure clients for a fit round.
- In ModelMergeStrategy, it is assumed that server side parameters are empty and clients will
be initialized with their weights locally.
- Parameters:
server_round (int) – Indicates the server round we’re currently on.
parameters (Parameters) – Not used.
client_manager (ClientManager) – The manager used to sample from the available clients.
- Returns:
- List of sampled client identifiers and the configuration/parameters to
be sent to each client (packaged as FitIns).
- Return type:
- evaluate(server_round, parameters)[source]¶
- Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized
(i.e., server-side) evaluation of model parameters.
- Parameters:
server_round (int) – Server round. Only one round in ModelMergeStrategy.
parameters (
Parameters
) – Parameters The current model parameters after merging has occurred.
- Returns:
- A Tuple containing loss and a
dictionary containing task-specific metrics (e.g., accuracy).
- Return type: