fl4health.utils.parameter_extraction module

check_shape_match(params1, params2, error_message)[source]

Check if the shapes of parameters from two models match.

Parameters:
  • params1 (Iterable[torch.Tensor]) – Parameters from the first model.

  • params2 (Iterable[torch.Tensor]) – Parameters from the second model.

  • error_message (str) – Error message to display if the shapes do not match.

Return type:

None

get_all_model_parameters(model)[source]

Function to extract ALL parameters associated with a pytorch module, including any state parameters. These values are converted from numpy arrays into a Flower Parameters object.

Parameters:

model (nn.Module) – PyTorch model whose parameters are to be extracted

Returns:

Flower Parameters object containing all of the target models state.

Return type:

Parameters