Source code for fl4health.strategies.aggregate_utils

from functools import reduce

import numpy as np
from flwr.common import NDArrays
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


[docs] def aggregate_results(results: list[tuple[NDArrays, int]], weighted: bool = True) -> NDArrays: """ Compute weighted or unweighted average. Args: results (list[tuple[NDArrays, int]]): This is a set of NDArrays (list of numpy arrays) and the number of relevant samples from each client (training or validation samples where appropriate). These are to be aggregated together in a weighted or unweighted average. The NDArrays most often represent model states. weighted (bool, optional): Whether or not the aggregation is a weighted average (by the sample counts provided in the tuple) or a uniform average. Defaults to True. Returns: NDArrays: Aggregated numpy arrays by the desired averaging. """ if weighted: # Uses the underlying flwr aggregation scheme return aggregate(results) else: # Number of client weights to average num_clients = len(results) # Create a list of weights, each multiplied by 1/num_clients weighted_weights = [[layer * (1.0 / num_clients) for layer in weights] for weights, _ in results] # Compute unweighted average by summing up across clients for each layer. return [reduce(np.add, layer_updates) for layer_updates in zip(*weighted_weights)]
[docs] def aggregate_losses(results: list[tuple[int, float]], weighted: bool = True) -> float: """ Aggregate evaluation results obtained from multiple clients. Args: results (list[tuple[int, float]]): A list of sample counts and loss values (in that order). The sample counts from each client (training or validation samples where appropriate) are used if weighted averaging is requested. weighted (bool, optional): Whether or not the aggregation is a weighted average (by the sample counts provided in the tuple) or a uniform average. Defaults to True. Returns: float: the weighted or unweighted average of the loss values in the results list. """ # Sorting the results by the loss values for numerical fluctuation determinism of the sum results = sorted(results, key=lambda x: x[1]) if weighted: # uses flwr implementation of weighted loss averaging return weighted_loss_avg(results) else: # standard averaging return sum([loss for _, loss in results]) / len(results)