Source code for florist.api.models.models

"""Functions and definitions for models and the Model enumeration."""

from enum import Enum

from florist.api.models.abstract import LocalDataModel
from florist.api.models.mnist import MnistNet


[docs] class Model(Enum): """Enumeration of supported models.""" MNIST = "MNIST"
[docs] def get_model_class(self) -> type[LocalDataModel]: """ Return the class for this model. :return: (type[LocalDataModel]) A LocalDataModel class corresponding to the model. :raises ValueError: if the model is not supported. """ if self == Model.MNIST: return MnistNet raise ValueError(f"Model {self.value} not supported.")
[docs] @classmethod def list(cls) -> list[str]: """ List all the supported models. :return: (list[str]) a list of supported models. """ return [model.value for model in Model]