import random from collections import deque from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from tqdm import tqdm from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner NODE_UNIT: str = "n" EDGE_UNIT: str = "e" class ECEPartitioner(BFSPartitioner): """ ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE). We calculate ECE for units in KG (represented as 'comprehension loss') and group units with similar ECE values into the same community. 1. Select a sampling strategy. 2. Choose a unit based on the sampling strategy. 2. Expand the community using BFS. 3. When expending, prefer to add units with the sampling strategy. 4. Stop when the max unit size is reached or the max input length is reached. (A unit is a node or an edge.) """ @staticmethod def _sort_units(units: list, edge_sampling: str) -> list: """ Sort units with edge sampling strategy :param units: total units :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) :return: sorted units """ if edge_sampling == "random": random.shuffle(units) elif edge_sampling == "min_loss": units = sorted( units, key=lambda x: x[-1]["loss"], ) elif edge_sampling == "max_loss": units = sorted( units, key=lambda x: x[-1]["loss"], reverse=True, ) else: raise ValueError(f"Invalid edge sampling: {edge_sampling}") return units def partition( self, g: BaseGraphStorage, max_units_per_community: int = 10, min_units_per_community: int = 1, max_tokens_per_community: int = 10240, unit_sampling: str = "random", **kwargs: Any, ) -> Iterable[Community]: nodes: List[Tuple[str, dict]] = g.get_all_nodes() edges: List[Tuple[str, str, dict]] = g.get_all_edges() adj, _ = self._build_adjacency_list(nodes, edges) node_dict = dict(nodes) edge_dict = {frozenset((u, v)): d for u, v, d in edges} all_units: List[Tuple[str, Any, dict]] = [ (NODE_UNIT, nid, d) for nid, d in nodes ] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges] used_n: Set[str] = set() used_e: Set[frozenset[str]] = set() all_units = self._sort_units(all_units, unit_sampling) def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]: nonlocal used_n, used_e community_nodes: Dict[str, dict] = {} community_edges: Dict[frozenset[str], dict] = {} queue = deque() token_sum = 0 def _add_unit(u): nonlocal token_sum t, i, d = u if t == NODE_UNIT: # node if i in used_n or i in community_nodes: return False community_nodes[i] = d used_n.add(i) else: # edge if i in used_e or i in community_edges: return False community_edges[i] = d used_e.add(i) token_sum += d.get("length", 0) return True _add_unit(seed_unit) queue.append(seed_unit) # BFS while queue: if ( len(community_nodes) + len(community_edges) >= max_units_per_community or token_sum >= max_tokens_per_community ): break cur_type, cur_id, _ = queue.popleft() neighbors: List[Tuple[str, Any, dict]] = [] if cur_type == NODE_UNIT: for nb_id in adj.get(cur_id, []): e_key = frozenset((cur_id, nb_id)) if e_key not in used_e and e_key not in community_edges: neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key])) else: for n_id in cur_id: if n_id not in used_n and n_id not in community_nodes: neighbors.append((NODE_UNIT, n_id, node_dict[n_id])) neighbors = self._sort_units(neighbors, unit_sampling) for nb in neighbors: if ( len(community_nodes) + len(community_edges) >= max_units_per_community or token_sum >= max_tokens_per_community ): break if _add_unit(nb): queue.append(nb) if len(community_nodes) + len(community_edges) < min_units_per_community: return None return Community( id=seed_unit[1], nodes=list(community_nodes.keys()), edges=[tuple(edge) for edge in community_edges if isinstance(edge, frozenset) and len(edge)==2], ) for unit in tqdm(all_units, desc="ECE partition"): utype, uid, _ = unit if (utype == NODE_UNIT and uid in used_n) or ( utype == EDGE_UNIT and uid in used_e ): continue comm = _grow_community(unit) if comm: yield comm