class HFPeftModelGenerator(HuggingFaceGeneratorMixin, BaseGenerator):
"""HFPeftModelGenerator Class.
NOTE: this class supports loading PeftModel's from HF Hub or from local.
TODO: support loading custom models via a `~peft.Config` and `~peft.get_peft_model`
"""
model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
model_name: str = Field(
description="Name of Peft model. Used for loading model from HF hub or local."
)
base_model_name: str = Field(
description="Name of the frozen HuggingFace base model. Used for loading the model from HF hub or local."
)
generation_config: "GenerationConfig" = Field(
description="The generation config used for generating with the PreTrainedModel."
)
load_model_kwargs: dict = Field(
description="Optional kwargs dict for loading peft model from HF. Defaults to None.",
default_factory=dict,
)
load_base_model_kwargs: dict = Field(
description="Optional kwargs dict for loading base model from HF. Defaults to None.",
default_factory=dict,
)
prompt_template: str = Field(description="Prompt template for RAG.")
_model: Optional["PeftModel"] = PrivateAttr(default=None)
_tokenizer: HFPretrainedTokenizer | None = PrivateAttr(default=None)
def __init__(
self,
model_name: str,
base_model_name: str,
generation_config: Optional["GenerationConfig"] = None,
prompt_template: str | None = None,
load_model_kwargs: dict | None = None,
load_base_model_kwargs: dict | None = None,
load_model_at_init: bool = True,
):
if not _has_huggingface:
msg = (
f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
"To fix please run `pip install fed-rag[huggingface]`."
)
raise MissingExtraError(msg)
generation_config = (
generation_config if generation_config else GenerationConfig()
)
prompt_template = (
prompt_template if prompt_template else DEFAULT_PROMPT_TEMPLATE
)
super().__init__(
model_name=model_name,
base_model_name=base_model_name,
generation_config=generation_config,
prompt_template=prompt_template,
load_model_kwargs=load_model_kwargs if load_model_kwargs else {},
load_base_model_kwargs=(
load_base_model_kwargs if load_base_model_kwargs else {}
),
)
self._tokenizer = HFPretrainedTokenizer(
model_name=base_model_name, load_model_at_init=load_model_at_init
)
if load_model_at_init:
self._model = self._load_model_from_hf()
def _load_model_from_hf(self, **kwargs: Any) -> "PeftModel":
load_base_kwargs = self.load_base_model_kwargs
load_kwargs = self.load_model_kwargs
load_kwargs.update(kwargs)
self.load_model_kwargs = load_kwargs # update load_model_kwargs
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name, **load_base_kwargs
)
if "quantization_config" in load_base_kwargs:
# preprocess model for kbit fine-tuning
# https://huggingface.co/docs/peft/developer_guides/quantization
base_model = prepare_model_for_kbit_training(base_model)
return PeftModel.from_pretrained(
base_model, self.model_name, **load_kwargs
)
@property
def model(self) -> "PeftModel":
if self._model is None:
# load HF PeftModel
self._model = self._load_model_from_hf()
return self._model
@model.setter
def model(self, value: "PeftModel") -> None:
self._model = value
@property
def tokenizer(self) -> HFPretrainedTokenizer:
return self._tokenizer
@tokenizer.setter
def tokenizer(self, value: HFPretrainedTokenizer) -> None:
self._tokenizer = value