Skip to content

Huggingface

HuggingFace PeftModel Generator

HFPeftModelGenerator

Bases: HuggingFaceGeneratorMixin, BaseGenerator

HFPeftModelGenerator Class.

NOTE: this class supports loading PeftModel's from HF Hub or from local. TODO: support loading custom models via a ~peft.Config and ~peft.get_peft_model

Source code in src/fed_rag/generators/huggingface/hf_peft_model.py
class HFPeftModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
    """HFPeftModelGenerator Class.

    NOTE: this class supports loading PeftModel's from HF Hub or from local.
    TODO: support loading custom models via a `~peft.Config` and `~peft.get_peft_model`
    """

    model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
    model_name: str = Field(
        description="Name of Peft model. Used for loading model from HF hub or local."
    )
    base_model_name: str = Field(
        description="Name of the frozen HuggingFace base model. Used for loading the model from HF hub or local."
    )
    generation_config: "GenerationConfig" = Field(
        description="The generation config used for generating with the PreTrainedModel."
    )
    load_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading peft model from HF. Defaults to None.",
        default_factory=dict,
    )
    load_base_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading base model from HF. Defaults to None.",
        default_factory=dict,
    )
    _prompt_template: str = PrivateAttr(default=DEFAULT_PROMPT_TEMPLATE)
    _model: Optional["PeftModel"] = PrivateAttr(default=None)
    _tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)

    def __init__(
        self,
        model_name: str,
        base_model_name: str,
        generation_config: Optional["GenerationConfig"] = None,
        prompt_template: str | None = None,
        load_model_kwargs: dict | None = None,
        load_base_model_kwargs: dict | None = None,
        load_model_at_init: bool = True,
    ):
        # if reaches here, then passed checks for huggingface extra installation
        from transformers.generation.utils import GenerationConfig

        generation_config = generation_config or GenerationConfig()
        super().__init__(
            model_name=model_name,
            base_model_name=base_model_name,
            generation_config=generation_config,
            prompt_template=prompt_template,
            load_model_kwargs=load_model_kwargs or {},
            load_base_model_kwargs=load_base_model_kwargs or {},
        )
        self._tokenizer = HFPretrainedTokenizer(
            model_name=base_model_name, load_model_at_init=load_model_at_init
        )
        self._prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
        if load_model_at_init:
            self._model = self._load_model_from_hf()

    @model_validator(mode="before")
    @classmethod
    def check_dependencies(cls, data: Any) -> Any:
        """Validate that huggingface dependencies are installed."""
        check_huggingface_installed(cls.__name__)
        return data

    def _load_model_from_hf(self, **kwargs: Any) -> "PeftModel":
        from peft import PeftModel, prepare_model_for_kbit_training
        from transformers import AutoModelForCausalLM

        self.load_model_kwargs.update(kwargs)  # update load_model_kwargs
        base_model = AutoModelForCausalLM.from_pretrained(
            self.base_model_name, **self.load_base_model_kwargs
        )

        if "quantization_config" in self.load_base_model_kwargs:
            # preprocess model for kbit fine-tuning
            # https://huggingface.co/docs/peft/developer_guides/quantization
            base_model = prepare_model_for_kbit_training(base_model)

        return PeftModel.from_pretrained(
            base_model, self.model_name, **self.load_model_kwargs
        )

    @property
    def model(self) -> "PeftModel":
        if self._model is None:
            # load HF PeftModel
            self._model = self._load_model_from_hf()
        return self._model

    @model.setter
    def model(self, value: "PeftModel") -> None:
        self._model = value

    @property
    def tokenizer(self) -> HFPretrainedTokenizer:
        return self._tokenizer

    @tokenizer.setter
    def tokenizer(self, value: HFPretrainedTokenizer) -> None:
        self._tokenizer = value

    @property
    def prompt_template(self) -> str:
        return self._prompt_template

    @prompt_template.setter
    def prompt_template(self, value: str) -> None:
        self._prompt_template = value

check_dependencies classmethod

check_dependencies(data)

Validate that huggingface dependencies are installed.

Source code in src/fed_rag/generators/huggingface/hf_peft_model.py
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, data: Any) -> Any:
    """Validate that huggingface dependencies are installed."""
    check_huggingface_installed(cls.__name__)
    return data

HuggingFace PretrainedModel Generator

HFPretrainedModelGenerator

Bases: HuggingFaceGeneratorMixin, BaseGenerator

