class HFSentenceTransformerRetriever(BaseRetriever):
model_config = ConfigDict(protected_namespaces=("pydantic_model_",))
model_name: str | None = Field(
description="Name of HuggingFace SentenceTransformer model.",
default=None,
)
query_model_name: str | None = Field(
description="Name of HuggingFace SentenceTransformer model used for encoding queries.",
default=None,
)
context_model_name: str | None = Field(
description="Name of HuggingFace SentenceTransformer model used for encoding context.",
default=None,
)
load_model_kwargs: LoadKwargs = Field(
description="Optional kwargs dict for loading models from HF. Defaults to None.",
default_factory=LoadKwargs,
)
_encoder: Optional["SentenceTransformer"] = PrivateAttr(default=None)
_query_encoder: Optional["SentenceTransformer"] = PrivateAttr(default=None)
_context_encoder: Optional["SentenceTransformer"] = PrivateAttr(
default=None
)
def __init__(
self,
model_name: str | None = None,
query_model_name: str | None = None,
context_model_name: str | None = None,
load_model_kwargs: LoadKwargs | dict | None = None,
load_model_at_init: bool = True,
):
if not _has_huggingface:
msg = (
f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. "
"To fix please run `pip install fed-rag[huggingface]`."
)
raise MissingExtraError(msg)
if isinstance(load_model_kwargs, dict):
# use same dict for all
load_model_kwargs = LoadKwargs(
encoder=load_model_kwargs,
query_encoder=load_model_kwargs,
context_encoder=load_model_kwargs,
)
load_model_kwargs = (
load_model_kwargs if load_model_kwargs else LoadKwargs()
)
super().__init__(
model_name=model_name,
query_model_name=query_model_name,
context_model_name=context_model_name,
load_model_kwargs=load_model_kwargs,
)
if load_model_at_init:
if model_name:
self._encoder = self._load_model_from_hf(load_type="encoder")
else:
self._query_encoder = self._load_model_from_hf(
load_type="query_encoder"
)
self._context_encoder = self._load_model_from_hf(
load_type="context_encoder"
)
def _load_model_from_hf(
self,
load_type: Literal["encoder", "query_encoder", "context_encoder"],
**kwargs: Any,
) -> "SentenceTransformer":
if load_type == "encoder":
load_kwargs = self.load_model_kwargs.encoder
load_kwargs.update(kwargs)
return SentenceTransformer(self.model_name, **load_kwargs)
elif load_type == "context_encoder":
load_kwargs = self.load_model_kwargs.context_encoder
load_kwargs.update(kwargs)
return SentenceTransformer(self.context_model_name, **load_kwargs)
elif load_type == "query_encoder":
load_kwargs = self.load_model_kwargs.query_encoder
load_kwargs.update(kwargs)
return SentenceTransformer(self.query_model_name, **load_kwargs)
else:
raise InvalidLoadType("Invalid `load_type` supplied.")
def encode_context(
self, context: str | list[str], **kwargs: Any
) -> torch.Tensor:
# validation guarantees one of these is not None
encoder = self.encoder if self.encoder else self.context_encoder
encoder = cast(SentenceTransformer, encoder)
return encoder.encode(context)
def encode_query(
self, query: str | list[str], **kwargs: Any
) -> torch.Tensor:
# validation guarantees one of these is not None
encoder = self.encoder if self.encoder else self.query_encoder
encoder = cast(SentenceTransformer, encoder)
return encoder.encode(query)
@property
def encoder(self) -> Optional["SentenceTransformer"]:
if self.model_name and self._encoder is None:
self._encoder = self._load_model_from_hf(load_type="encoder")
return self._encoder
@property
def query_encoder(self) -> Optional["SentenceTransformer"]:
if self.query_model_name and self._query_encoder is None:
self._query_encoder = self._load_model_from_hf(
load_type="query_encoder"
)
return self._query_encoder
@property
def context_encoder(self) -> Optional["SentenceTransformer"]:
if self.context_model_name and self._context_encoder is None:
self._context_encoder = self._load_model_from_hf(
load_type="context_encoder"
)
return self._context_encoder