fl4health.clients.model_merge_client module¶
- class ModelMergeClient(data_path, model_path, metrics, device, reporters=None, client_name=None)[source]¶
Bases:
NumPyClient
- __init__(data_path, model_path, metrics, device, reporters=None, client_name=None)[source]¶
ModelMergeClient
to support functionality to simply perform model merging across client models and subsequently evaluate.- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
model_path (Path) – path to the checkpoint of the client model to be used in model merging.
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda”
reporters (Sequence[BaseReporter], optional) – A sequence of FL4Health reporters which the client should send data to.
client_name (str) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated.
- evaluate(parameters, config)[source]¶
Evaluate the provided parameters using the locally held dataset.
- Parameters:
parameters (NDArrays) – The current model parameters.
config (Config) – Configuration object from the server.
- Returns:
The float represents the loss which is assumed to be 0 for the
ModelMergeClient
. The int represents the number of examples in the local test dataset and the dictionary is the computed metrics on the test set.- Return type:
- fit(parameters, config)[source]¶
Initializes client, validates local client model on local test data and returns parameters, test dataset length and test metrics. Importantly, parameters from Server, which is empty, is not used to initialized the client model.
NOTE: Since we only assume the client provides a
test_loader
, client evaluation and sample counts are always based off the clienttest_loader
.- Parameters:
parameters (NDArrays) – Not used.
config (NDArrays) – The config from the server.
- Returns:
The local model parameters along with the number of samples in the local test dataset and the computed metrics of the local model on the local test dataset.
- Return type:
- Raises:
AssertionError – If model is initialized prior to fit method being called which should not happen in the case of the
ModelMergeClient
.
- abstract get_model(config)[source]¶
User defined method that returns PyTorch model. This is the local model that will be communicated to the server for merging.
- Parameters:
config (Config) – The config from the server.
- Returns:
The client model.
- Return type:
nn.Module
- Raises:
NotImplementedError – To be defined in child class.
- get_parameter_exchanger(config)[source]¶
Parameter exchange is assumed to always be full for model merging clients. However, this functionality may be overridden if a different exchanger is needed.
- Used in non-standard way for
ModelMergeClient
asset_parameters
is only called for evaluate as parameters should initially be set to the parameters in the nn.Module returned by
get_model
.
- Parameters:
config (Config) – Configuration object from the server.
- Returns:
The parameter exchanger used to set and get parameters.
- Return type:
- Used in non-standard way for
- get_parameters(config)[source]¶
Determines which parameters are sent back to the server for aggregation. This uses a parameter exchanger to determine parameters sent.
For the
ModelMergeClient
, we assume thatself.setup_client
has already been called as it does not support client polling so get_parameters is called from fit and thus should be initialized by this point.- Parameters:
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Returns:
These are the parameters to be sent to the server. At minimum they represent the relevant model parameters to be aggregated, but can contain more information.
- Return type:
NDArrays
- abstract get_test_data_loader(config)[source]¶
User defined method that returns a PyTorch Test DataLoader.
- Parameters:
config (Config) – The config from the server.
- Returns:
Client test data loader.
- Return type:
DataLoader
- set_parameters(parameters, config)[source]¶
- Sets the local model parameters transferred from the server using a parameter exchanger
to coordinate how parameters are set.
- For the ModelMergeClient, we assume that initially parameters are being set to the parameters
in the nn.Module returned by the user defined get_model method. Thus, set_parameters is only called once after model merging has occurred and before federated evaluation.
- Parameters:
parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model but may contain more information than that.
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Return type: