Skip to content

Huggingface

inspect_trainer_signature

inspect_trainer_signature(fn)
Source code in src/fed_rag/inspectors/huggingface/trainer.py
def inspect_trainer_signature(fn: Callable) -> TrainerSignatureSpec:
    sig = inspect.signature(fn)

    # validate return type
    return_type = sig.return_annotation
    if (return_type is Any) or not issubclass(return_type, TrainResult):
        msg = "Trainer should return a fed_rag.types.TrainResult or a subclass of it."
        raise InvalidReturnType(msg)

    # inspect fn params
    extra_train_kwargs = []
    net_param = None
    train_data_param = None
    val_data_param = None
    net_parameter_class_name = None

    for name, t in sig.parameters.items():
        if name in ("self", "cls"):
            continue

        if type_name := get_type_name(t):
            if (
                type_name
                in [
                    "SentenceTransformer",
                    "PreTrainedModel",
                    "PeftModel",
                    "HFModelType",
                ]  # TODO: should accept union types involving these two
                and net_param is None
            ):
                net_param = name
                net_parameter_class_name = type_name
                continue

            if type_name == "Dataset" and train_data_param is None:
                train_data_param = name
                continue

            if type_name == "Dataset" and val_data_param is None:
                val_data_param = name
                continue

        extra_train_kwargs.append(name)

    if net_param is None:
        msg = (
            "Inspection failed to find a model param. "
            "For HuggingFace this param must have type `PreTrainedModel` or `SentenceTransformers`."
        )
        raise MissingNetParam(msg)

    if train_data_param is None:
        msg = (
            "Inspection failed to find two data params for train and val datasets."
            "For HuggingFace these params must be of type `datasets.Dataset`"
        )
        raise MissingMultipleDataParams(msg)

    if val_data_param is None:
        msg = (
            "Inspection found one data param but failed to find another. "
            "Two data params are required for train and val datasets."
            "For HuggingFace these params must be of type `datasets.Dataset`"
        )
        raise MissingDataParam(msg)

    spec = TrainerSignatureSpec(
        net_parameter=net_param,
        train_data_param=train_data_param,
        val_data_param=val_data_param,
        extra_train_kwargs=extra_train_kwargs,
        net_parameter_class_name=net_parameter_class_name,
    )
    return spec

inspect_tester_signature

inspect_tester_signature(fn)
Source code in src/fed_rag/inspectors/huggingface/tester.py
def inspect_tester_signature(fn: Callable) -> TesterSignatureSpec:
    sig = inspect.signature(fn)

    # validate return type
    return_type = sig.return_annotation
    if (return_type is Any) or not issubclass(return_type, TestResult):
        msg = "Tester should return a fed_rag.types.TestResult or a subclass of it."
        raise InvalidReturnType(msg)

    # inspect fn params
    extra_tester_kwargs = []
    net_param = None
    test_data_param = None
    net_parameter_class_name = None

    for name, t in sig.parameters.items():
        if name in ("self", "cls"):
            continue

        if type_name := get_type_name(t):
            if (
                type_name
                in [
                    "PreTrainedModel",
                    "SentenceTransformer",
                    "PeftModel",
                    "HFModelType",
                ]
                and net_param is None
            ):
                net_param = name
                net_parameter_class_name = type_name
                continue

            if type_name == "Dataset" and test_data_param is None:
                test_data_param = name
                continue

        extra_tester_kwargs.append(name)

    if net_param is None:
        msg = (
            "Inspection failed to find a model param. "
            "For HuggingFace this param must have type `PreTrainedModel` or `SentenceTransformers`."
        )
        raise MissingNetParam(msg)

    if test_data_param is None:
        msg = (
            "Inspection failed to find a data param for a test dataset."
            "For HuggingFace these params must be of type `datasets.Dataset`"
        )
        raise MissingDataParam(msg)

    spec = TesterSignatureSpec(
        net_parameter=net_param,
        test_data_param=test_data_param,
        extra_test_kwargs=extra_tester_kwargs,
        net_parameter_class_name=net_parameter_class_name,
    )
    return spec