github-actions[bot]
Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025
31086ae
import json
import os
import shutil
from dataclasses import dataclass
from typing import Any
try:
import kuzu
except ImportError:
kuzu = None
from graphgen.bases.base_storage import BaseGraphStorage
@dataclass
class KuzuStorage(BaseGraphStorage):
"""
Graph storage implementation based on KuzuDB.
Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties,
we map the data to a generic schema:
- Node Table 'Entity': {id: STRING, data: STRING (JSON)}
- Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)}
"""
working_dir: str = None
namespace: str = None
_db: Any = None
_conn: Any = None
def __post_init__(self):
if kuzu is None:
raise ImportError(
"KuzuDB is not installed. Please install it via `pip install kuzu`."
)
self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu")
self._init_db()
def _init_db(self):
# KuzuDB automatically creates the directory
self._db = kuzu.Database(self.db_path)
self._conn = kuzu.Connection(self._db)
self._init_schema()
print(f"KuzuDB initialized at {self.db_path}")
def _init_schema(self):
"""Initialize the generic Node and Edge tables if they don't exist."""
# Check and create Node table
try:
# We use a generic table name "Entity" to store all nodes
self._conn.execute(
"CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))"
)
print("Created KuzuDB Node Table 'Entity'")
except RuntimeError as e:
# Usually throws if table exists, verify safely or ignore
print("Node Table 'Entity' already exists or error:", e)
# Check and create Edge table
try:
# We use a generic table name "Relation" to store all edges
self._conn.execute(
"CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)"
)
print("Created KuzuDB Rel Table 'Relation'")
except RuntimeError as e:
print("Rel Table 'Relation' already exists or error:", e)
def index_done_callback(self):
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
def has_node(self, node_id: str) -> bool:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
)
count = result.get_next()[0]
return count > 0
def has_edge(self, source_node_id: str, target_node_id: str):
result = self._conn.execute(
"MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)",
{"src": source_node_id, "dst": target_node_id},
)
count = result.get_next()[0]
return count > 0
def node_degree(self, node_id: str) -> int:
# Calculate total degree (incoming + outgoing)
query = """
MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity)
RETURN count(e)
"""
result = self._conn.execute(query, {"id": node_id})
if result.has_next():
return result.get_next()[0]
return 0
def edge_degree(self, src_id: str, tgt_id: str) -> int:
# In this context, usually checks existence or multiplicity.
# Kuzu supports multi-edges, so we count them.
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
RETURN count(e)
"""
result = self._conn.execute(query, {"src": src_id, "dst": tgt_id})
if result.has_next():
return result.get_next()[0]
return 0
def get_node(self, node_id: str) -> Any:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
)
if result.has_next():
data_str = result.get_next()[0]
return json.loads(data_str) if data_str else {}
return None
def update_node(self, node_id: str, node_data: dict[str, str]):
current_data = self.get_node(node_id)
if current_data is None:
print(f"Node {node_id} not found for update.")
return
# Merge existing data with new data
current_data.update(node_data)
json_data = json.dumps(current_data, ensure_ascii=False)
self._conn.execute(
"MATCH (a:Entity {id: $id}) SET a.data = $data",
{"id": node_id, "data": json_data},
)
def get_all_nodes(self) -> Any:
"""Returns List[Tuple[id, data_dict]]"""
result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data")
nodes = []
while result.has_next():
row = result.get_next()
nodes.append((row[0], json.loads(row[1])))
return nodes
def get_edge(self, source_node_id: str, target_node_id: str):
# Warning: If multiple edges exist, this returns the first one found
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
RETURN e.data
"""
result = self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id}
)
if result.has_next():
data_str = result.get_next()[0]
return json.loads(data_str) if data_str else {}
return None
def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
current_data = self.get_edge(source_node_id, target_node_id)
if current_data is None:
print(f"Edge {source_node_id}->{target_node_id} not found for update.")
return
current_data.update(edge_data)
json_data = json.dumps(current_data, ensure_ascii=False)
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
SET e.data = $data
"""
self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
)
def get_all_edges(self) -> Any:
"""Returns List[Tuple[src, dst, data_dict]]"""
query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data"
result = self._conn.execute(query)
edges = []
while result.has_next():
row = result.get_next()
edges.append((row[0], row[1], json.loads(row[2])))
return edges
def get_node_edges(self, source_node_id: str) -> Any:
"""Returns generic edges connected to this node (outgoing)"""
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity)
RETURN a.id, b.id, e.data
"""
result = self._conn.execute(query, {"src": source_node_id})
edges = []
while result.has_next():
row = result.get_next()
edges.append((row[0], row[1], json.loads(row[2])))
return edges
def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""
Insert or Update node.
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
"""
json_data = json.dumps(node_data, ensure_ascii=False)
query = """
MERGE (a:Entity {id: $id})
ON MATCH SET a.data = $data
ON CREATE SET a.data = $data
"""
self._conn.execute(query, {"id": node_id, "data": json_data})
def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""
Insert or Update edge.
Note: We explicitly ensure nodes exist before merging the edge to avoid errors,
although GraphGen generally creates nodes before edges.
"""
# Ensure source node exists
if not self.has_node(source_node_id):
self.upsert_node(source_node_id, {})
# Ensure target node exists
if not self.has_node(target_node_id):
self.upsert_node(target_node_id, {})
json_data = json.dumps(edge_data, ensure_ascii=False)
query = """
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
MERGE (a)-[e:Relation]->(b)
ON MATCH SET e.data = $data
ON CREATE SET e.data = $data
"""
self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
)
def delete_node(self, node_id: str):
# DETACH DELETE removes the node and all connected edges
query = "MATCH (a:Entity {id: $id}) DETACH DELETE a"
self._conn.execute(query, {"id": node_id})
print(f"Node {node_id} deleted from KuzuDB.")
def clear(self):
"""Clear all data but keep schema (or drop tables)."""
self._conn.execute("MATCH (n) DETACH DELETE n")
print(f"Graph {self.namespace} cleared.")
def reload(self):
"""For databases that need reloading, KuzuDB auto-manages this."""
def drop(self):
"""Completely remove the database folder."""
if self.db_path and os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
print(f"Dropped KuzuDB at {self.db_path}")