florist.api.models.mnist module¶
Definitions for the MNIST model.
- class MnistNet[source]¶
Bases:
LocalDataModel
Implementation of the Mnist model.
- forward(x)[source]¶
Perform a forward pass for the given tensor.
- Parameters:
x (
Tensor
) – (torch.Tensor) the tensor to perform the forward pass on.- Return type:
Tensor
- Returns:
(torch.Tensor) a result tensor after the forward pass.
- get_criterion()[source]¶
Return the loss for MNIST model.
- Return type:
_Loss
- Returns:
(torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
- get_data_loaders(data_path, batch_size, sampler=None)[source]¶
Return the data loader for MNIST data.
- Parameters:
- Return type:
tuple
[DataLoader
[TensorDataset
],DataLoader
[TensorDataset
]]- Returns:
(Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader and validation data loader respectively.