Bases: BaseKnowledgeStore
InMemoryKnowledgeStore Class.
Source code in src/fed_rag/knowledge_stores/in_memory.py
| class InMemoryKnowledgeStore(BaseKnowledgeStore):
"""InMemoryKnowledgeStore Class."""
cache_dir: str = Field(default=DEFAULT_CACHE_DIR)
_data: dict[str, KnowledgeNode] = PrivateAttr(default_factory=dict)
@classmethod
def from_nodes(cls, nodes: list[KnowledgeNode], **kwargs: Any) -> Self:
instance = cls(**kwargs)
instance.load_nodes(nodes)
return instance
def load_node(self, node: KnowledgeNode) -> None:
if node.node_id not in self._data:
self._data[node.node_id] = node
def load_nodes(self, nodes: list[KnowledgeNode]) -> None:
for node in nodes:
self.load_node(node)
def retrieve(
self, query_emb: list[float], top_k: int
) -> list[tuple[float, KnowledgeNode]]:
all_nodes = list(self._data.values())
node_ids_and_scores = _get_top_k_nodes(
nodes=all_nodes, query_emb=query_emb, top_k=top_k
)
return [(el[1], self._data[el[0]]) for el in node_ids_and_scores]
def delete_node(self, node_id: str) -> bool:
if node_id in self._data:
del self._data[node_id]
return True
else:
return False
def clear(self) -> None:
self._data = {}
@property
def count(self) -> int:
return len(self._data)
@model_serializer(mode="wrap")
def custom_model_dump(self, handler: Any) -> Dict[str, Any]:
data = handler(self)
data = cast(Dict[str, Any], data)
# include _data in serialization
if self._data:
data["_data"] = self._data
return data # type: ignore[no-any-return]
def persist(self) -> None:
serialized_model = self.model_dump()
data_values = list(serialized_model["_data"].values())
parquet_table = pa.Table.from_pylist(data_values)
filename = Path(self.cache_dir) / f"{self.name}.parquet"
Path(filename).parent.mkdir(parents=True, exist_ok=True)
pq.write_table(parquet_table, filename)
def load(self) -> None:
filename = Path(self.cache_dir) / f"{self.name}.parquet"
if not filename.exists():
msg = f"Knowledge store '{self.name}' not found at expected location: {filename}"
raise KnowledgeStoreNotFoundError(msg)
parquet_data = pq.read_table(filename).to_pylist()
nodes = [KnowledgeNode(**data) for data in parquet_data]
self.load_nodes(nodes)
|