fl4health.strategies.model_merge_strategy module¶
- class ModelMergeStrategy(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, weighted_aggregation=True)[source]¶
Bases:
Strategy
- __init__(*, fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=2, evaluate_fn=None, on_fit_config_fn=None, on_evaluate_config_fn=None, accept_failures=True, fit_metrics_aggregation_fn=None, evaluate_metrics_aggregation_fn=None, weighted_aggregation=True)[source]¶
Model Merging strategy in which weights are loaded from clients, averaged (weighted or unweighted) and redistributed to the clients for evaluation.
- Parameters:
fraction_fit (float, optional) – Fraction of clients used during training. In case
min_fit_clients
is larger thanfraction_fit * available_clients
,min_fit_clients
will still be sampled. Defaults to 1.0.fraction_evaluate (float, optional) – Fraction of clients used during validation. In case
min_evaluate_clients
is larger thanfraction_evaluate * available_clients
,min_evaluate_clients
will still be sampled. Defaults to 1.0.min_fit_clients (int, optional) – Minimum number of clients used during training. Defaults to 2
min_evaluate_clients (int, optional) – Minimum number of clients used during validation. Defaults to 2.
min_available_clients (int, optional) – Minimum number of total clients in the system. Defaults to 2.
evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None) – Optional function used for central server-side evaluation. Defaults to None.
on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure training by providing a configuration dictionary. Defaults to None.
on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional) – Function used to configure client-side validation by providing a
Config
dictionary. Defaults to None.accept_failures (bool, optional) – Whether or not accept rounds containing failures. Defaults to True.
fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None.
evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional) – Metrics aggregation function. Defaults to None.
weighted_aggregation (bool, optional) – Determines whether parameter aggregation is a linearly weighted average or a uniform average. Important to note that weighting is based on number of samples in the test dataset for the
ModelMergeStrategy
. Defaults to True.
- aggregate_evaluate(server_round, results, failures)[source]¶
Aggregate the metrics returned from the clients as a result of the evaluation round.
ModelMergeStrategy
assumes only metrics will be computed on client and loss is set to None.- Parameters:
results (list[tuple[ClientProxy, EvaluateRes]]) – The client identifiers and the results of their local evaluation that need to be aggregated on the server-side. These results are loss values (None in this case) and the metrics dictionary.
failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]) – These are the results and exceptions from clients that experienced an issue during evaluation, such as timeouts or exceptions.
- Returns:
Aggregated loss values and the aggregated metrics. The metrics are aggregated according to
evaluate_metrics_aggregation_fn
.- Return type:
- aggregate_fit(server_round, results, failures)[source]¶
Performs model merging by taking an unweighted average of client weights and metrics.
- Parameters:
server_round (int) – Indicates the server round we’re currently on. Only one round for
ModelMergeStrategy
.results (list[tuple[ClientProxy, FitRes]]) – The client identifiers and the results of their local fit that need to be aggregated on the server-side.
failures (list[tuple[ClientProxy, FitRes] | BaseException]) – These are the results and exceptions from clients that experienced an issue during fit, such as timeouts or exceptions.
- Returns:
The aggregated model weights and the metrics dictionary.
- Return type:
- configure_evaluate(server_round, parameters, client_manager)[source]¶
Sample and configure clients for a evaluation round.
- Parameters:
server_round (int) – Indicates the server round we’re currently on. Only one round for
ModelMergeStrategy
parameters (Parameters) – The parameters to be used to initialize the clients for the eval round. This will only occur following model merging.
client_manager (ClientManager) – The manager used to sample from the available clients.
- Returns:
List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as
EvaluateIns
).- Return type:
- configure_fit(server_round, parameters, client_manager)[source]¶
Sample and configure clients for a fit round.
In
ModelMergeStrategy
, it is assumed that server side parameters are empty and clients will be initialized with their weights locally.- Parameters:
server_round (int) – Indicates the server round we’re currently on.
parameters (Parameters) – Not used.
client_manager (ClientManager) – The manager used to sample from the available clients.
- Returns:
List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as
FitIns
).- Return type:
- evaluate(server_round, parameters)[source]¶
Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized (i.e., server-side) evaluation of model parameters.
- Parameters:
server_round (int) – Server round. Only one round in
ModelMergeStrategy
.parameters (
Parameters
) – Parameters The current model parameters after merging has occurred.
- Returns:
A Tuple containing loss and a dictionary containing task-specific metrics (e.g., accuracy).
- Return type: