File size: 3,559 Bytes
fb9c306
acd7cf4
 
31086ae
 
fb9c306
31086ae
acd7cf4
 
31086ae
 
acd7cf4
31086ae
 
 
acd7cf4
 
 
 
 
fb9c306
acd7cf4
 
 
fb9c306
acd7cf4
 
 
fb9c306
 
 
 
 
acd7cf4
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
 
fb9c306
 
 
31086ae
 
 
fb9c306
 
 
 
 
 
 
 
 
acd7cf4
 
 
 
 
fb9c306
acd7cf4
 
fb9c306
31086ae
bda6eda
31086ae
 
 
 
fb9c306
acd7cf4
31086ae
fb9c306
 
 
31086ae
acd7cf4
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
bda6eda
fb9c306
acd7cf4
 
fb9c306
acd7cf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import os
import time
from importlib import resources
from typing import Any, Dict

import ray
import yaml
from dotenv import load_dotenv
from ray.data.block import Block
from ray.data.datasource.filename_provider import FilenameProvider

from graphgen.engine import Engine
from graphgen.operators import operators
from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger

sys_path = os.path.abspath(os.path.dirname(__file__))

load_dotenv()


def set_working_dir(folder):
    os.makedirs(folder, exist_ok=True)


def save_config(config_path, global_config):
    if not os.path.exists(os.path.dirname(config_path)):
        os.makedirs(os.path.dirname(config_path))
    with open(config_path, "w", encoding="utf-8") as config_file:
        yaml.dump(
            global_config, config_file, default_flow_style=False, allow_unicode=True
        )


class NodeFilenameProvider(FilenameProvider):
    def __init__(self, node_id: str):
        self.node_id = node_id

    def get_filename_for_block(
        self, block: Block, write_uuid: str, task_index: int, block_index: int
    ) -> str:
        # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
        return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"

    def get_filename_for_row(
        self,
        row: Dict[str, Any],
        write_uuid: str,
        task_index: int,
        block_index: int,
        row_index: int,
    ) -> str:
        raise NotImplementedError(
            f"Row-based filenames are not supported by write_json. "
            f"Node: {self.node_id}, write_uuid: {write_uuid}"
        )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        help="Config parameters for GraphGen.",
        default=resources.files("graphgen")
        .joinpath("configs")
        .joinpath("aggregated_config.yaml"),
        type=str,
    )
    parser.add_argument(
        "--output_dir",
        help="Output directory for GraphGen.",
        default=sys_path,
        required=True,
        type=str,
    )

    args = parser.parse_args()

    working_dir = args.output_dir

    with open(args.config_file, "r", encoding="utf-8") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    unique_id = int(time.time())
    output_path = os.path.join(working_dir, "output", f"{unique_id}")
    set_working_dir(output_path)
    log_path = os.path.join(working_dir, "logs", "Driver.log")
    driver_logger = set_logger(
        log_path,
        name="GraphGen",
        if_stream=True,
    )
    CURRENT_LOGGER_VAR.set(driver_logger)
    logger.info(
        "GraphGen with unique ID %s logging to %s",
        unique_id,
        log_path,
    )

    engine = Engine(config, operators)
    ds = ray.data.from_items([])
    results = engine.execute(ds)

    for node_id, dataset in results.items():
        output_path = os.path.join(output_path, f"{node_id}")
        os.makedirs(output_path, exist_ok=True)
        dataset.write_json(
            output_path,
            filename_provider=NodeFilenameProvider(node_id),
            pandas_json_args_fn=lambda: {
                "force_ascii": False,
                "orient": "records",
                "lines": True,
            },
        )
        logger.info("Node %s results saved to %s", node_id, output_path)

    save_config(os.path.join(output_path, "config.yaml"), config)
    logger.info("GraphGen completed successfully. Data saved to %s", output_path)


if __name__ == "__main__":
    main()