Source code for fl4health.utils.parameter_extraction
from collections.abc import Iterable
import torch
import torch.nn as nn
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Parameters
[docs]
def get_all_model_parameters(model: nn.Module) -> Parameters:
"""
Function to extract **ALL** parameters associated with a pytorch module, including any state parameters. These
values are converted from numpy arrays into a Flower Parameters object.
Args:
model (nn.Module): PyTorch model whose parameters are to be extracted
Returns:
Parameters: Flower Parameters object containing all of the target models state.
"""
# Extracting all model parameters and converting to Parameters object
return ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])
[docs]
def check_shape_match(params1: Iterable[torch.Tensor], params2: Iterable[torch.Tensor], error_message: str) -> None:
"""
Check if the shapes of parameters from two models match.
Args:
params1 (Iterable[torch.Tensor]): Parameters from the first model.
params2 (Iterable[torch.Tensor]): Parameters from the second model.
error_message (str): Error message to display if the shapes do not match.
"""
params1_list = list(params1)
params2_list = list(params2)
# Check if the number of parameters match
assert len(params1_list) == len(
params2_list
), f"Parameter length mismatch: \
{len(params1_list)} vs {len(params2_list)}. {error_message}"
# Check if each corresponding parameter shape matches
for param1, param2 in zip(params1_list, params2_list):
assert param1.shape == param2.shape, error_message