Source code in src/fed_rag/generators/huggingface/hf_pretrained_model.py
class HFPretrainedModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
    model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
    model_name: str = Field(
        description="Name of HuggingFace model. Used for loading the model from HF hub or local."
    )
    generation_config: "GenerationConfig" = Field(
        description="The generation config used for generating with the PreTrainedModel."
    )
    load_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading models from HF. Defaults to None.",
        default_factory=dict,
    )
    _prompt_template: str = PrivateAttr(default=DEFAULT_PROMPT_TEMPLATE)
    _model: Optional["PreTrainedModel"] = PrivateAttr(default=None)
    _tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)

    def __init__(
        self,
        model_name: str,
        generation_config: Optional["GenerationConfig"] = None,
        prompt_template: str | None = None,
        load_model_kwargs: dict | None = None,
        load_model_at_init: bool = True,
    ):
        # if reaches here, then passed checks for extra
        from transformers.generation.utils import GenerationConfig

        generation_config = generation_config or GenerationConfig()
        super().__init__(
            model_name=model_name,
            generation_config=generation_config,
            load_model_kwargs=load_model_kwargs or {},
        )
        self._tokenizer = HFPretrainedTokenizer(
            model_name=model_name, load_model_at_init=load_model_at_init
        )
        self._prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
        if load_model_at_init:
            self._model = self._load_model_from_hf()

    @model_validator(mode="before")
    @classmethod
    def check_dependencies(cls, data: Any) -> Any:
        """Validate that huggingface dependencies are installed."""
        check_huggingface_installed(cls.__name__)
        return data

    def _load_model_from_hf(self, **kwargs: Any) -> "PreTrainedModel":
        from transformers import AutoModelForCausalLM

        self.load_model_kwargs.update(kwargs)
        return AutoModelForCausalLM.from_pretrained(
            self.model_name, **self.load_model_kwargs
        )

    @property
    def model(self) -> "PreTrainedModel":
        if self._model is None:
            # load HF Pretrained Model
            self._model = self._load_model_from_hf()
        return self._model

    @model.setter
    def model(self, value: "PreTrainedModel") -> None:
        self._model = value

    @property
    def tokenizer(self) -> HFPretrainedTokenizer:
        return self._tokenizer

    @tokenizer.setter
    def tokenizer(self, value: HFPretrainedTokenizer) -> None:
        self._tokenizer = value

    @property
    def prompt_template(self) -> str:
        return self._prompt_template

    @prompt_template.setter
    def prompt_template(self, value: str) -> None:
        self._prompt_template = value

check_dependencies classmethod

check_dependencies(data)

Validate that huggingface dependencies are installed.

Source code in src/fed_rag/generators/huggingface/hf_pretrained_model.py
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, data: Any) -> Any:
    """Validate that huggingface dependencies are installed."""
    check_huggingface_installed(cls.__name__)
    return data

HF Multimodal Model Generator

HFMultimodalModelGenerator

Bases: ImageModalityMixin, AudioModalityMixin, VideoModalityMixin, BaseGenerator

