class HFPeftModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
"""HFPeftModelGenerator Class.
NOTE: this class supports loading PeftModel's from HF Hub or from local.
TODO: support loading custom models via a `~peft.Config` and `~peft.get_peft_model`
"""
model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
model_name: str = Field(
description="Name of Peft model. Used for loading model from HF hub or local."
)
base_model_name: str = Field(
description="Name of the frozen HuggingFace base model. Used for loading the model from HF hub or local."
)
generation_config: "GenerationConfig" = Field(
description="The generation config used for generating with the PreTrainedModel."
)
load_model_kwargs: dict = Field(
description="Optional kwargs dict for loading peft model from HF. Defaults to None.",
default_factory=dict,
)
load_base_model_kwargs: dict = Field(
description="Optional kwargs dict for loading base model from HF. Defaults to None.",
default_factory=dict,
)
_prompt_template: str = PrivateAttr(default=DEFAULT_PROMPT_TEMPLATE)
_model: Optional["PeftModel"] = PrivateAttr(default=None)
_tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)
def __init__(
self,
model_name: str,
base_model_name: str,
generation_config: Optional["GenerationConfig"] = None,
prompt_template: str | None = None,
load_model_kwargs: dict | None = None,
load_base_model_kwargs: dict | None = None,
load_model_at_init: bool = True,
):
# if reaches here, then passed checks for huggingface extra installation
from transformers.generation.utils import GenerationConfig
generation_config = generation_config or GenerationConfig()
super().__init__(
model_name=model_name,
base_model_name=base_model_name,
generation_config=generation_config,
prompt_template=prompt_template,
load_model_kwargs=load_model_kwargs or {},
load_base_model_kwargs=load_base_model_kwargs or {},
)
self._tokenizer = HFPretrainedTokenizer(
model_name=base_model_name, load_model_at_init=load_model_at_init
)
self._prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
if load_model_at_init:
self._model = self._load_model_from_hf()
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, data: Any) -> Any:
"""Validate that huggingface dependencies are installed."""
check_huggingface_installed(cls.__name__)
return data
def _load_model_from_hf(self, **kwargs: Any) -> "PeftModel":
from peft import PeftModel, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM
self.load_model_kwargs.update(kwargs) # update load_model_kwargs
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name, **self.load_base_model_kwargs
)
if "quantization_config" in self.load_base_model_kwargs:
# preprocess model for kbit fine-tuning
# https://huggingface.co/docs/peft/developer_guides/quantization
base_model = prepare_model_for_kbit_training(base_model)
return PeftModel.from_pretrained(
base_model, self.model_name, **self.load_model_kwargs
)
@property
def model(self) -> "PeftModel":
if self._model is None:
# load HF PeftModel
self._model = self._load_model_from_hf()
return self._model
@model.setter
def model(self, value: "PeftModel") -> None:
self._model = value
@property
def tokenizer(self) -> HFPretrainedTokenizer:
return self._tokenizer
@tokenizer.setter
def tokenizer(self, value: HFPretrainedTokenizer) -> None:
self._tokenizer = value
@property
def prompt_template(self) -> str:
return self._prompt_template
@prompt_template.setter
def prompt_template(self, value: str) -> None:
self._prompt_template = value