florist.api.clients.mnist module

Implementation of the MNIST client and model.

class MnistClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpointer=None, reporters=None, progress_bar=False, intermediate_client_state_dir=None, client_name=None)[source]

Bases: BasicClient

Implementation of the MNIST client.

get_criterion(config)[source]

Return the loss for MNIST data.

Parameters:

config (Dict[str, Union[bool, bytes, float, int, str]]) – (Config) the Config object for this client.

Return type:

_Loss

Returns:

(torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.

get_data_loaders(config)[source]

Return the data loader for MNIST data.

Parameters:

config (Dict[str, Union[bool, bytes, float, int, str]]) – (Config) the Config object for this client.

Return type:

Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]

Returns:

(Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader and validation data loader respectively.

get_model(config)[source]

Return the model for MNIST data.

Parameters:

config (Dict[str, Union[bool, bytes, float, int, str]]) – (Config) the Config object for this client.

Return type:

Module

Returns:

(torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet.

get_optimizer(config)[source]

Return the optimizer for MNIST data.

Parameters:

config (Dict[str, Union[bool, bytes, float, int, str]]) – (Config) the Config object for this client.

Return type:

Optimizer

Returns:

(torch.optim.Optimizer) An instance of torch.optim.SGD with learning rate of 0.001 and momentum of 0.9.