Source code for florist.api.models.mnist
"""Definitions for the MNIST model."""
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as f
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.sampler import LabelBasedSampler
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from florist.api.models.abstract import LocalDataModel
[docs]
class MnistNet(LocalDataModel):
    """Implementation of the Mnist model."""
[docs]
    def __init__(self) -> None:
        """Initialize an instance of MnistNet."""
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 10)
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform a forward pass for the given tensor.
        :param x: (torch.Tensor) the tensor to perform the forward pass on.
        :return: (torch.Tensor) a result tensor after the forward pass.
        """
        x = self.pool(f.relu(self.conv1(x)))
        x = self.pool(f.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = f.relu(self.fc1(x))
        return f.relu(self.fc2(x))
[docs]
    def get_data_loaders(
        self,
        data_path: Path,
        batch_size: int,
        sampler: Optional[LabelBasedSampler] = None,
    ) -> tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]:
        """
        Return the data loader for MNIST data.
        :param data_path: (Path) the local path of the data.
        :param batch_size: (int) the batch size for training.
        :param sampler: (Optional[LabelBasedSampler]) the sampler to be used to sample data.
        :return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
            and validation data loader respectively.
        """
        # Removing LeCun's website from the list of mirrors to pull MNIST dataset from
        # as it is timing out and adding considerable time to our tests
        mirror_url_to_remove = "http://yann.lecun.com/exdb/mnist/"
        if mirror_url_to_remove in MNIST.mirrors:
            MNIST.mirrors.remove(mirror_url_to_remove)
        train_loader, val_loader, _ = load_mnist_data(data_path, batch_size, sampler)
        return train_loader, val_loader
[docs]
    def get_criterion(self) -> _Loss:
        """
        Return the loss for MNIST model.
        :return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
        """
        return torch.nn.CrossEntropyLoss()