classHFSentenceTransformerRetriever(BaseRetriever):model_config=ConfigDict(protected_namespaces=("pydantic_model_",))model_name:str|None=Field(description="Name of HuggingFace SentenceTransformer model.",default=None,)query_model_name:str|None=Field(description="Name of HuggingFace SentenceTransformer model used for encoding queries.",default=None,)context_model_name:str|None=Field(description="Name of HuggingFace SentenceTransformer model used for encoding context.",default=None,)load_model_kwargs:LoadKwargs=Field(description="Optional kwargs dict for loading models from HF. Defaults to None.",default_factory=LoadKwargs,)_encoder:Optional["SentenceTransformer"]=PrivateAttr(default=None)_query_encoder:Optional["SentenceTransformer"]=PrivateAttr(default=None)_context_encoder:Optional["SentenceTransformer"]=PrivateAttr(default=None)def__init__(self,model_name:str|None=None,query_model_name:str|None=None,context_model_name:str|None=None,load_model_kwargs:LoadKwargs|dict|None=None,load_model_at_init:bool=True,):ifnot_has_huggingface:msg=(f"`{self.__class__.__name__}` requires `huggingface` extra to be installed. ""To fix please run `pip install fed-rag[huggingface]`.")raiseMissingExtraError(msg)ifisinstance(load_model_kwargs,dict):# use same dict for allload_model_kwargs=LoadKwargs(encoder=load_model_kwargs,query_encoder=load_model_kwargs,context_encoder=load_model_kwargs,)load_model_kwargs=(load_model_kwargsifload_model_kwargselseLoadKwargs())super().__init__(model_name=model_name,query_model_name=query_model_name,context_model_name=context_model_name,load_model_kwargs=load_model_kwargs,)ifload_model_at_init:ifmodel_name:self._encoder=self._load_model_from_hf(load_type="encoder")else:self._query_encoder=self._load_model_from_hf(load_type="query_encoder")self._context_encoder=self._load_model_from_hf(load_type="context_encoder")def_load_model_from_hf(self,load_type:Literal["encoder","query_encoder","context_encoder"],**kwargs:Any,)->"SentenceTransformer":ifload_type=="encoder":load_kwargs=self.load_model_kwargs.encoderload_kwargs.update(kwargs)returnSentenceTransformer(self.model_name,**load_kwargs)elifload_type=="context_encoder":load_kwargs=self.load_model_kwargs.context_encoderload_kwargs.update(kwargs)returnSentenceTransformer(self.context_model_name,**load_kwargs)elifload_type=="query_encoder":load_kwargs=self.load_model_kwargs.query_encoderload_kwargs.update(kwargs)returnSentenceTransformer(self.query_model_name,**load_kwargs)else:raiseInvalidLoadType("Invalid `load_type` supplied.")defencode_context(self,context:str|list[str],**kwargs:Any)->torch.Tensor:# validation guarantees one of these is not Noneencoder=self.encoderifself.encoderelseself.context_encoderencoder=cast(SentenceTransformer,encoder)returnencoder.encode(context)defencode_query(self,query:str|list[str],**kwargs:Any)->torch.Tensor:# validation guarantees one of these is not Noneencoder=self.encoderifself.encoderelseself.query_encoderencoder=cast(SentenceTransformer,encoder)returnencoder.encode(query)@propertydefencoder(self)->Optional["SentenceTransformer"]:ifself.model_nameandself._encoderisNone:self._encoder=self._load_model_from_hf(load_type="encoder")returnself._encoder@propertydefquery_encoder(self)->Optional["SentenceTransformer"]:ifself.query_model_nameandself._query_encoderisNone:self._query_encoder=self._load_model_from_hf(load_type="query_encoder")returnself._query_encoder@propertydefcontext_encoder(self)->Optional["SentenceTransformer"]:ifself.context_model_nameandself._context_encoderisNone:self._context_encoder=self._load_model_from_hf(load_type="context_encoder")returnself._context_encoder