class InMemoryKnowledgeStore(BaseKnowledgeStore):
"""InMemoryKnowledgeStore Class."""
cache_dir: str = Field(default=DEFAULT_CACHE_DIR)
_data: dict[str, KnowledgeNode] = PrivateAttr(default_factory=dict)
_data_storage: list[float] = PrivateAttr(default_factory=list)
_node_list: list[str] = PrivateAttr(default_factory=list)
@classmethod
def from_nodes(cls, nodes: list[KnowledgeNode], **kwargs: Any) -> Self:
instance = cls(**kwargs)
instance.load_nodes(nodes)
return instance
def load_node(self, node: KnowledgeNode) -> None:
if isinstance(self._data_storage, torch.Tensor):
device = torch.device("cpu")
self._data_storage = self._data_storage.to(device).tolist()
gc.collect() # Clean up Python garbage
torch.cuda.empty_cache()
if node.node_id not in self._data:
self._data[node.node_id] = node
self._node_list.append(node.node_id)
self._data_storage.append(node.embedding)
def load_nodes(self, nodes: list[KnowledgeNode]) -> None:
for node in nodes:
self.load_node(node)
def retrieve(
self, query_emb: list[float], top_k: int = DEFAULT_TOP_K
) -> list[tuple[float, KnowledgeNode]]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
query_emb = torch.tensor(query_emb).to(device)
if not torch.is_tensor(self._data_storage):
self._data_storage = torch.tensor(self._data_storage).to(device)
node_ids_and_scores = _get_top_k_nodes(
nodes=self._node_list,
embeddings=self._data_storage,
query_emb=query_emb,
top_k=top_k,
)
return [(el[1], self._data[el[0]]) for el in node_ids_and_scores]
def batch_retrieve(
self, query_embs: list[list[float]], top_k: int = DEFAULT_TOP_K
) -> list[list[tuple[float, "KnowledgeNode"]]]:
raise NotImplementedError(
f"batch_retrieve is not implemented for {self.__class__.__name__}."
)
def delete_node(self, node_id: str) -> bool:
if isinstance(self._data_storage, torch.Tensor):
device = torch.device("cpu")
self._data_storage = self._data_storage.to(device).tolist()
gc.collect() # Clean up Python garbage
torch.cuda.empty_cache()
if node_id in self._data:
del self._data[node_id]
for i in range(len(self._node_list)):
print(len(self._node_list))
if node_id == self._node_list[i]:
del self._data_storage[i]
del self._node_list[i]
break
return True
else:
return False
def clear(self) -> None:
self._data = {}
@property
def count(self) -> int:
return len(self._data)
@model_serializer(mode="wrap")
def custom_model_dump(self, handler: Any) -> Dict[str, Any]:
data = handler(self)
data = cast(Dict[str, Any], data)
# include _data in serialization
if self._data:
data["_data"] = self._data
return data # type: ignore[no-any-return]
def persist(self) -> None:
serialized_model = self.model_dump()
data_values = list(serialized_model["_data"].values())
parquet_table = pa.Table.from_pylist(data_values)
filename = Path(self.cache_dir) / f"{self.name}.parquet"
Path(filename).parent.mkdir(parents=True, exist_ok=True)
pq.write_table(parquet_table, filename)
def load(self) -> None:
filename = Path(self.cache_dir) / f"{self.name}.parquet"
if not filename.exists():
msg = f"Knowledge store '{self.name}' not found at expected location: {filename}"
raise KnowledgeStoreNotFoundError(msg)
parquet_data = pq.read_table(filename).to_pylist()
nodes = [KnowledgeNode(**data) for data in parquet_data]
self.load_nodes(nodes)