File size: 2,377 Bytes
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Set

# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it
# pylint: disable=no-name-in-module
from rocksdict import Rdict

from graphgen.bases.base_storage import BaseKVStorage


@dataclass
class RocksDBKVStorage(BaseKVStorage):
    _db: Rdict = None
    _db_path: str = None

    def __post_init__(self):
        self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db")
        self._db = Rdict(self._db_path)
        print(
            f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'"
        )

    @property
    def data(self):
        return self._db

    def all_keys(self) -> List[str]:
        return list(self._db.keys())

    def index_done_callback(self):
        self._db.flush()
        print(f"RocksDB flushed for {self.namespace}")

    def get_by_id(self, id: str) -> Any:
        return self._db.get(id, None)

    def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]:
        result = []
        for index in ids:
            item = self._db.get(index, None)
            if item is None:
                result.append(None)
                continue

            if fields is None:
                result.append(item)
            else:
                result.append({k: v for k, v in item.items() if k in fields})
        return result

    def get_all(self) -> Dict[str, Dict]:
        return dict(self._db)

    def filter_keys(self, data: List[str]) -> Set[str]:
        return {s for s in data if s not in self._db}

    def upsert(self, data: Dict[str, Any]):
        left_data = {}
        for k, v in data.items():
            if k not in self._db:
                left_data[k] = v

        if left_data:
            for k, v in left_data.items():
                self._db[k] = v
            # if left_data is very large, it is recommended to use self._db.write_batch() for optimization

        return left_data

    def drop(self):
        self._db.close()
        Rdict.destroy(self._db_path)
        self._db = Rdict(self._db_path)
        print(f"Dropped RocksDB {self.namespace}")

    def close(self):
        if self._db:
            self._db.close()

    def reload(self):
        """For databases that need reloading, RocksDB auto-manages this."""