Source code for florist.api.clients.mnist

"""Implementation of the MNIST client and model."""

from typing import Tuple

import torch
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from florist.api.models.mnist import MnistNet


[docs] class MnistClient(BasicClient): # type: ignore[misc] """Implementation of the MNIST client."""
[docs] def get_data_loaders(self, config: Config) -> Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]: """ Return the data loader for MNIST data. :param config: (Config) the Config object for this client. :return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader and validation data loader respectively. """ # Removing LeCun's website from the list of mirrors to pull MNIST dataset from # as it is timing out and adding considerable time to our tests mirror_url_to_remove = "http://yann.lecun.com/exdb/mnist/" if mirror_url_to_remove in MNIST.mirrors: MNIST.mirrors.remove(mirror_url_to_remove) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=int(config["batch_size"])) return train_loader, val_loader
[docs] def get_model(self, config: Config) -> nn.Module: """ Return the model for MNIST data. :param config: (Config) the Config object for this client. :return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet. """ return MnistNet()
[docs] def get_optimizer(self, config: Config) -> Optimizer: """ Return the optimizer for MNIST data. :param config: (Config) the Config object for this client. :return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning rate of 0.001 and momentum of 0.9. """ return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
[docs] def get_criterion(self, config: Config) -> _Loss: """ Return the loss for MNIST data. :param config: (Config) the Config object for this client. :return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss. """ return torch.nn.CrossEntropyLoss()