Source code for fl4health.utils.metric_aggregation
from collections import defaultdict
from flwr.common.typing import Metrics
[docs]
def uniform_metric_aggregation(
all_client_metrics: list[tuple[int, Metrics]],
) -> tuple[defaultdict[str, int], Metrics]:
"""
Function that aggregates client metrics and divides by the number of clients that contributed to metric.
Args:
all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the
sample counts and metrics for each client.
Returns:
tuple[defaultdict[str, int], Metrics]: Client counts per metric and the uniformly aggregated metrics.
"""
aggregated_metrics: Metrics = {}
total_client_count_by_metric: defaultdict[str, int] = defaultdict(int)
# Run through all of the metrics
for _, client_metrics in all_client_metrics:
for metric_name, metric_value in client_metrics.items():
if isinstance(metric_value, float):
current_metric_value = aggregated_metrics.get(metric_name, 0.0)
assert isinstance(current_metric_value, float)
aggregated_metrics[metric_name] = current_metric_value + metric_value
total_client_count_by_metric[metric_name] += 1
elif isinstance(metric_value, int):
current_metric_value = aggregated_metrics.get(metric_name, 0)
assert isinstance(current_metric_value, int)
aggregated_metrics[metric_name] = current_metric_value + metric_value
total_client_count_by_metric[metric_name] += 1
else:
raise ValueError("Metric type is not supported")
# Compute average of each metric by dividing by number of clients contributing
uniform_normalize_metrics(total_client_count_by_metric, aggregated_metrics)
return total_client_count_by_metric, aggregated_metrics
[docs]
def metric_aggregation(
all_client_metrics: list[tuple[int, Metrics]],
) -> tuple[int, Metrics]:
"""
Function that computes a weighted aggregation of metrics normalized by the total number of samples.
Args:
all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the
sample counts and metrics for each client.
Returns:
tuple[int, Metrics]: The total number of examples along with aggregated metrics.
"""
aggregated_metrics: Metrics = {}
total_examples = 0
# Run through all of the metrics
for num_examples_on_client, client_metrics in all_client_metrics:
total_examples += num_examples_on_client
for metric_name, metric_value in client_metrics.items():
# Here we assume each metric is normalized by the number of examples on the client. So we scale up to
# get the "raw" value
if isinstance(metric_value, float):
current_metric_value = aggregated_metrics.get(metric_name, 0.0)
assert isinstance(current_metric_value, float)
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value
elif isinstance(metric_value, int):
current_metric_value = aggregated_metrics.get(metric_name, 0)
assert isinstance(current_metric_value, int)
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value
else:
raise ValueError("Metric type is not supported")
return total_examples, aggregated_metrics
[docs]
def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics:
"""
Function that normalizes metrics by provided sample count.
Args:
total_examples (int): The total number of samples across all client datasets.
aggregated_metrics (Metrics): Metrics that have been aggregated across clients.
Returns:
Metrics: The metrics normalized by total_examples.
"""
# Normalize all metric values by the total count of examples seen.
normalized_metrics: Metrics = {}
for metric_name, metric_value in aggregated_metrics.items():
if isinstance(metric_value, float) or isinstance(metric_value, int):
normalized_metrics[metric_name] = metric_value / total_examples
return normalized_metrics
[docs]
def uniform_normalize_metrics(
total_client_count_by_metric: defaultdict[str, int], aggregated_metrics: Metrics
) -> Metrics:
"""
Function that normalizes metrics based on how many clients contributed to the metric.
Args:
total_client_count_by_metric (defaultdict[str, int]): The count of clients that contributed to each metric.
aggregated_metrics (Metrics): Metrics that have been aggregated across clients.
Returns:
Metrics: The normalized metrics.
"""
# Normalize all metric values by the total count of clients that contributed to the metric.
normalized_metrics: Metrics = {}
for metric_name, metric_value in aggregated_metrics.items():
if isinstance(metric_value, float) or isinstance(metric_value, int):
normalized_metrics[metric_name] = metric_value / total_client_count_by_metric[metric_name]
return normalized_metrics
[docs]
def fit_metrics_aggregation_fn(
all_client_metrics: list[tuple[int, Metrics]],
) -> Metrics:
"""
Function for fit that computes a weighted aggregation of the client metrics
and normalizes by the total number of samples.
Args:
all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the
sample counts and metrics for each client.
Returns:
Metrics: The aggregated normalized metrics.
"""
# This function is run by the server to aggregate metrics returned by each clients fit function
# NOTE: The first value of the tuple is number of examples for FedAvg
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
return normalize_metrics(total_examples, aggregated_metrics)
[docs]
def evaluate_metrics_aggregation_fn(
all_client_metrics: list[tuple[int, Metrics]],
) -> Metrics:
"""
Function for evaluate that computes a weighted aggregation of the client metrics
and normalizes by the total number of samples.
Args:
all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the
sample counts and metrics for each client.
Returns:
Metrics: The aggregated normalized metrics.
"""
# This function is run by the server to aggregate metrics returned by each clients evaluate function
# NOTE: The first value of the tuple is number of examples for FedAvg
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
return normalize_metrics(total_examples, aggregated_metrics)
[docs]
def uniform_evaluate_metrics_aggregation_fn(
all_client_metrics: list[tuple[int, Metrics]],
) -> Metrics:
"""
Function for evaluate that computes aggregation of the client metrics and normalizes by the number
of clients that contributed to the metric.
Args:
all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the
sample counts and metrics for each client.
Returns:
Metrics: The aggregated normalized metrics.
"""
# This function is run by the server to aggregate metrics returned by each clients evaluate function
# NOTE: The first value of the tuple is number of examples for FedAvg, but it is not used here.
total_client_count_by_metric, aggregated_metrics = uniform_metric_aggregation(all_client_metrics)
return uniform_normalize_metrics(total_client_count_by_metric, aggregated_metrics)