Source code for fl4health.utils.typing

import logging
from collections.abc import Callable
from enum import Enum

import torch
import torch.nn as nn
from flwr.common import EvaluateRes, FitRes
from flwr.common.typing import NDArrays
from flwr.server.client_proxy import ClientProxy

TorchInputType = torch.Tensor | dict[str, torch.Tensor]
TorchTargetType = torch.Tensor | dict[str, torch.Tensor]
TorchPredType = dict[str, torch.Tensor]
TorchFeatureType = dict[str, torch.Tensor]
TorchTransformFunction = Callable[[torch.Tensor], torch.Tensor]
LayerSelectionFunction = Callable[[nn.Module, nn.Module | None], tuple[NDArrays, list[str]]]

FitFailures = list[tuple[ClientProxy, FitRes] | BaseException]
EvaluateFailures = list[tuple[ClientProxy, EvaluateRes] | BaseException]


[docs] class LogLevel(Enum): NOTSET = logging.NOTSET DEBUG = logging.DEBUG INFO = logging.INFO WARNING = logging.WARNING ERROR = logging.ERROR CRITICAL = logging.CRITICAL