Source code for florist.api.servers.config_parsers
"""Parsers for FL server configurations."""
import json
from ast import literal_eval
from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List
from typing_extensions import Self
[docs]
class BasicConfigParser:
    """Parser for basic server configurations."""
[docs]
    @classmethod
    def mandatory_fields(cls) -> List[str]:
        """
        Define the mandatory fields for basic server configuration.
        Namely: `n_server_rounds`, `batch_size` and `local_epochs`.
        :return: (List[str]) the list of required fields for basic server configuration.
        """
        return ["n_server_rounds", "batch_size", "local_epochs"]
[docs]
    @classmethod
    def parse(cls, config_json_str: str) -> Dict[str, Any]:
        """
        Parse a configuration JSON string into a dictionary.
        :param config_json_str: (str) the configuration JSON string
        :return: (Dict[str, Any]) The configuration JSON string parsed as a dictionary.
        """
        config = json.loads(config_json_str)
        assert isinstance(config, dict), "config is not a dictionary"
        for config_name in config:
            # converting the value to number if it is a number
            # if it throws an exception it means it's not a number, so suppress and leave as is
            with suppress(Exception):
                config[config_name] = literal_eval(config[config_name])
        mandatory_fields = cls.mandatory_fields()
        for mandatory_field in mandatory_fields:
            if mandatory_field not in config:
                raise IncompleteConfigError(f"Server config does not contain '{mandatory_field}'")
        return config
[docs]
class FedProxConfigParser(BasicConfigParser):
    """Parser for FedProx server configurations."""
[docs]
    @classmethod
    def mandatory_fields(cls) -> List[str]:
        """
        Define the mandatory fields for FedProx configuration.
        Namely: `n_server_rounds`, `adapt_proximal_weight`, `initial_proximal_weight`, `proximal_weight_delta`,
        `proximal_weight_patience`, `local_epochs` and `batch_size`.
        :return: (List[str]) the list of required fields for FedProx server configuration.
        """
        basic_fields = super().mandatory_fields()
        return basic_fields + [
            "adapt_proximal_weight",
            "initial_proximal_weight",
            "proximal_weight_delta",
            "proximal_weight_patience",
        ]
[docs]
class ConfigParser(Enum):
    """Enum to define the types of server configuration parsers."""
    BASIC = "BASIC"
    FEDPROX = "FEDPROX"
[docs]
    @classmethod
    def class_for_parser(cls, config_parser: Self) -> type[BasicConfigParser]:
        """
        Return the class for a given config parser.
        :param config_parser: (ConfigParser) The config parser enumeration instance.
        :return: (type[BasicConfigParser]) A subclass of BasicConfigParser corresponding to the given config parser.
        :raises ValueError: if the config_parser is not supported.
        """
        if config_parser == ConfigParser.BASIC:
            return BasicConfigParser
        if config_parser == ConfigParser.FEDPROX:
            return FedProxConfigParser
        raise ValueError(f"Config parser {config_parser.value} not supported.")
[docs]
    @classmethod
    def list(cls) -> List[str]:
        """
        List all the supported config parsers.
        :return: (List[str]) a list of supported config parsers.
        """
        return [config_parser.value for config_parser in ConfigParser]
[docs]
class IncompleteConfigError(Exception):
    """Defines errors in server config strings that have incomplete information."""
    pass