Source code for florist.api.models.mnist

"""Definitions for the MNIST model."""

import torch
import torch.nn.functional as f
from torch import nn


[docs] class MnistNet(nn.Module): """Implementation of the Mnist model."""
[docs] def __init__(self) -> None: """Initialize an instance of MnistNet.""" super().__init__() self.conv1 = nn.Conv2d(1, 8, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(8, 16, 5) self.fc1 = nn.Linear(16 * 4 * 4, 120) self.fc2 = nn.Linear(120, 10)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform a forward pass for the given tensor. :param x: (torch.Tensor) the tensor to perform the forward pass on. :return: (torch.Tensor) a result tensor after the forward pass. """ x = self.pool(f.relu(self.conv1(x))) x = self.pool(f.relu(self.conv2(x))) x = x.view(-1, 16 * 4 * 4) x = f.relu(self.fc1(x)) return f.relu(self.fc2(x))