Source code for fl4health.utils.peft_parameter_extraction
import torch.nn as nn
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Parameters
from peft import get_peft_model_state_dict
[docs]
def get_all_peft_parameters_from_model(model: nn.Module) -> Parameters:
"""
Function to extract peft parameters associated with a pytorch module. These values are converted
from numpy arrays into a Flower Parameters object.
Args:
model (nn.Module): PyTorch model whose parameters are to be extracted
Returns:
Parameters: Flower Parameters object containing all of the target models state.
"""
state_dict = get_peft_model_state_dict(model)
return ndarrays_to_parameters([val.cpu().numpy() for _, val in state_dict.items()])