Skip to content

Huggingface

HuggingFace PeftModel Generator

HFPeftModelGenerator

Bases: HuggingFaceGeneratorMixin, BaseGenerator

HFPeftModelGenerator Class.

NOTE: this class supports loading PeftModel's from HF Hub or from local. TODO: support loading custom models via a ~peft.Config and ~peft.get_peft_model

Source code in src/fed_rag/generators/huggingface/hf_peft_model.py
class HFPeftModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
    """HFPeftModelGenerator Class.

    NOTE: this class supports loading PeftModel's from HF Hub or from local.
    TODO: support loading custom models via a `~peft.Config` and `~peft.get_peft_model`
    """

    model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
    model_name: str = Field(
        description="Name of Peft model. Used for loading model from HF hub or local."
    )
    base_model_name: str = Field(
        description="Name of the frozen HuggingFace base model. Used for loading the model from HF hub or local."
    )
    generation_config: "GenerationConfig" = Field(
        description="The generation config used for generating with the PreTrainedModel."
    )
    load_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading peft model from HF. Defaults to None.",
        default_factory=dict,
    )
    load_base_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading base model from HF. Defaults to None.",
        default_factory=dict,
    )
    _prompt_template: str = PrivateAttr(default=DEFAULT_PROMPT_TEMPLATE)
    _model: Optional["PeftModel"] = PrivateAttr(default=None)
    _tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)

    def __init__(
        self,
        model_name: str,
        base_model_name: str,
        generation_config: Optional["GenerationConfig"] = None,
        prompt_template: str | None = None,
        load_model_kwargs: dict | None = None,
        load_base_model_kwargs: dict | None = None,
        load_model_at_init: bool = True,
    ):
        # if reaches here, then passed checks for huggingface extra installation
        from transformers.generation.utils import GenerationConfig

        generation_config = generation_config or GenerationConfig()
        super().__init__(
            model_name=model_name,
            base_model_name=base_model_name,
            generation_config=generation_config,
            prompt_template=prompt_template,
            load_model_kwargs=load_model_kwargs or {},
            load_base_model_kwargs=load_base_model_kwargs or {},
        )
        self._tokenizer = HFPretrainedTokenizer(
            model_name=base_model_name, load_model_at_init=load_model_at_init
        )
        self._prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
        if load_model_at_init:
            self._model = self._load_model_from_hf()

    @model_validator(mode="before")
    @classmethod
    def check_dependencies(cls, data: Any) -> Any:
        """Validate that huggingface dependencies are installed."""
        check_huggingface_installed(cls.__name__)
        return data

    def _load_model_from_hf(self, **kwargs: Any) -> "PeftModel":
        from peft import PeftModel, prepare_model_for_kbit_training
        from transformers import AutoModelForCausalLM

        self.load_model_kwargs.update(kwargs)  # update load_model_kwargs
        base_model = AutoModelForCausalLM.from_pretrained(
            self.base_model_name, **self.load_base_model_kwargs
        )

        if "quantization_config" in self.load_base_model_kwargs:
            # preprocess model for kbit fine-tuning
            # https://huggingface.co/docs/peft/developer_guides/quantization
            base_model = prepare_model_for_kbit_training(base_model)

        return PeftModel.from_pretrained(
            base_model, self.model_name, **self.load_model_kwargs
        )

    @property
    def model(self) -> "PeftModel":
        if self._model is None:
            # load HF PeftModel
            self._model = self._load_model_from_hf()
        return self._model

    @model.setter
    def model(self, value: "PeftModel") -> None:
        self._model = value

    @property
    def tokenizer(self) -> HFPretrainedTokenizer:
        return self._tokenizer

    @tokenizer.setter
    def tokenizer(self, value: HFPretrainedTokenizer) -> None:
        self._tokenizer = value

    @property
    def prompt_template(self) -> str:
        return self._prompt_template

    @prompt_template.setter
    def prompt_template(self, value: str) -> None:
        self._prompt_template = value

check_dependencies classmethod

check_dependencies(data)

Validate that huggingface dependencies are installed.

Source code in src/fed_rag/generators/huggingface/hf_peft_model.py
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, data: Any) -> Any:
    """Validate that huggingface dependencies are installed."""
    check_huggingface_installed(cls.__name__)
    return data

HuggingFace PretrainedModel Generator

HFPretrainedModelGenerator

Bases: HuggingFaceGeneratorMixin, BaseGenerator

Source code in src/fed_rag/generators/huggingface/hf_pretrained_model.py
class HFPretrainedModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
    model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
    model_name: str = Field(
        description="Name of HuggingFace model. Used for loading the model from HF hub or local."
    )
    generation_config: "GenerationConfig" = Field(
        description="The generation config used for generating with the PreTrainedModel."
    )
    load_model_kwargs: dict = Field(
        description="Optional kwargs dict for loading models from HF. Defaults to None.",
        default_factory=dict,
    )
    _prompt_template: str = PrivateAttr(default=DEFAULT_PROMPT_TEMPLATE)
    _model: Optional["PreTrainedModel"] = PrivateAttr(default=None)
    _tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)

    def __init__(
        self,
        model_name: str,
        generation_config: Optional["GenerationConfig"] = None,
        prompt_template: str | None = None,
        load_model_kwargs: dict | None = None,
        load_model_at_init: bool = True,
    ):
        # if reaches here, then passed checks for extra
        from transformers.generation.utils import GenerationConfig

        generation_config = generation_config or GenerationConfig()
        super().__init__(
            model_name=model_name,
            generation_config=generation_config,
            load_model_kwargs=load_model_kwargs or {},
        )
        self._tokenizer = HFPretrainedTokenizer(
            model_name=model_name, load_model_at_init=load_model_at_init
        )
        self._prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
        if load_model_at_init:
            self._model = self._load_model_from_hf()

    @model_validator(mode="before")
    @classmethod
    def check_dependencies(cls, data: Any) -> Any:
        """Validate that huggingface dependencies are installed."""
        check_huggingface_installed(cls.__name__)
        return data

    def _load_model_from_hf(self, **kwargs: Any) -> "PreTrainedModel":
        from transformers import AutoModelForCausalLM

        self.load_model_kwargs.update(kwargs)
        return AutoModelForCausalLM.from_pretrained(
            self.model_name, **self.load_model_kwargs
        )

    @property
    def model(self) -> "PreTrainedModel":
        if self._model is None:
            # load HF Pretrained Model
            self._model = self._load_model_from_hf()
        return self._model

    @model.setter
    def model(self, value: "PreTrainedModel") -> None:
        self._model = value

    @property
    def tokenizer(self) -> HFPretrainedTokenizer:
        return self._tokenizer

    @tokenizer.setter
    def tokenizer(self, value: HFPretrainedTokenizer) -> None:
        self._tokenizer = value

    @property
    def prompt_template(self) -> str:
        return self._prompt_template

    @prompt_template.setter
    def prompt_template(self, value: str) -> None:
        self._prompt_template = value

check_dependencies classmethod

check_dependencies(data)

Validate that huggingface dependencies are installed.

Source code in src/fed_rag/generators/huggingface/hf_pretrained_model.py
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, data: Any) -> Any:
    """Validate that huggingface dependencies are installed."""
    check_huggingface_installed(cls.__name__)
    return data