Skip to content

Pytorch

inspect_trainer_signature

inspect_trainer_signature(fn)
Source code in src/fed_rag/inspectors/pytorch/trainer.py
def inspect_trainer_signature(fn: Callable) -> TrainerSignatureSpec:
    sig = inspect.signature(fn)

    # validate return type
    return_type = sig.return_annotation
    if (return_type is Any) or not issubclass(return_type, TrainResult):
        msg = "Trainer should return a fed_rag.types.TrainResult or a subclsas of it."
        raise InvalidReturnType(msg)

    # inspect fn params
    extra_train_kwargs = []
    net_param = None
    train_data_param = None
    val_data_param = None
    net_parameter_class_name = None

    for name, t in sig.parameters.items():
        if name in ("self", "cls"):
            continue

        if type_name := getattr(t.annotation, "__name__", None):
            if type_name == "Module" and net_param is None:
                net_param = name
                net_parameter_class_name = type_name
                continue

            if type_name == "DataLoader" and train_data_param is None:
                train_data_param = name
                continue

            if type_name == "DataLoader" and val_data_param is None:
                val_data_param = name
                continue

        extra_train_kwargs.append(name)

    if net_param is None:
        msg = (
            "Inspection failed to find a model param. "
            "For PyTorch this param must have type `nn.Module`."
        )
        raise MissingNetParam(msg)

    if train_data_param is None:
        msg = (
            "Inspection failed to find two data params for train and val datasets."
            "For PyTorch these params must be of type `torch.utils.data.DataLoader`"
        )
        raise MissingMultipleDataParams(msg)

    if val_data_param is None:
        msg = (
            "Inspection found one data param but failed to find another. "
            "Two data params are required for train and val datasets."
            "For PyTorch these params must be of type `torch.utils.data.DataLoader`"
        )
        raise MissingDataParam(msg)

    spec = TrainerSignatureSpec(
        net_parameter=net_param,
        train_data_param=train_data_param,
        val_data_param=val_data_param,
        extra_train_kwargs=extra_train_kwargs,
        net_parameter_class_name=net_parameter_class_name,
    )
    return spec

inspect_tester_signature

inspect_tester_signature(fn)
Source code in src/fed_rag/inspectors/pytorch/tester.py
def inspect_tester_signature(fn: Callable) -> TesterSignatureSpec:
    sig = inspect.signature(fn)

    # validate return type
    return_type = sig.return_annotation
    if (return_type is Any) or not issubclass(return_type, TestResult):
        msg = "Tester should return a fed_rag.types.TestResult or a subclsas of it."
        raise InvalidReturnType(msg)

    # inspect fn params
    extra_tester_kwargs = []
    net_param = None
    test_data_param = None
    net_parameter_class_name = None

    for name, t in sig.parameters.items():
        if name in ("self", "cls"):
            continue

        if type_name := getattr(t.annotation, "__name__", None):
            if type_name == "Module" and net_param is None:
                net_param = name
                net_parameter_class_name = type_name
                continue

            if type_name == "DataLoader" and test_data_param is None:
                test_data_param = name
                continue

        extra_tester_kwargs.append(name)

    if net_param is None:
        msg = (
            "Inspection failed to find a model param. "
            "For PyTorch this param must have type `nn.Module`."
        )
        raise MissingNetParam(msg)

    if test_data_param is None:
        msg = (
            "Inspection failed to find a data param for a test dataset."
            "For PyTorch this params must be of type `torch.utils.data.DataLoader`"
        )
        raise MissingDataParam(msg)

    spec = TesterSignatureSpec(
        net_parameter=net_param,
        test_data_param=test_data_param,
        extra_test_kwargs=extra_tester_kwargs,
        net_parameter_class_name=net_parameter_class_name,
    )
    return spec