Source code for florist.api.models.mnist

"""Definitions for the MNIST model."""

from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as f
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.sampler import LabelBasedSampler
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from florist.api.models.abstract import LocalDataModel


[docs] class MnistNet(LocalDataModel): """Implementation of the Mnist model."""
[docs] def __init__(self) -> None: """Initialize an instance of MnistNet.""" super().__init__() self.conv1 = nn.Conv2d(1, 8, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(8, 16, 5) self.fc1 = nn.Linear(16 * 4 * 4, 120) self.fc2 = nn.Linear(120, 10)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform a forward pass for the given tensor. :param x: (torch.Tensor) the tensor to perform the forward pass on. :return: (torch.Tensor) a result tensor after the forward pass. """ x = self.pool(f.relu(self.conv1(x))) x = self.pool(f.relu(self.conv2(x))) x = x.view(-1, 16 * 4 * 4) x = f.relu(self.fc1(x)) return f.relu(self.fc2(x))
[docs] def get_data_loaders( self, data_path: Path, batch_size: int, sampler: Optional[LabelBasedSampler] = None, ) -> tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]: """ Return the data loader for MNIST data. :param data_path: (Path) the local path of the data. :param batch_size: (int) the batch size for training. :param sampler: (Optional[LabelBasedSampler]) the sampler to be used to sample data. :return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader and validation data loader respectively. """ # Removing LeCun's website from the list of mirrors to pull MNIST dataset from # as it is timing out and adding considerable time to our tests mirror_url_to_remove = "http://yann.lecun.com/exdb/mnist/" if mirror_url_to_remove in MNIST.mirrors: MNIST.mirrors.remove(mirror_url_to_remove) train_loader, val_loader, _ = load_mnist_data(data_path, batch_size, sampler) return train_loader, val_loader
[docs] def get_criterion(self) -> _Loss: """ Return the loss for MNIST model. :return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss. """ return torch.nn.CrossEntropyLoss()