Source code for florist.api.servers.common
"""Common functions and definitions for servers."""
from enum import Enum
from typing import List
from torch import nn
from florist.api.models.mnist import MnistNet
[docs]
class Model(Enum):
"""Enumeration of supported models."""
MNIST = "MNIST"
[docs]
@classmethod
def class_for_model(cls, model: "Model") -> type[nn.Module]:
"""
Return the class for a given model.
:param model: (Model) The model enumeration object.
:return: (type[torch.nn.Module]) A torch.nn.Module class corresponding to the given model.
:raises ValueError: if the client is not supported.
"""
if model == Model.MNIST:
return MnistNet
raise ValueError(f"Model {model.value} not supported.")
[docs]
@classmethod
def list(cls) -> List[str]:
"""
List all the supported models.
:return: (List[str]) a list of supported models.
"""
return [model.value for model in Model]