florist.api.models.abstract module

Abstract model classes.

class LocalDataModel(*args, **kwargs)[source]

Bases: Module, ABC

Abstract class for a model that has its data stored locally.

abstract get_criterion()[source]

Return the loss function for this model.

Return type:

_Loss

Returns:

(torch.nn.modules.loss._Loss) the loss function for this model.

abstract get_data_loaders(data_path, batch_size, sampler=None)[source]

Return the data loader for the model with local data.

Parameters:
  • data_path (Path) – (Path) the local path of the data.

  • batch_size (int) – (int) the batch size for training.

  • sampler (Optional[LabelBasedSampler]) – (Optional[LabelBasedSampler]) the sampler to be used to sample data.

Return type:

tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]

Returns:

(Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]) a tuple with the train data loader and validation data loader respectively.