Source code for florist.api.models.mnist
"""Definitions for the MNIST model."""
import torch
import torch.nn.functional as f
from torch import nn
[docs]
class MnistNet(nn.Module):
"""Implementation of the Mnist model."""
[docs]
def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.
:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))