def inspect_trainer_signature(fn: Callable) -> TrainerSignatureSpec:
    sig = inspect.signature(fn)
    # validate return type
    return_type = sig.return_annotation
    if (return_type is Any) or not issubclass(return_type, TrainResult):
        msg = "Trainer should return a fed_rag.data_structures.TrainResult or a subclsas of it."
        raise InvalidReturnType(msg)
    # inspect fn params
    extra_train_kwargs = []
    net_param = None
    train_data_param = None
    val_data_param = None
    net_parameter_class_name = None
    for name, t in sig.parameters.items():
        if name in ("self", "cls"):
            continue
        if type_name := getattr(t.annotation, "__name__", None):
            if type_name == "Module" and net_param is None:
                net_param = name
                net_parameter_class_name = type_name
                continue
            if type_name == "DataLoader" and train_data_param is None:
                train_data_param = name
                continue
            if type_name == "DataLoader" and val_data_param is None:
                val_data_param = name
                continue
        extra_train_kwargs.append(name)
    if net_param is None:
        msg = (
            "Inspection failed to find a model param. "
            "For PyTorch this param must have type `nn.Module`."
        )
        raise MissingNetParam(msg)
    if train_data_param is None:
        msg = (
            "Inspection failed to find two data params for train and val datasets."
            "For PyTorch these params must be of type `torch.utils.data.DataLoader`"
        )
        raise MissingMultipleDataParams(msg)
    if val_data_param is None:
        msg = (
            "Inspection found one data param but failed to find another. "
            "Two data params are required for train and val datasets."
            "For PyTorch these params must be of type `torch.utils.data.DataLoader`"
        )
        raise MissingDataParam(msg)
    spec = TrainerSignatureSpec(
        net_parameter=net_param,
        train_data_param=train_data_param,
        val_data_param=val_data_param,
        extra_train_kwargs=extra_train_kwargs,
        net_parameter_class_name=net_parameter_class_name,
    )
    return spec