Spaces:
Running
Running
| import inspect | |
| import logging | |
| from collections import defaultdict, deque | |
| from functools import wraps | |
| from typing import Any, Callable, Dict, List, Set | |
| import ray | |
| import ray.data | |
| from graphgen.bases import Config, Node | |
| from graphgen.utils import logger | |
| class Engine: | |
| def __init__( | |
| self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs | |
| ): | |
| self.config = Config(**config) | |
| self.global_params = self.config.global_params | |
| self.functions = functions | |
| self.datasets: Dict[str, ray.data.Dataset] = {} | |
| if not ray.is_initialized(): | |
| context = ray.init( | |
| ignore_reinit_error=True, | |
| logging_level=logging.ERROR, | |
| log_to_driver=True, | |
| **ray_init_kwargs, | |
| ) | |
| logger.info("Ray Dashboard URL: %s", context.dashboard_url) | |
| def _topo_sort(nodes: List[Node]) -> List[Node]: | |
| id_to_node: Dict[str, Node] = {} | |
| for n in nodes: | |
| id_to_node[n.id] = n | |
| indeg: Dict[str, int] = {nid: 0 for nid in id_to_node} | |
| adj: Dict[str, List[str]] = defaultdict(list) | |
| for n in nodes: | |
| nid = n.id | |
| deps: List[str] = n.dependencies | |
| uniq_deps: Set[str] = set(deps) | |
| for d in uniq_deps: | |
| if d not in id_to_node: | |
| raise ValueError( | |
| f"The dependency node id {d} of node {nid} is not defined in the configuration." | |
| ) | |
| indeg[nid] += 1 | |
| adj[d].append(nid) | |
| zero_deg: deque = deque( | |
| [id_to_node[nid] for nid, deg in indeg.items() if deg == 0] | |
| ) | |
| sorted_nodes: List[Node] = [] | |
| while zero_deg: | |
| cur = zero_deg.popleft() | |
| sorted_nodes.append(cur) | |
| cur_id = cur.id | |
| for nb_id in adj.get(cur_id, []): | |
| indeg[nb_id] -= 1 | |
| if indeg[nb_id] == 0: | |
| zero_deg.append(id_to_node[nb_id]) | |
| if len(sorted_nodes) != len(nodes): | |
| remaining = [nid for nid, deg in indeg.items() if deg > 0] | |
| raise ValueError( | |
| f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}" | |
| ) | |
| return sorted_nodes | |
| def _get_input_dataset( | |
| self, node: Node, initial_ds: ray.data.Dataset | |
| ) -> ray.data.Dataset: | |
| deps = node.dependencies | |
| if not deps: | |
| return initial_ds | |
| if len(deps) == 1: | |
| return self.datasets[deps[0]] | |
| main_ds = self.datasets[deps[0]] | |
| other_dss = [self.datasets[d] for d in deps[1:]] | |
| return main_ds.union(*other_dss) | |
| def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): | |
| def _filter_kwargs( | |
| func_or_class: Callable, | |
| global_params: Dict[str, Any], | |
| func_params: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| """ | |
| 1. global_params: only when specified in function signature, will be passed | |
| 2. func_params: pass specified params first, then **kwargs if exists | |
| """ | |
| try: | |
| sig = inspect.signature(func_or_class) | |
| except ValueError: | |
| return {} | |
| params = sig.parameters | |
| final_kwargs = {} | |
| has_var_keywords = any( | |
| p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() | |
| ) | |
| valid_keys = set(params.keys()) | |
| for k, v in global_params.items(): | |
| if k in valid_keys: | |
| final_kwargs[k] = v | |
| for k, v in func_params.items(): | |
| if k in valid_keys or has_var_keywords: | |
| final_kwargs[k] = v | |
| return final_kwargs | |
| if node.op_name not in self.functions: | |
| raise ValueError(f"Operator {node.op_name} not found for node {node.id}") | |
| op_handler = self.functions[node.op_name] | |
| node_params = _filter_kwargs(op_handler, self.global_params, node.params or {}) | |
| if node.type == "source": | |
| self.datasets[node.id] = op_handler(**node_params) | |
| return | |
| input_ds = self._get_input_dataset(node, initial_ds) | |
| if inspect.isclass(op_handler): | |
| execution_params = node.execution_params or {} | |
| replicas = execution_params.get("replicas", 1) | |
| batch_size = ( | |
| int(execution_params.get("batch_size")) | |
| if "batch_size" in execution_params | |
| else "default" | |
| ) | |
| compute_resources = execution_params.get("compute_resources", {}) | |
| if node.type == "aggregate": | |
| self.datasets[node.id] = input_ds.repartition(1).map_batches( | |
| op_handler, | |
| compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), | |
| batch_size=None, # aggregate processes the whole dataset at once | |
| num_gpus=compute_resources.get("num_gpus", 0) | |
| if compute_resources | |
| else 0, | |
| fn_constructor_kwargs=node_params, | |
| batch_format="pandas", | |
| ) | |
| else: | |
| # others like map, filter, flatmap, map_batch let actors process data inside batches | |
| self.datasets[node.id] = input_ds.map_batches( | |
| op_handler, | |
| compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), | |
| batch_size=batch_size, | |
| num_gpus=compute_resources.get("num_gpus", 0) | |
| if compute_resources | |
| else 0, | |
| fn_constructor_kwargs=node_params, | |
| batch_format="pandas", | |
| ) | |
| else: | |
| def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: | |
| return op_handler(row_or_batch, **node_params) | |
| if node.type == "map": | |
| self.datasets[node.id] = input_ds.map(func_wrapper) | |
| elif node.type == "filter": | |
| self.datasets[node.id] = input_ds.filter(func_wrapper) | |
| elif node.type == "flatmap": | |
| self.datasets[node.id] = input_ds.flat_map(func_wrapper) | |
| elif node.type == "aggregate": | |
| self.datasets[node.id] = input_ds.repartition(1).map_batches( | |
| func_wrapper, batch_format="default" | |
| ) | |
| elif node.type == "map_batch": | |
| self.datasets[node.id] = input_ds.map_batches(func_wrapper) | |
| else: | |
| raise ValueError( | |
| f"Unsupported node type {node.type} for node {node.id}" | |
| ) | |
| def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: | |
| all_ids = {n.id for n in nodes} | |
| deps_set = set() | |
| for n in nodes: | |
| deps_set.update(n.dependencies) | |
| return all_ids - deps_set | |
| def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: | |
| sorted_nodes = self._topo_sort(self.config.nodes) | |
| for node in sorted_nodes: | |
| self._execute_node(node, initial_ds) | |
| leaf_nodes = self._find_leaf_nodes(sorted_nodes) | |
| def _fetch_result(ds: ray.data.Dataset) -> List[Any]: | |
| return ds.take_all() | |
| return {node_id: self.datasets[node_id] for node_id in leaf_nodes} | |