Source code in src/fed_rag/generators/huggingface/hf_multimodal_model.py
class HFMultimodalModelGenerator(
    ImageModalityMixin,
    AudioModalityMixin,
    VideoModalityMixin,
    BaseGenerator,
):
    model_config = ConfigDict(
        protected_namespaces=("pydantic_model_",), arbitrary_types_allowed=True
    )
    model_name: str = Field(description="HuggingFace model name or path.")
    modality_types: set[str] = Field(
        default_factory=lambda: {"text", "image", "audio", "video"}
    )
    generation_config: Optional[Any] = Field(default=None)
    load_model_kwargs: dict = Field(default_factory=dict)
    prompt_template_init: str | None = Field(default=None)
    load_model_at_init: bool = Field(default=True)

    _model: Optional["PreTrainedModel"] = PrivateAttr(default=None)
    _model_cls: Any = PrivateAttr(default=None)
    _processor: Any = PrivateAttr(default=None)
    _prompt_template: str = PrivateAttr(default="")

    @model_validator(mode="before")
    @classmethod
    def _check_hf_available(cls, data: Any) -> Any:
        check_huggingface_installed(cls.__name__)
        return data

    def __init__(self, **data: Any) -> None:
        super().__init__(**data)
        from transformers import AutoConfig, AutoProcessor, GenerationConfig

        self._prompt_template = self.prompt_template_init or ""
        self._processor = AutoProcessor.from_pretrained(self.model_name)
        cfg = AutoConfig.from_pretrained(self.model_name)
        if self.generation_config is None:
            self.generation_config = GenerationConfig()
        self._model_cls = self._detect_model_class(cfg)
        self._model = None
        if self.load_model_at_init:
            self._model = self._load_model_from_hf()

    def _load_model_from_hf(self) -> "PreTrainedModel":
        return self._model_cls.from_pretrained(
            self.model_name, **self.load_model_kwargs
        )

    @staticmethod
    def _detect_model_class(cfg: Any) -> Any:
        from transformers import AutoModel, AutoModelForImageTextToText

        if any(
            getattr(cfg, attr, None) is not None
            for attr in ("vision_config", "audio_config", "video_config")
        ):
            return AutoModelForImageTextToText
        if getattr(cfg, "architectures", None):
            if any("ImageTextToText" in arch for arch in cfg.architectures):
                return AutoModelForImageTextToText
        return AutoModel

    def to_query(self, q: str | Query | Prompt) -> Query:
        if isinstance(q, Query):
            return q
        if isinstance(q, Prompt):
            return Query(
                text=q.text,
                images=getattr(q, "images", None),
                audios=getattr(q, "audios", None),
                videos=getattr(q, "videos", None),
            )
        return Query(text=str(q))

    def to_context(self, c: str | Context | None) -> Context | None:
        if c is None or isinstance(c, Context):
            return c
        return Context(text=str(c))

    def _pack_messages(
        self,
        query: str | Query | list[str] | list[Query],
        context: str | Context | list[str] | list[Context] | None = None,
    ) -> list[dict[str, Any]]:
        queries = (
            [query] if not isinstance(query, list) else query  # type: ignore[arg-type]
        )
        queries = [self.to_query(q) for q in queries]

        if isinstance(context, list):
            contexts = [self.to_context(c) for c in context]
            if len(contexts) != len(queries):
                raise GeneratorError(
                    "Batch mode requires query and context to be the same length"
                )
        else:
            contexts = [self.to_context(context)] * len(queries)

        messages: list[dict[str, Any]] = []
        for q, ctx in zip(queries, contexts):
            content: list[dict[str, Any]] = []
            if ctx is not None:
                if getattr(ctx, "text", None):
                    content.append({"type": "text", "text": ctx.text})
                for im in getattr(ctx, "images", []) or []:
                    if isinstance(im, np.ndarray):
                        im = PILImage.fromarray(im)
                    content.append({"type": "image", "image": im})
                for au in getattr(ctx, "audios", []) or []:
                    content.append({"type": "audio", "audio": au})
                for vid in getattr(ctx, "videos", []) or []:
                    content.append({"type": "video", "video": vid})
            for im in getattr(q, "images", []) or []:
                if isinstance(im, np.ndarray):
                    im = PILImage.fromarray(im)
                content.append({"type": "image", "image": im})
            for au in getattr(q, "audios", []) or []:
                content.append({"type": "audio", "audio": au})
            for vid in getattr(q, "videos", []) or []:
                content.append({"type": "video", "video": vid})
            if getattr(q, "text", None):
                content.append({"type": "text", "text": q.text})

            messages.append({"role": "user", "content": content})
        return messages

    def complete(
        self,
        prompt: Prompt | list[Prompt] | str | list[str] | None = None,
        query: str | Query | list[str] | list[Query] | None = None,
        context: str | Context | list[str] | list[Context] | None = None,
        **kwargs: Any,
    ) -> str | list[str]:
        """Core generation method - contains the main generation logic."""
        max_new_tokens = kwargs.pop("max_new_tokens", 256)
        add_generation_prompt = kwargs.pop("add_generation_prompt", True)

        # Handle both prompt-only and query+context cases
        if prompt is not None:
            # Traditional complete() usage: convert prompt to query, no context
            messages = self._pack_messages(prompt, context=None)
            is_batch = isinstance(prompt, list)
        else:
            # Called from generate(): use query and context
            messages = self._pack_messages(query, context)
            is_batch = isinstance(query, list)

        inputs = self._processor.apply_chat_template(
            messages,
            add_generation_prompt=add_generation_prompt,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        )

        input_len = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            generation = self.model.generate(
                **inputs, max_new_tokens=max_new_tokens, **kwargs
            )
            generation = generation[:, input_len:]
        decoded: list[str] = self._processor.batch_decode(
            generation, skip_special_tokens=True
        )
        if not is_batch:
            if not decoded or not isinstance(decoded[0], str):
                raise GeneratorError(
                    "batch_decode did not return valid output"
                )
            return decoded[0]
        return decoded

    def generate(
        self,
        query: str | Query | list[str] | list[Query],
        context: str | Context | list[str] | list[Context] | None = None,
        **gen_kwargs: Any,
    ) -> str | list[str]:
        """Generate method - formats query+context and calls complete()."""
        return self.complete(query=query, context=context, **gen_kwargs)

    def compute_target_sequence_proba(
        self,
        prompt: Prompt | str,
        target: str,
        **kwargs: Any,
    ) -> torch.Tensor:
        q = self.to_query(prompt)
        base_text = getattr(q, "text", "") or ""
        full_text = base_text + target

        # Create a query with the full text for processing
        full_query = Query(
            text=full_text,
            images=getattr(q, "images", None),
            audios=getattr(q, "audios", None),
            videos=getattr(q, "videos", None),
        )

        # Reuse _pack_messages logic
        messages = self._pack_messages(full_query, context=None)
        inputs = self._processor.apply_chat_template(
            messages,
            add_generation_prompt=False,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"]

        # Create base prompt messages for length calculation
        base_query = Query(
            text=base_text,
            images=getattr(q, "images", None),
            audios=getattr(q, "audios", None),
            videos=getattr(q, "videos", None),
        )
        base_messages = self._pack_messages(base_query, context=None)
        prompt_inputs = self._processor.apply_chat_template(
            base_messages,
            add_generation_prompt=False,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )
        prompt_len = prompt_inputs["input_ids"].shape[-1]
        with torch.no_grad():
            outputs = self.model(**inputs)
        if not hasattr(outputs, "logits") or outputs.logits is None:
            raise GeneratorError(
                "Underlying model does not expose logits; cannot compute probabilities."
            )
        logits = outputs.logits
        target_ids = input_ids[0][prompt_len:]
        target_logits = logits[0, prompt_len - 1 : -1, :]
        log_probs = [
            F.log_softmax(target_logits[i], dim=-1)[tid].item()
            for i, tid in enumerate(target_ids)
        ]
        return torch.exp(torch.tensor(sum(log_probs)))

    @property
    def model(self) -> "PreTrainedModel":
        if self._model is None:
            self._model = self._load_model_from_hf()
        return self._model

    @property
    def tokenizer(self) -> Any:
        if hasattr(self._processor, "tokenizer"):
            return self._processor.tokenizer
        if callable(getattr(self._processor, "encode", None)):
            return self._processor
        raise AttributeError(
            f"{self.__class__.__name__}: This processor does not have a `.tokenizer` attribute. "
            "For some multimodal models, please use `.processor` directly."
        )

    @property
    def processor(self) -> Any:
        return self._processor

    @property
    def prompt_template(self) -> str:
        return self._prompt_template

    @prompt_template.setter
    def prompt_template(self, value: str) -> None:
        self._prompt_template = value

complete

complete(prompt=None, query=None, context=None, **kwargs)

Core generation method - contains the main generation logic.

Source code in src/fed_rag/generators/huggingface/hf_multimodal_model.py
def complete(
    self,
    prompt: Prompt | list[Prompt] | str | list[str] | None = None,
    query: str | Query | list[str] | list[Query] | None = None,
    context: str | Context | list[str] | list[Context] | None = None,
    **kwargs: Any,
) -> str | list[str]:
    """Core generation method - contains the main generation logic."""
    max_new_tokens = kwargs.pop("max_new_tokens", 256)
    add_generation_prompt = kwargs.pop("add_generation_prompt", True)

    # Handle both prompt-only and query+context cases
    if prompt is not None:
        # Traditional complete() usage: convert prompt to query, no context
        messages = self._pack_messages(prompt, context=None)
        is_batch = isinstance(prompt, list)
    else:
        # Called from generate(): use query and context
        messages = self._pack_messages(query, context)
        is_batch = isinstance(query, list)

    inputs = self._processor.apply_chat_template(
        messages,
        add_generation_prompt=add_generation_prompt,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

    input_len = inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = self.model.generate(
            **inputs, max_new_tokens=max_new_tokens, **kwargs
        )
        generation = generation[:, input_len:]
    decoded: list[str] = self._processor.batch_decode(
        generation, skip_special_tokens=True
    )
    if not is_batch:
        if not decoded or not isinstance(decoded[0], str):
            raise GeneratorError(
                "batch_decode did not return valid output"
            )
        return decoded[0]
    return decoded

generate

generate(query, context=None, **gen_kwargs)

Generate method - formats query+context and calls complete().

Source code in src/fed_rag/generators/huggingface/hf_multimodal_model.py
def generate(
    self,
    query: str | Query | list[str] | list[Query],
    context: str | Context | list[str] | list[Context] | None = None,
    **gen_kwargs: Any,
) -> str | list[str]:
    """Generate method - formats query+context and calls complete()."""
    return self.complete(query=query, context=context, **gen_kwargs)