Spaces:
Running
Running
File size: 2,377 Bytes
31086ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Set
# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it
# pylint: disable=no-name-in-module
from rocksdict import Rdict
from graphgen.bases.base_storage import BaseKVStorage
@dataclass
class RocksDBKVStorage(BaseKVStorage):
_db: Rdict = None
_db_path: str = None
def __post_init__(self):
self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db")
self._db = Rdict(self._db_path)
print(
f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'"
)
@property
def data(self):
return self._db
def all_keys(self) -> List[str]:
return list(self._db.keys())
def index_done_callback(self):
self._db.flush()
print(f"RocksDB flushed for {self.namespace}")
def get_by_id(self, id: str) -> Any:
return self._db.get(id, None)
def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]:
result = []
for index in ids:
item = self._db.get(index, None)
if item is None:
result.append(None)
continue
if fields is None:
result.append(item)
else:
result.append({k: v for k, v in item.items() if k in fields})
return result
def get_all(self) -> Dict[str, Dict]:
return dict(self._db)
def filter_keys(self, data: List[str]) -> Set[str]:
return {s for s in data if s not in self._db}
def upsert(self, data: Dict[str, Any]):
left_data = {}
for k, v in data.items():
if k not in self._db:
left_data[k] = v
if left_data:
for k, v in left_data.items():
self._db[k] = v
# if left_data is very large, it is recommended to use self._db.write_batch() for optimization
return left_data
def drop(self):
self._db.close()
Rdict.destroy(self._db_path)
self._db = Rdict(self._db_path)
print(f"Dropped RocksDB {self.namespace}")
def close(self):
if self._db:
self._db.close()
def reload(self):
"""For databases that need reloading, RocksDB auto-manages this."""
|