Source code for florist.api.servers.config_parsers

"""Parsers for FL server configurations."""

import json
from ast import literal_eval
from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List


[docs] class BasicConfigParser: """Parser for basic server configurations."""
[docs] @classmethod def mandatory_fields(cls) -> List[str]: """ Define the mandatory fields for basic configuration, namely `n_server_rounds`, `batch_size` and `local_epochs`. :return: (List[str]) the list of fields for basic server configuration. """ return ["n_server_rounds", "batch_size", "local_epochs"]
[docs] @classmethod def parse(cls, config_json_str: str) -> Dict[str, Any]: """ Parse a configuration JSON string into a dictionary. :param config_json_str: (str) the configuration JSON string :return: (Dict[str, Any]) The configuration JSON string parsed as a dictionary. """ config = json.loads(config_json_str) assert isinstance(config, dict), "config is not a dictionary" for config_name in config: # converting the value to number if it is a number # if it throws an exception it means it's not a number, so suppress and leave as is with suppress(Exception): config[config_name] = literal_eval(config[config_name]) mandatory_fields = cls.mandatory_fields() for mandatory_field in mandatory_fields: if mandatory_field not in config: raise IncompleteConfigError(f"Server config does not contain '{mandatory_field}'") return config
[docs] class ConfigParser(Enum): """Enum to define the types of server configuration parsers.""" BASIC = "BASIC"
[docs] @classmethod def class_for_parser(cls, config_parser: "ConfigParser") -> type[BasicConfigParser]: """ Return the class for a given config parser. :param config_parser: (ConfigParser) The config parser enumeration instance. :return: (type[BasicConfigParser]) A subclass of BasicConfigParser corresponding to the given config parser. :raises ValueError: if the config_parser is not supported. """ if config_parser == ConfigParser.BASIC: return BasicConfigParser raise ValueError(f"Config parser {config_parser.value} not supported.")
[docs] @classmethod def list(cls) -> List[str]: """ List all the supported config parsers. :return: (List[str]) a list of supported config parsers. """ return [config_parser.value for config_parser in ConfigParser]
[docs] class IncompleteConfigError(Exception): """Defines errors in server config strings that have incomplete information.""" pass