fl4health.clients.evaluate_client module¶
- class EvaluateClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, model_checkpoint_path=None, reporters=None, client_name=None)[source]¶
Bases:
BasicClient
- __init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, model_checkpoint_path=None, reporters=None, client_name=None)[source]¶
This client implements an evaluation only flow. That is, there is no expectation of parameter exchange with the server past the model initialization stage. The implementing client should instantiate a global model if one is expected from the server, which will be loaded using the passed parameters. If a model checkpoint path is provided the client attempts to load a local model from the specified path.
- Parameters:
data_path (Path) – path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
model_checkpoint_path (Path | None, optional) – _description_. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.
client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Defaults to None.
- evaluate(parameters, config)[source]¶
Evaluates the model on the validation set, and test set (if defined).
- Parameters:
parameters (NDArrays) – The parameters of the model to be evaluated.
config (NDArrays) – The config object from the server.
- Returns:
- A loss associated with the evaluation, the number of samples in the
validation/test set and the metric_values associated with evaluation.
- Return type:
- fit(parameters, config)[source]¶
Processes config, initializes client (if first round) and performs training based on the passed config. If per_round_checkpointer is not None, on initialization the client checks if a checkpointed client state exists to load and at the end of each round the client state is saved.
- Parameters:
parameters (NDArrays) – The parameters of the model to be used in fit.
config (NDArrays) – The config from the server.
- Returns:
The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit.
- Return type:
- Raises:
ValueError – If local_steps or local_epochs is not specified in config.
- get_data_loader(config)[source]¶
User defined method that returns a PyTorch DataLoader for validation
- Return type:
tuple
[DataLoader
]
- get_local_model(config)[source]¶
Functionality for initializing a model from a local checkpoint. This can be overridden for custom behavior
- Return type:
Module
|None
- get_parameter_exchanger(config)[source]¶
Parameter exchange is assumed to always be full for evaluation only clients. If there are partial weights exchanged during training, we assume that the checkpoint has been saved locally. However, this functionality may be overridden if a different exchanger is needed
- Return type:
- get_parameters(config)[source]¶
Determines which parameters are sent back to the server for aggregation. This uses a parameter exchanger to determine parameters sent.
- Parameters:
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
- Returns:
- These are the parameters to be sent to the server. At minimum they represent the relevant model
parameters to be aggregated, but can contain more information.
- Return type:
NDArrays
- initialize_global_model(config)[source]¶
User defined method that to initializes a global model to potentially be hydrated by parameters sent by the server, by default, no global model is assumed to exist unless specified by the user
- Return type:
Module
|None
- set_parameters(parameters, config, fitting_round)[source]¶
Sets the local model parameters transferred from the server using a parameter exchanger to coordinate how parameters are set. In the first fitting round, we assume the full model is being initialized and use the FullParameterExchanger() to set all model weights. Otherwise, we use the appropriate parameter exchanger defined by the user depending on the federated learning algorithm being used.
- Parameters:
parameters (NDArrays) – Parameters have information about model state to be added to the relevant client model but may contain more information than that.
config (Config) – The config is sent by the FL server to allow for customization in the function if desired.
fitting_round (bool) – Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is only used if the current federated learning round is the very first fitting round.
- Return type:
- setup_client(config)[source]¶
Set dataloaders, parameter exchangers and other attributes for the client
- Return type: