florist.api.models.mnist module

Definitions for the MNIST model.

class MnistNet[source]

Bases: LocalDataModel

Implementation of the Mnist model.

__init__()[source]

Initialize an instance of MnistNet.

forward(x)[source]

Perform a forward pass for the given tensor.

Parameters:

x (Tensor) – (torch.Tensor) the tensor to perform the forward pass on.

Return type:

Tensor

Returns:

(torch.Tensor) a result tensor after the forward pass.

get_criterion()[source]

Return the loss for MNIST model.

Return type:

_Loss

Returns:

(torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.

get_data_loaders(data_path, batch_size, sampler=None)[source]

Return the data loader for MNIST data.

Parameters:
  • data_path (Path) – (Path) the local path of the data.

  • batch_size (int) – (int) the batch size for training.

  • sampler (Optional[LabelBasedSampler]) – (Optional[LabelBasedSampler]) the sampler to be used to sample data.

Return type:

tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]

Returns:

(Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader and validation data loader respectively.