Spaces:
Running
Running
File size: 3,559 Bytes
fb9c306 acd7cf4 31086ae fb9c306 31086ae acd7cf4 31086ae acd7cf4 31086ae acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 31086ae acd7cf4 fb9c306 31086ae fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 31086ae bda6eda 31086ae fb9c306 acd7cf4 31086ae fb9c306 31086ae acd7cf4 31086ae acd7cf4 bda6eda fb9c306 acd7cf4 fb9c306 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import argparse
import os
import time
from importlib import resources
from typing import Any, Dict
import ray
import yaml
from dotenv import load_dotenv
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider
from graphgen.engine import Engine
from graphgen.operators import operators
from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger
sys_path = os.path.abspath(os.path.dirname(__file__))
load_dotenv()
def set_working_dir(folder):
os.makedirs(folder, exist_ok=True)
def save_config(config_path, global_config):
if not os.path.exists(os.path.dirname(config_path)):
os.makedirs(os.path.dirname(config_path))
with open(config_path, "w", encoding="utf-8") as config_file:
yaml.dump(
global_config, config_file, default_flow_style=False, allow_unicode=True
)
class NodeFilenameProvider(FilenameProvider):
def __init__(self, node_id: str):
self.node_id = node_id
def get_filename_for_block(
self, block: Block, write_uuid: str, task_index: int, block_index: int
) -> str:
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
def get_filename_for_row(
self,
row: Dict[str, Any],
write_uuid: str,
task_index: int,
block_index: int,
row_index: int,
) -> str:
raise NotImplementedError(
f"Row-based filenames are not supported by write_json. "
f"Node: {self.node_id}, write_uuid: {write_uuid}"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
help="Config parameters for GraphGen.",
default=resources.files("graphgen")
.joinpath("configs")
.joinpath("aggregated_config.yaml"),
type=str,
)
parser.add_argument(
"--output_dir",
help="Output directory for GraphGen.",
default=sys_path,
required=True,
type=str,
)
args = parser.parse_args()
working_dir = args.output_dir
with open(args.config_file, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
unique_id = int(time.time())
output_path = os.path.join(working_dir, "output", f"{unique_id}")
set_working_dir(output_path)
log_path = os.path.join(working_dir, "logs", "Driver.log")
driver_logger = set_logger(
log_path,
name="GraphGen",
if_stream=True,
)
CURRENT_LOGGER_VAR.set(driver_logger)
logger.info(
"GraphGen with unique ID %s logging to %s",
unique_id,
log_path,
)
engine = Engine(config, operators)
ds = ray.data.from_items([])
results = engine.execute(ds)
for node_id, dataset in results.items():
output_path = os.path.join(output_path, f"{node_id}")
os.makedirs(output_path, exist_ok=True)
dataset.write_json(
output_path,
filename_provider=NodeFilenameProvider(node_id),
pandas_json_args_fn=lambda: {
"force_ascii": False,
"orient": "records",
"lines": True,
},
)
logger.info("Node %s results saved to %s", node_id, output_path)
save_config(os.path.join(output_path, "config.yaml"), config)
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
if __name__ == "__main__":
main()
|