add model checkpoint
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __init__.py +0 -5
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/load.cpython-310.pyc +0 -0
- graph_grammar/.DS_Store +0 -0
- graph_grammar/__init__.py +0 -19
- graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/__pycache__/hypergraph.cpython-310.pyc +0 -0
- graph_grammar/algo/__init__.py +0 -20
- graph_grammar/algo/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc +0 -0
- graph_grammar/algo/tree_decomposition.py +0 -821
- graph_grammar/graph_grammar/__init__.py +0 -20
- graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/base.py +0 -30
- graph_grammar/graph_grammar/corpus.py +0 -152
- graph_grammar/graph_grammar/hrg.py +0 -1065
- graph_grammar/graph_grammar/symbols.py +0 -180
- graph_grammar/graph_grammar/utils.py +0 -130
- graph_grammar/hypergraph.py +0 -544
- graph_grammar/io/__init__.py +0 -20
- graph_grammar/io/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/io/__pycache__/smi.cpython-310.pyc +0 -0
- graph_grammar/io/smi.py +0 -559
- graph_grammar/nn/__init__.py +0 -11
- graph_grammar/nn/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/nn/__pycache__/decoder.cpython-310.pyc +0 -0
- graph_grammar/nn/__pycache__/encoder.cpython-310.pyc +0 -0
- graph_grammar/nn/dataset.py +0 -121
- graph_grammar/nn/decoder.py +0 -158
- graph_grammar/nn/encoder.py +0 -199
- graph_grammar/nn/graph.py +0 -313
- load.py +0 -83
- mhg_gnn.egg-info/PKG-INFO +0 -102
- mhg_gnn.egg-info/SOURCES.txt +0 -46
- mhg_gnn.egg-info/dependency_links.txt +0 -1
- mhg_gnn.egg-info/requires.txt +0 -7
- mhg_gnn.egg-info/top_level.txt +0 -2
- pickles/mhggnn_pretrained_model_0724_2023.pickle → mhggnn_pretrained_model_0724_2023.pickle +0 -0
- models/__init__.py +0 -5
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/mhgvae.cpython-310.pyc +0 -0
- models/mhgvae.py +0 -956
- notebooks/mhg-gnn_encoder_decoder_example.ipynb +0 -114
- paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf +0 -0
- pickles/.DS_Store +0 -0
__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# -*- coding:utf-8 -*-
|
| 2 |
-
# Rhizome
|
| 3 |
-
# Version beta 0.0, August 2023
|
| 4 |
-
# Property of IBM Research, Accelerated Discovery
|
| 5 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (214 Bytes)
|
|
|
__pycache__/load.cpython-310.pyc
DELETED
|
Binary file (3.04 kB)
|
|
|
graph_grammar/.DS_Store
DELETED
|
Binary file (8.2 kB)
|
|
|
graph_grammar/__init__.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
"""
|
| 8 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 9 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 10 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
""" Title """
|
| 14 |
-
|
| 15 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 16 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 17 |
-
__version__ = "0.1"
|
| 18 |
-
__date__ = "Jan 1 2018"
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (666 Bytes)
|
|
|
graph_grammar/__pycache__/hypergraph.cpython-310.pyc
DELETED
|
Binary file (15.3 kB)
|
|
|
graph_grammar/algo/__init__.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding:utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 1 2018"
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/algo/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (659 Bytes)
|
|
|
graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc
DELETED
|
Binary file (19.5 kB)
|
|
|
graph_grammar/algo/tree_decomposition.py
DELETED
|
@@ -1,821 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Dec 11 2017"
|
| 20 |
-
|
| 21 |
-
from copy import deepcopy
|
| 22 |
-
from itertools import combinations
|
| 23 |
-
from ..hypergraph import Hypergraph
|
| 24 |
-
import networkx as nx
|
| 25 |
-
import numpy as np
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class CliqueTree(nx.Graph):
|
| 29 |
-
''' clique tree object
|
| 30 |
-
|
| 31 |
-
Attributes
|
| 32 |
-
----------
|
| 33 |
-
hg : Hypergraph
|
| 34 |
-
This hypergraph will be decomposed.
|
| 35 |
-
root_hg : Hypergraph
|
| 36 |
-
Hypergraph on the root node.
|
| 37 |
-
ident_node_dict : dict
|
| 38 |
-
ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
|
| 39 |
-
'''
|
| 40 |
-
def __init__(self, hg=None, **kwargs):
|
| 41 |
-
self.hg = deepcopy(hg)
|
| 42 |
-
if self.hg is not None:
|
| 43 |
-
self.ident_node_dict = self.hg.get_identical_node_dict()
|
| 44 |
-
else:
|
| 45 |
-
self.ident_node_dict = {}
|
| 46 |
-
super().__init__(**kwargs)
|
| 47 |
-
|
| 48 |
-
@property
|
| 49 |
-
def root_hg(self):
|
| 50 |
-
''' return the hypergraph on the root node
|
| 51 |
-
'''
|
| 52 |
-
return self.nodes[0]['subhg']
|
| 53 |
-
|
| 54 |
-
@root_hg.setter
|
| 55 |
-
def root_hg(self, hypergraph):
|
| 56 |
-
''' set the hypergraph on the root node
|
| 57 |
-
'''
|
| 58 |
-
self.nodes[0]['subhg'] = hypergraph
|
| 59 |
-
|
| 60 |
-
def insert_subhg(self, subhypergraph: Hypergraph) -> None:
|
| 61 |
-
''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
|
| 62 |
-
|
| 63 |
-
Parameters
|
| 64 |
-
----------
|
| 65 |
-
subhg : Hypergraph
|
| 66 |
-
'''
|
| 67 |
-
num_nodes = self.number_of_nodes()
|
| 68 |
-
self.add_node(num_nodes, subhg=subhypergraph)
|
| 69 |
-
self.add_edge(num_nodes, 0)
|
| 70 |
-
adj_nodes = deepcopy(list(self.adj[0].keys()))
|
| 71 |
-
for each_node in adj_nodes:
|
| 72 |
-
if len(self.nodes[each_node]["subhg"].nodes.intersection(
|
| 73 |
-
self.nodes[num_nodes]["subhg"].nodes)\
|
| 74 |
-
- self.root_hg.nodes) != 0 and each_node != num_nodes:
|
| 75 |
-
self.remove_edge(0, each_node)
|
| 76 |
-
self.add_edge(each_node, num_nodes)
|
| 77 |
-
|
| 78 |
-
def to_irredundant(self) -> None:
|
| 79 |
-
''' convert the clique tree to be irredundant
|
| 80 |
-
'''
|
| 81 |
-
for each_node in self.hg.nodes:
|
| 82 |
-
subtree = self.subgraph([
|
| 83 |
-
each_tree_node for each_tree_node in self.nodes()\
|
| 84 |
-
if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
|
| 85 |
-
leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
|
| 86 |
-
redundant_leaf_node_list = []
|
| 87 |
-
for each_leaf_node in leaf_node_list:
|
| 88 |
-
if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
|
| 89 |
-
redundant_leaf_node_list.append(each_leaf_node)
|
| 90 |
-
for each_red_leaf_node in redundant_leaf_node_list:
|
| 91 |
-
current_node = each_red_leaf_node
|
| 92 |
-
while subtree.degree(current_node) == 1 \
|
| 93 |
-
and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
|
| 94 |
-
self.nodes[current_node]["subhg"].remove_node(each_node)
|
| 95 |
-
remove_node = current_node
|
| 96 |
-
current_node = list(dict(subtree[remove_node]).keys())[0]
|
| 97 |
-
subtree.remove_node(remove_node)
|
| 98 |
-
|
| 99 |
-
fixed_node_set = deepcopy(self.nodes)
|
| 100 |
-
for each_node in fixed_node_set:
|
| 101 |
-
if self.nodes[each_node]["subhg"].num_edges == 0:
|
| 102 |
-
if len(self[each_node]) == 1:
|
| 103 |
-
self.remove_node(each_node)
|
| 104 |
-
elif len(self[each_node]) == 2:
|
| 105 |
-
self.add_edge(*self[each_node])
|
| 106 |
-
self.remove_node(each_node)
|
| 107 |
-
else:
|
| 108 |
-
pass
|
| 109 |
-
else:
|
| 110 |
-
pass
|
| 111 |
-
|
| 112 |
-
redundant = True
|
| 113 |
-
while redundant:
|
| 114 |
-
redundant = False
|
| 115 |
-
fixed_edge_set = deepcopy(self.edges)
|
| 116 |
-
remove_node_set = set()
|
| 117 |
-
for node_1, node_2 in fixed_edge_set:
|
| 118 |
-
if node_1 in remove_node_set or node_2 in remove_node_set:
|
| 119 |
-
pass
|
| 120 |
-
else:
|
| 121 |
-
if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
|
| 122 |
-
redundant = True
|
| 123 |
-
adj_node_list = set(self.adj[node_1]) - {node_2}
|
| 124 |
-
self.remove_node(node_1)
|
| 125 |
-
remove_node_set.add(node_1)
|
| 126 |
-
for each_node in adj_node_list:
|
| 127 |
-
self.add_edge(node_2, each_node)
|
| 128 |
-
|
| 129 |
-
elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
|
| 130 |
-
redundant = True
|
| 131 |
-
adj_node_list = set(self.adj[node_2]) - {node_1}
|
| 132 |
-
self.remove_node(node_2)
|
| 133 |
-
remove_node_set.add(node_2)
|
| 134 |
-
for each_node in adj_node_list:
|
| 135 |
-
self.add_edge(node_1, each_node)
|
| 136 |
-
|
| 137 |
-
def node_update(self, key_node: str, subhg) -> None:
|
| 138 |
-
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
| 139 |
-
|
| 140 |
-
Parameters
|
| 141 |
-
----------
|
| 142 |
-
key_node : str
|
| 143 |
-
key node that must be removed.
|
| 144 |
-
subhg : Hypegraph
|
| 145 |
-
"""
|
| 146 |
-
for each_edge in subhg.edges:
|
| 147 |
-
self.root_hg.remove_edge(each_edge)
|
| 148 |
-
self.root_hg.remove_nodes(self.ident_node_dict[key_node])
|
| 149 |
-
|
| 150 |
-
adj_node_list = list(subhg.nodes)
|
| 151 |
-
for each_node in subhg.nodes:
|
| 152 |
-
if each_node not in self.ident_node_dict[key_node]:
|
| 153 |
-
if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
|
| 154 |
-
self.root_hg.remove_node(each_node)
|
| 155 |
-
adj_node_list.remove(each_node)
|
| 156 |
-
else:
|
| 157 |
-
adj_node_list.remove(each_node)
|
| 158 |
-
|
| 159 |
-
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
| 160 |
-
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
| 161 |
-
self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
|
| 162 |
-
|
| 163 |
-
subhg.remove_edges_with_attr({'tmp' : True})
|
| 164 |
-
self.insert_subhg(subhg)
|
| 165 |
-
|
| 166 |
-
def update(self, subhg, remove_nodes=False):
|
| 167 |
-
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
| 168 |
-
|
| 169 |
-
Parameters
|
| 170 |
-
----------
|
| 171 |
-
subhg : Hypegraph
|
| 172 |
-
"""
|
| 173 |
-
for each_edge in subhg.edges:
|
| 174 |
-
self.root_hg.remove_edge(each_edge)
|
| 175 |
-
if remove_nodes:
|
| 176 |
-
remove_edge_list = []
|
| 177 |
-
for each_edge in self.root_hg.edges:
|
| 178 |
-
if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
|
| 179 |
-
and self.root_hg.edge_attr(each_edge).get('tmp', False):
|
| 180 |
-
remove_edge_list.append(each_edge)
|
| 181 |
-
self.root_hg.remove_edges(remove_edge_list)
|
| 182 |
-
|
| 183 |
-
adj_node_list = list(subhg.nodes)
|
| 184 |
-
for each_node in subhg.nodes:
|
| 185 |
-
if self.root_hg.degree(each_node) == 0:
|
| 186 |
-
self.root_hg.remove_node(each_node)
|
| 187 |
-
adj_node_list.remove(each_node)
|
| 188 |
-
|
| 189 |
-
if len(adj_node_list) != 1 and not remove_nodes:
|
| 190 |
-
self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
|
| 191 |
-
'''
|
| 192 |
-
else:
|
| 193 |
-
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
| 194 |
-
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
| 195 |
-
self.root_hg.add_edge(
|
| 196 |
-
[each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
| 197 |
-
'''
|
| 198 |
-
subhg.remove_edges_with_attr({'tmp':True})
|
| 199 |
-
self.insert_subhg(subhg)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
|
| 203 |
-
if mode == 'standard':
|
| 204 |
-
degree_dict = hg.degrees()
|
| 205 |
-
min_deg_node = min(degree_dict, key=degree_dict.get)
|
| 206 |
-
min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
|
| 207 |
-
return min_deg_node, min_deg_subhg
|
| 208 |
-
elif mode == 'mol':
|
| 209 |
-
degree_dict = hg.degrees()
|
| 210 |
-
min_deg = min(degree_dict.values())
|
| 211 |
-
min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
|
| 212 |
-
min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
|
| 213 |
-
for each_min_deg_node in min_deg_node_list]
|
| 214 |
-
best_score = np.inf
|
| 215 |
-
best_idx = -1
|
| 216 |
-
for each_idx in range(len(min_deg_subhg_list)):
|
| 217 |
-
if min_deg_subhg_list[each_idx].num_nodes < best_score:
|
| 218 |
-
best_idx = each_idx
|
| 219 |
-
return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
|
| 220 |
-
else:
|
| 221 |
-
raise ValueError
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def tree_decomposition(hg, irredundant=True):
|
| 225 |
-
""" compute a tree decomposition of the input hypergraph
|
| 226 |
-
|
| 227 |
-
Parameters
|
| 228 |
-
----------
|
| 229 |
-
hg : Hypergraph
|
| 230 |
-
hypergraph to be decomposed
|
| 231 |
-
irredundant : bool
|
| 232 |
-
if True, irredundant tree decomposition will be computed.
|
| 233 |
-
|
| 234 |
-
Returns
|
| 235 |
-
-------
|
| 236 |
-
clique_tree : nx.Graph
|
| 237 |
-
each node contains a subhypergraph of `hg`
|
| 238 |
-
"""
|
| 239 |
-
org_hg = hg.copy()
|
| 240 |
-
ident_node_dict = hg.get_identical_node_dict()
|
| 241 |
-
clique_tree = CliqueTree(org_hg)
|
| 242 |
-
clique_tree.add_node(0, subhg=org_hg)
|
| 243 |
-
while True:
|
| 244 |
-
degree_dict = org_hg.degrees()
|
| 245 |
-
min_deg_node = min(degree_dict, key=degree_dict.get)
|
| 246 |
-
min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
|
| 247 |
-
if org_hg.nodes == min_deg_subhg.nodes:
|
| 248 |
-
break
|
| 249 |
-
|
| 250 |
-
# org_hg and min_deg_subhg are divided
|
| 251 |
-
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
| 252 |
-
|
| 253 |
-
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
| 254 |
-
|
| 255 |
-
if irredundant:
|
| 256 |
-
clique_tree.to_irredundant()
|
| 257 |
-
return clique_tree
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
|
| 261 |
-
''' compute a tree decomposition given a hyperedge replacement grammar.
|
| 262 |
-
the resultant clique tree should induce a less compact HRG.
|
| 263 |
-
|
| 264 |
-
Parameters
|
| 265 |
-
----------
|
| 266 |
-
hg : Hypergraph
|
| 267 |
-
hypergraph to be decomposed
|
| 268 |
-
hrg : HyperedgeReplacementGrammar
|
| 269 |
-
current HRG
|
| 270 |
-
irredundant : bool
|
| 271 |
-
if True, irredundant tree decomposition will be computed.
|
| 272 |
-
|
| 273 |
-
Returns
|
| 274 |
-
-------
|
| 275 |
-
clique_tree : nx.Graph
|
| 276 |
-
each node contains a subhypergraph of `hg`
|
| 277 |
-
'''
|
| 278 |
-
org_hg = hg.copy()
|
| 279 |
-
ident_node_dict = hg.get_identical_node_dict()
|
| 280 |
-
clique_tree = CliqueTree(org_hg)
|
| 281 |
-
clique_tree.add_node(0, subhg=org_hg)
|
| 282 |
-
root_node = 0
|
| 283 |
-
|
| 284 |
-
# construct a clique tree using HRG
|
| 285 |
-
success_any = True
|
| 286 |
-
while success_any:
|
| 287 |
-
success_any = False
|
| 288 |
-
for each_prod_rule in hrg.prod_rule_list:
|
| 289 |
-
org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
|
| 290 |
-
if success:
|
| 291 |
-
if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
|
| 292 |
-
success_any = True
|
| 293 |
-
subhg.remove_edges_with_attr({'terminal' : False})
|
| 294 |
-
clique_tree.root_hg = org_hg
|
| 295 |
-
clique_tree.insert_subhg(subhg)
|
| 296 |
-
|
| 297 |
-
clique_tree.root_hg = org_hg
|
| 298 |
-
|
| 299 |
-
for each_edge in deepcopy(org_hg.edges):
|
| 300 |
-
if not org_hg.edge_attr(each_edge)['terminal']:
|
| 301 |
-
node_list = org_hg.nodes_in_edge(each_edge)
|
| 302 |
-
org_hg.remove_edge(each_edge)
|
| 303 |
-
|
| 304 |
-
for each_node_1, each_node_2 in combinations(node_list, 2):
|
| 305 |
-
if not org_hg.is_adj(each_node_1, each_node_2):
|
| 306 |
-
org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
| 307 |
-
|
| 308 |
-
# construct a clique tree using the existing algorithm
|
| 309 |
-
degree_dict = org_hg.degrees()
|
| 310 |
-
if degree_dict:
|
| 311 |
-
while True:
|
| 312 |
-
min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
|
| 313 |
-
if org_hg.nodes == min_deg_subhg.nodes: break
|
| 314 |
-
|
| 315 |
-
# org_hg and min_deg_subhg are divided
|
| 316 |
-
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
| 317 |
-
|
| 318 |
-
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
| 319 |
-
if irredundant:
|
| 320 |
-
clique_tree.to_irredundant()
|
| 321 |
-
|
| 322 |
-
if return_root:
|
| 323 |
-
if root_node == 0 and 0 not in clique_tree.nodes:
|
| 324 |
-
root_node = clique_tree.number_of_nodes()
|
| 325 |
-
while root_node not in clique_tree.nodes:
|
| 326 |
-
root_node -= 1
|
| 327 |
-
elif root_node not in clique_tree.nodes:
|
| 328 |
-
while root_node not in clique_tree.nodes:
|
| 329 |
-
root_node -= 1
|
| 330 |
-
else:
|
| 331 |
-
pass
|
| 332 |
-
return clique_tree, root_node
|
| 333 |
-
else:
|
| 334 |
-
return clique_tree
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
def tree_decomposition_from_leaf(hg, irredundant=True):
|
| 338 |
-
""" compute a tree decomposition of the input hypergraph
|
| 339 |
-
|
| 340 |
-
Parameters
|
| 341 |
-
----------
|
| 342 |
-
hg : Hypergraph
|
| 343 |
-
hypergraph to be decomposed
|
| 344 |
-
irredundant : bool
|
| 345 |
-
if True, irredundant tree decomposition will be computed.
|
| 346 |
-
|
| 347 |
-
Returns
|
| 348 |
-
-------
|
| 349 |
-
clique_tree : nx.Graph
|
| 350 |
-
each node contains a subhypergraph of `hg`
|
| 351 |
-
"""
|
| 352 |
-
def apply_normal_decomposition(clique_tree):
|
| 353 |
-
degree_dict = clique_tree.root_hg.degrees()
|
| 354 |
-
min_deg_node = min(degree_dict, key=degree_dict.get)
|
| 355 |
-
min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
|
| 356 |
-
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
| 357 |
-
return clique_tree, False
|
| 358 |
-
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
| 359 |
-
return clique_tree, True
|
| 360 |
-
|
| 361 |
-
def apply_min_edge_deg_decomposition(clique_tree):
|
| 362 |
-
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
| 363 |
-
non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
| 364 |
-
if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
|
| 365 |
-
if not non_tmp_edge_list:
|
| 366 |
-
return clique_tree, False
|
| 367 |
-
min_deg_edge = None
|
| 368 |
-
min_deg = np.inf
|
| 369 |
-
for each_edge in non_tmp_edge_list:
|
| 370 |
-
if min_deg > edge_degree_dict[each_edge]:
|
| 371 |
-
min_deg_edge = each_edge
|
| 372 |
-
min_deg = edge_degree_dict[each_edge]
|
| 373 |
-
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
| 374 |
-
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
| 375 |
-
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
| 376 |
-
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
| 377 |
-
return clique_tree, False
|
| 378 |
-
clique_tree.update(min_deg_subhg)
|
| 379 |
-
return clique_tree, True
|
| 380 |
-
|
| 381 |
-
org_hg = hg.copy()
|
| 382 |
-
clique_tree = CliqueTree(org_hg)
|
| 383 |
-
clique_tree.add_node(0, subhg=org_hg)
|
| 384 |
-
|
| 385 |
-
success = True
|
| 386 |
-
while success:
|
| 387 |
-
clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
|
| 388 |
-
if not success:
|
| 389 |
-
clique_tree, success = apply_normal_decomposition(clique_tree)
|
| 390 |
-
|
| 391 |
-
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
| 392 |
-
if irredundant:
|
| 393 |
-
clique_tree.to_irredundant()
|
| 394 |
-
return clique_tree
|
| 395 |
-
|
| 396 |
-
def topological_tree_decomposition(
|
| 397 |
-
hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
|
| 398 |
-
''' compute a tree decomposition of the input hypergraph
|
| 399 |
-
|
| 400 |
-
Parameters
|
| 401 |
-
----------
|
| 402 |
-
hg : Hypergraph
|
| 403 |
-
hypergraph to be decomposed
|
| 404 |
-
irredundant : bool
|
| 405 |
-
if True, irredundant tree decomposition will be computed.
|
| 406 |
-
|
| 407 |
-
Returns
|
| 408 |
-
-------
|
| 409 |
-
clique_tree : CliqueTree
|
| 410 |
-
each node contains a subhypergraph of `hg`
|
| 411 |
-
'''
|
| 412 |
-
def _contract_tree(clique_tree):
|
| 413 |
-
''' contract a single leaf
|
| 414 |
-
|
| 415 |
-
Parameters
|
| 416 |
-
----------
|
| 417 |
-
clique_tree : CliqueTree
|
| 418 |
-
|
| 419 |
-
Returns
|
| 420 |
-
-------
|
| 421 |
-
CliqueTree, bool
|
| 422 |
-
bool represents whether this operation succeeds or not.
|
| 423 |
-
'''
|
| 424 |
-
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
| 425 |
-
leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
| 426 |
-
if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
|
| 427 |
-
and edge_degree_dict[each_edge] == 1]
|
| 428 |
-
if not leaf_edge_list:
|
| 429 |
-
return clique_tree, False
|
| 430 |
-
min_deg_edge = leaf_edge_list[0]
|
| 431 |
-
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
| 432 |
-
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
| 433 |
-
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
| 434 |
-
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
| 435 |
-
return clique_tree, False
|
| 436 |
-
clique_tree.update(min_deg_subhg)
|
| 437 |
-
return clique_tree, True
|
| 438 |
-
|
| 439 |
-
def _rip_labels_from_cycles(clique_tree, org_hg):
|
| 440 |
-
''' rip hyperedge-labels off
|
| 441 |
-
|
| 442 |
-
Parameters
|
| 443 |
-
----------
|
| 444 |
-
clique_tree : CliqueTree
|
| 445 |
-
org_hg : Hypergraph
|
| 446 |
-
|
| 447 |
-
Returns
|
| 448 |
-
-------
|
| 449 |
-
CliqueTree, bool
|
| 450 |
-
bool represents whether this operation succeeds or not.
|
| 451 |
-
'''
|
| 452 |
-
ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
|
| 453 |
-
for each_edge in clique_tree.root_hg.edges:
|
| 454 |
-
if each_edge in org_hg.edges:
|
| 455 |
-
if org_hg.in_cycle(each_edge):
|
| 456 |
-
node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
|
| 457 |
-
subhg = clique_tree.root_hg.get_subhg(
|
| 458 |
-
node_list, [each_edge], ident_node_dict)
|
| 459 |
-
if clique_tree.root_hg.nodes == subhg.nodes:
|
| 460 |
-
return clique_tree, False
|
| 461 |
-
clique_tree.update(subhg)
|
| 462 |
-
'''
|
| 463 |
-
in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
|
| 464 |
-
if not all(in_cycle_dict.values()):
|
| 465 |
-
node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
|
| 466 |
-
node_list = [node_not_in_cycle]
|
| 467 |
-
node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
|
| 468 |
-
edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
|
| 469 |
-
import pdb; pdb.set_trace()
|
| 470 |
-
subhg = clique_tree.root_hg.get_subhg(
|
| 471 |
-
node_list, edge_list, ident_node_dict)
|
| 472 |
-
|
| 473 |
-
clique_tree.update(subhg)
|
| 474 |
-
'''
|
| 475 |
-
return clique_tree, True
|
| 476 |
-
return clique_tree, False
|
| 477 |
-
|
| 478 |
-
def _shrink_cycle(clique_tree):
|
| 479 |
-
''' shrink a cycle
|
| 480 |
-
|
| 481 |
-
Parameters
|
| 482 |
-
----------
|
| 483 |
-
clique_tree : CliqueTree
|
| 484 |
-
|
| 485 |
-
Returns
|
| 486 |
-
-------
|
| 487 |
-
CliqueTree, bool
|
| 488 |
-
bool represents whether this operation succeeds or not.
|
| 489 |
-
'''
|
| 490 |
-
def filter_subhg(subhg, hg, key_node):
|
| 491 |
-
num_nodes_cycle = 0
|
| 492 |
-
nodes_in_cycle_list = []
|
| 493 |
-
for each_node in subhg.nodes:
|
| 494 |
-
if hg.in_cycle(each_node):
|
| 495 |
-
num_nodes_cycle += 1
|
| 496 |
-
if each_node != key_node:
|
| 497 |
-
nodes_in_cycle_list.append(each_node)
|
| 498 |
-
if num_nodes_cycle > 3:
|
| 499 |
-
break
|
| 500 |
-
if num_nodes_cycle != 3:
|
| 501 |
-
return False
|
| 502 |
-
else:
|
| 503 |
-
for each_edge in hg.edges:
|
| 504 |
-
if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
|
| 505 |
-
return False
|
| 506 |
-
return True
|
| 507 |
-
|
| 508 |
-
#ident_node_dict = hg.get_identical_node_dict()
|
| 509 |
-
ident_node_dict = clique_tree.ident_node_dict
|
| 510 |
-
for each_node in clique_tree.root_hg.nodes:
|
| 511 |
-
if clique_tree.root_hg.in_cycle(each_node)\
|
| 512 |
-
and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
|
| 513 |
-
clique_tree.root_hg,
|
| 514 |
-
each_node):
|
| 515 |
-
target_node = each_node
|
| 516 |
-
target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
|
| 517 |
-
if clique_tree.root_hg.nodes == target_subhg.nodes:
|
| 518 |
-
return clique_tree, False
|
| 519 |
-
clique_tree.update(target_subhg)
|
| 520 |
-
return clique_tree, True
|
| 521 |
-
return clique_tree, False
|
| 522 |
-
|
| 523 |
-
def _contract_cycles(clique_tree):
|
| 524 |
-
'''
|
| 525 |
-
remove a subhypergraph that looks like a cycle on a leaf.
|
| 526 |
-
|
| 527 |
-
Parameters
|
| 528 |
-
----------
|
| 529 |
-
clique_tree : CliqueTree
|
| 530 |
-
|
| 531 |
-
Returns
|
| 532 |
-
-------
|
| 533 |
-
CliqueTree, bool
|
| 534 |
-
bool represents whether this operation succeeds or not.
|
| 535 |
-
'''
|
| 536 |
-
def _divide_hg(hg):
|
| 537 |
-
''' divide a hypergraph into subhypergraphs such that
|
| 538 |
-
each subhypergraph is connected to each other in a tree-like way.
|
| 539 |
-
|
| 540 |
-
Parameters
|
| 541 |
-
----------
|
| 542 |
-
hg : Hypergraph
|
| 543 |
-
|
| 544 |
-
Returns
|
| 545 |
-
-------
|
| 546 |
-
list of Hypergraphs
|
| 547 |
-
each element corresponds to a subhypergraph of `hg`
|
| 548 |
-
'''
|
| 549 |
-
for each_node in hg.nodes:
|
| 550 |
-
if hg.is_dividable(each_node):
|
| 551 |
-
adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
|
| 552 |
-
'''
|
| 553 |
-
if any(adj_edges_dict.values()):
|
| 554 |
-
import pdb; pdb.set_trace()
|
| 555 |
-
edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
|
| 556 |
-
subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
|
| 557 |
-
return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
|
| 558 |
-
else:
|
| 559 |
-
'''
|
| 560 |
-
subhg1, subhg2 = hg.divide(each_node)
|
| 561 |
-
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
| 562 |
-
return [hg]
|
| 563 |
-
|
| 564 |
-
def _is_leaf(hg, divided_subhg) -> bool:
|
| 565 |
-
''' judge whether subhg is a leaf-like in the original hypergraph
|
| 566 |
-
|
| 567 |
-
Parameters
|
| 568 |
-
----------
|
| 569 |
-
hg : Hypergraph
|
| 570 |
-
divided_subhg : Hypergraph
|
| 571 |
-
`divided_subhg` is a subhypergraph of `hg`
|
| 572 |
-
|
| 573 |
-
Returns
|
| 574 |
-
-------
|
| 575 |
-
bool
|
| 576 |
-
'''
|
| 577 |
-
'''
|
| 578 |
-
adj_edges_set = set([])
|
| 579 |
-
for each_node in divided_subhg.nodes:
|
| 580 |
-
adj_edges_set.update(set(hg.adj_edges(each_node)))
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
_hg = deepcopy(hg)
|
| 584 |
-
_hg.remove_subhg(divided_subhg)
|
| 585 |
-
if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
|
| 586 |
-
import pdb; pdb.set_trace()
|
| 587 |
-
return len(adj_edges_set - divided_subhg.edges) == 1
|
| 588 |
-
'''
|
| 589 |
-
_hg = deepcopy(hg)
|
| 590 |
-
_hg.remove_subhg(divided_subhg)
|
| 591 |
-
return nx.is_connected(_hg.hg)
|
| 592 |
-
|
| 593 |
-
subhg_list = _divide_hg(clique_tree.root_hg)
|
| 594 |
-
if len(subhg_list) == 1:
|
| 595 |
-
return clique_tree, False
|
| 596 |
-
else:
|
| 597 |
-
while len(subhg_list) > 1:
|
| 598 |
-
max_leaf_subhg = None
|
| 599 |
-
for each_subhg in subhg_list:
|
| 600 |
-
if _is_leaf(clique_tree.root_hg, each_subhg):
|
| 601 |
-
if max_leaf_subhg is None:
|
| 602 |
-
max_leaf_subhg = each_subhg
|
| 603 |
-
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
| 604 |
-
max_leaf_subhg = each_subhg
|
| 605 |
-
clique_tree.update(max_leaf_subhg)
|
| 606 |
-
subhg_list.remove(max_leaf_subhg)
|
| 607 |
-
return clique_tree, True
|
| 608 |
-
|
| 609 |
-
org_hg = hg.copy()
|
| 610 |
-
clique_tree = CliqueTree(org_hg)
|
| 611 |
-
clique_tree.add_node(0, subhg=org_hg)
|
| 612 |
-
|
| 613 |
-
success = True
|
| 614 |
-
while success:
|
| 615 |
-
'''
|
| 616 |
-
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
| 617 |
-
if not success:
|
| 618 |
-
clique_tree, success = _contract_cycles(clique_tree)
|
| 619 |
-
'''
|
| 620 |
-
clique_tree, success = _contract_tree(clique_tree)
|
| 621 |
-
if not success:
|
| 622 |
-
if rip_labels:
|
| 623 |
-
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
| 624 |
-
if not success:
|
| 625 |
-
if shrink_cycle:
|
| 626 |
-
clique_tree, success = _shrink_cycle(clique_tree)
|
| 627 |
-
if not success:
|
| 628 |
-
if contract_cycles:
|
| 629 |
-
clique_tree, success = _contract_cycles(clique_tree)
|
| 630 |
-
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
| 631 |
-
if irredundant:
|
| 632 |
-
clique_tree.to_irredundant()
|
| 633 |
-
return clique_tree
|
| 634 |
-
|
| 635 |
-
def molecular_tree_decomposition(hg, irredundant=True):
|
| 636 |
-
""" compute a tree decomposition of the input molecular hypergraph
|
| 637 |
-
|
| 638 |
-
Parameters
|
| 639 |
-
----------
|
| 640 |
-
hg : Hypergraph
|
| 641 |
-
molecular hypergraph to be decomposed
|
| 642 |
-
irredundant : bool
|
| 643 |
-
if True, irredundant tree decomposition will be computed.
|
| 644 |
-
|
| 645 |
-
Returns
|
| 646 |
-
-------
|
| 647 |
-
clique_tree : CliqueTree
|
| 648 |
-
each node contains a subhypergraph of `hg`
|
| 649 |
-
"""
|
| 650 |
-
def _divide_hg(hg):
|
| 651 |
-
''' divide a hypergraph into subhypergraphs such that
|
| 652 |
-
each subhypergraph is connected to each other in a tree-like way.
|
| 653 |
-
|
| 654 |
-
Parameters
|
| 655 |
-
----------
|
| 656 |
-
hg : Hypergraph
|
| 657 |
-
|
| 658 |
-
Returns
|
| 659 |
-
-------
|
| 660 |
-
list of Hypergraphs
|
| 661 |
-
each element corresponds to a subhypergraph of `hg`
|
| 662 |
-
'''
|
| 663 |
-
is_ring = False
|
| 664 |
-
for each_node in hg.nodes:
|
| 665 |
-
if hg.node_attr(each_node)['is_in_ring']:
|
| 666 |
-
is_ring = True
|
| 667 |
-
if not hg.node_attr(each_node)['is_in_ring'] \
|
| 668 |
-
and hg.degree(each_node) == 2:
|
| 669 |
-
subhg1, subhg2 = hg.divide(each_node)
|
| 670 |
-
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
| 671 |
-
|
| 672 |
-
if is_ring:
|
| 673 |
-
subhg_list = []
|
| 674 |
-
remove_edge_list = []
|
| 675 |
-
remove_node_list = []
|
| 676 |
-
for each_edge in hg.edges:
|
| 677 |
-
node_list = hg.nodes_in_edge(each_edge)
|
| 678 |
-
subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
|
| 679 |
-
subhg_list.append(subhg)
|
| 680 |
-
remove_edge_list.append(each_edge)
|
| 681 |
-
for each_node in node_list:
|
| 682 |
-
if not hg.node_attr(each_node)['is_in_ring']:
|
| 683 |
-
remove_node_list.append(each_node)
|
| 684 |
-
hg.remove_edges(remove_edge_list)
|
| 685 |
-
hg.remove_nodes(remove_node_list, False)
|
| 686 |
-
return subhg_list + [hg]
|
| 687 |
-
else:
|
| 688 |
-
return [hg]
|
| 689 |
-
|
| 690 |
-
org_hg = hg.copy()
|
| 691 |
-
clique_tree = CliqueTree(org_hg)
|
| 692 |
-
clique_tree.add_node(0, subhg=org_hg)
|
| 693 |
-
|
| 694 |
-
subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
|
| 695 |
-
#_subhg_list = deepcopy(subhg_list)
|
| 696 |
-
if len(subhg_list) == 1:
|
| 697 |
-
pass
|
| 698 |
-
else:
|
| 699 |
-
while len(subhg_list) > 1:
|
| 700 |
-
max_leaf_subhg = None
|
| 701 |
-
for each_subhg in subhg_list:
|
| 702 |
-
if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
|
| 703 |
-
if max_leaf_subhg is None:
|
| 704 |
-
max_leaf_subhg = each_subhg
|
| 705 |
-
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
| 706 |
-
max_leaf_subhg = each_subhg
|
| 707 |
-
|
| 708 |
-
if max_leaf_subhg is None:
|
| 709 |
-
for each_subhg in subhg_list:
|
| 710 |
-
if _is_ring_label(clique_tree.root_hg, each_subhg):
|
| 711 |
-
if max_leaf_subhg is None:
|
| 712 |
-
max_leaf_subhg = each_subhg
|
| 713 |
-
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
| 714 |
-
max_leaf_subhg = each_subhg
|
| 715 |
-
if max_leaf_subhg is not None:
|
| 716 |
-
clique_tree.update(max_leaf_subhg)
|
| 717 |
-
subhg_list.remove(max_leaf_subhg)
|
| 718 |
-
else:
|
| 719 |
-
for each_subhg in subhg_list:
|
| 720 |
-
if _is_leaf(clique_tree.root_hg, each_subhg):
|
| 721 |
-
if max_leaf_subhg is None:
|
| 722 |
-
max_leaf_subhg = each_subhg
|
| 723 |
-
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
| 724 |
-
max_leaf_subhg = each_subhg
|
| 725 |
-
if max_leaf_subhg is not None:
|
| 726 |
-
clique_tree.update(max_leaf_subhg, True)
|
| 727 |
-
subhg_list.remove(max_leaf_subhg)
|
| 728 |
-
else:
|
| 729 |
-
break
|
| 730 |
-
if len(subhg_list) > 1:
|
| 731 |
-
'''
|
| 732 |
-
for each_idx, each_subhg in enumerate(subhg_list):
|
| 733 |
-
each_subhg.draw(f'{each_idx}', True)
|
| 734 |
-
clique_tree.root_hg.draw('root', True)
|
| 735 |
-
import pickle
|
| 736 |
-
with open('buggy_hg.pkl', 'wb') as f:
|
| 737 |
-
pickle.dump(hg, f)
|
| 738 |
-
return clique_tree, subhg_list, _subhg_list
|
| 739 |
-
'''
|
| 740 |
-
raise RuntimeError('bug in tree decomposition algorithm')
|
| 741 |
-
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
| 742 |
-
|
| 743 |
-
'''
|
| 744 |
-
for each_tree_node in clique_tree.adj[0]:
|
| 745 |
-
subhg = clique_tree.nodes[each_tree_node]['subhg']
|
| 746 |
-
for each_edge in subhg.edges:
|
| 747 |
-
if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
|
| 748 |
-
clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
|
| 749 |
-
'''
|
| 750 |
-
if irredundant:
|
| 751 |
-
clique_tree.to_irredundant()
|
| 752 |
-
return clique_tree #, _subhg_list
|
| 753 |
-
|
| 754 |
-
def _is_leaf(hg, subhg) -> bool:
|
| 755 |
-
''' judge whether subhg is a leaf-like in the original hypergraph
|
| 756 |
-
|
| 757 |
-
Parameters
|
| 758 |
-
----------
|
| 759 |
-
hg : Hypergraph
|
| 760 |
-
subhg : Hypergraph
|
| 761 |
-
`subhg` is a subhypergraph of `hg`
|
| 762 |
-
|
| 763 |
-
Returns
|
| 764 |
-
-------
|
| 765 |
-
bool
|
| 766 |
-
'''
|
| 767 |
-
if len(subhg.edges) == 0:
|
| 768 |
-
adj_edge_set = set([])
|
| 769 |
-
subhg_edge_set = set([])
|
| 770 |
-
for each_edge in hg.edges:
|
| 771 |
-
if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
|
| 772 |
-
subhg_edge_set.add(each_edge)
|
| 773 |
-
for each_node in subhg.nodes:
|
| 774 |
-
adj_edge_set.update(set(hg.adj_edges(each_node)))
|
| 775 |
-
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
| 776 |
-
return True
|
| 777 |
-
else:
|
| 778 |
-
return False
|
| 779 |
-
elif len(subhg.edges) == 1:
|
| 780 |
-
adj_edge_set = set([])
|
| 781 |
-
subhg_edge_set = subhg.edges
|
| 782 |
-
for each_node in subhg.nodes:
|
| 783 |
-
for each_adj_edge in hg.adj_edges(each_node):
|
| 784 |
-
adj_edge_set.add(each_adj_edge)
|
| 785 |
-
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
| 786 |
-
return True
|
| 787 |
-
else:
|
| 788 |
-
return False
|
| 789 |
-
else:
|
| 790 |
-
raise ValueError('subhg should be nodes only or one-edge hypergraph.')
|
| 791 |
-
|
| 792 |
-
def _is_ring_label(hg, subhg):
|
| 793 |
-
if len(subhg.edges) != 1:
|
| 794 |
-
return False
|
| 795 |
-
edge_name = list(subhg.edges)[0]
|
| 796 |
-
#assert edge_name in hg.edges, f'{edge_name}'
|
| 797 |
-
is_in_ring = False
|
| 798 |
-
for each_node in subhg.nodes:
|
| 799 |
-
if subhg.node_attr(each_node)['is_in_ring']:
|
| 800 |
-
is_in_ring = True
|
| 801 |
-
else:
|
| 802 |
-
adj_edge_list = list(hg.adj_edges(each_node))
|
| 803 |
-
adj_edge_list.remove(edge_name)
|
| 804 |
-
if len(adj_edge_list) == 1:
|
| 805 |
-
if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
|
| 806 |
-
return False
|
| 807 |
-
elif len(adj_edge_list) == 0:
|
| 808 |
-
pass
|
| 809 |
-
else:
|
| 810 |
-
raise ValueError
|
| 811 |
-
if is_in_ring:
|
| 812 |
-
return True
|
| 813 |
-
else:
|
| 814 |
-
return False
|
| 815 |
-
|
| 816 |
-
def _is_ring(hg):
|
| 817 |
-
for each_node in hg.nodes:
|
| 818 |
-
if not hg.node_attr(each_node)['is_in_ring']:
|
| 819 |
-
return False
|
| 820 |
-
return True
|
| 821 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/__init__.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 1 2018"
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (680 Bytes)
|
|
|
graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc
DELETED
|
Binary file (1.17 kB)
|
|
|
graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc
DELETED
|
Binary file (4.71 kB)
|
|
|
graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc
DELETED
|
Binary file (29.1 kB)
|
|
|
graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc
DELETED
|
Binary file (5.38 kB)
|
|
|
graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc
DELETED
|
Binary file (3.63 kB)
|
|
|
graph_grammar/graph_grammar/base.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Dec 11 2017"
|
| 20 |
-
|
| 21 |
-
from abc import ABCMeta, abstractmethod
|
| 22 |
-
|
| 23 |
-
class GraphGrammarBase(metaclass=ABCMeta):
|
| 24 |
-
@abstractmethod
|
| 25 |
-
def learn(self):
|
| 26 |
-
pass
|
| 27 |
-
|
| 28 |
-
@abstractmethod
|
| 29 |
-
def sample(self):
|
| 30 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/corpus.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jun 4 2018"
|
| 20 |
-
|
| 21 |
-
from collections import Counter
|
| 22 |
-
from functools import partial
|
| 23 |
-
from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
|
| 24 |
-
from networkx.algorithms.isomorphism import GraphMatcher
|
| 25 |
-
import os
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class CliqueTreeCorpus(object):
|
| 29 |
-
|
| 30 |
-
''' clique tree corpus
|
| 31 |
-
|
| 32 |
-
Attributes
|
| 33 |
-
----------
|
| 34 |
-
clique_tree_list : list of CliqueTree
|
| 35 |
-
subhg_list : list of Hypergraph
|
| 36 |
-
'''
|
| 37 |
-
|
| 38 |
-
def __init__(self):
|
| 39 |
-
self.clique_tree_list = []
|
| 40 |
-
self.subhg_list = []
|
| 41 |
-
|
| 42 |
-
@property
|
| 43 |
-
def size(self):
|
| 44 |
-
return len(self.subhg_list)
|
| 45 |
-
|
| 46 |
-
def add_clique_tree(self, clique_tree):
|
| 47 |
-
for each_node in clique_tree.nodes:
|
| 48 |
-
subhg = clique_tree.nodes[each_node]['subhg']
|
| 49 |
-
subhg_idx = self.add_subhg(subhg)
|
| 50 |
-
clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
|
| 51 |
-
self.clique_tree_list.append(clique_tree)
|
| 52 |
-
|
| 53 |
-
def add_to_subhg_list(self, clique_tree, root_node):
|
| 54 |
-
parent_node_dict = {}
|
| 55 |
-
current_node = None
|
| 56 |
-
parent_node_dict[root_node] = None
|
| 57 |
-
stack = [root_node]
|
| 58 |
-
while stack:
|
| 59 |
-
current_node = stack.pop()
|
| 60 |
-
current_subhg = clique_tree.nodes[current_node]['subhg']
|
| 61 |
-
for each_child in clique_tree.adj[current_node]:
|
| 62 |
-
if each_child != parent_node_dict[current_node]:
|
| 63 |
-
stack.append(each_child)
|
| 64 |
-
parent_node_dict[each_child] = current_node
|
| 65 |
-
if parent_node_dict[current_node] is not None:
|
| 66 |
-
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
| 67 |
-
common, _ = common_node_list(parent_subhg, current_subhg)
|
| 68 |
-
parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
|
| 69 |
-
|
| 70 |
-
parent_node_dict = {}
|
| 71 |
-
current_node = None
|
| 72 |
-
parent_node_dict[root_node] = None
|
| 73 |
-
stack = [root_node]
|
| 74 |
-
while stack:
|
| 75 |
-
current_node = stack.pop()
|
| 76 |
-
current_subhg = clique_tree.nodes[current_node]['subhg']
|
| 77 |
-
for each_child in clique_tree.adj[current_node]:
|
| 78 |
-
if each_child != parent_node_dict[current_node]:
|
| 79 |
-
stack.append(each_child)
|
| 80 |
-
parent_node_dict[each_child] = current_node
|
| 81 |
-
if parent_node_dict[current_node] is not None:
|
| 82 |
-
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
| 83 |
-
common, _ = common_node_list(parent_subhg, current_subhg)
|
| 84 |
-
for each_idx, each_node in enumerate(common):
|
| 85 |
-
current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
|
| 86 |
-
|
| 87 |
-
subhg_idx, is_new = self.add_subhg(current_subhg)
|
| 88 |
-
clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
|
| 89 |
-
return clique_tree
|
| 90 |
-
|
| 91 |
-
def add_subhg(self, subhg):
|
| 92 |
-
if len(self.subhg_list) == 0:
|
| 93 |
-
node_dict = {}
|
| 94 |
-
for each_node in subhg.nodes:
|
| 95 |
-
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
| 96 |
-
node_list = []
|
| 97 |
-
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
| 98 |
-
node_list.append(each_key)
|
| 99 |
-
for each_idx, each_node in enumerate(node_list):
|
| 100 |
-
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
| 101 |
-
self.subhg_list.append(subhg)
|
| 102 |
-
return 0, True
|
| 103 |
-
else:
|
| 104 |
-
match = False
|
| 105 |
-
subhg_bond_symbol_counter \
|
| 106 |
-
= Counter([subhg.node_attr(each_node)['symbol'] \
|
| 107 |
-
for each_node in subhg.nodes])
|
| 108 |
-
subhg_atom_symbol_counter \
|
| 109 |
-
= Counter([subhg.edge_attr(each_edge).get('symbol', None) \
|
| 110 |
-
for each_edge in subhg.edges])
|
| 111 |
-
for each_idx, each_subhg in enumerate(self.subhg_list):
|
| 112 |
-
each_bond_symbol_counter \
|
| 113 |
-
= Counter([each_subhg.node_attr(each_node)['symbol'] \
|
| 114 |
-
for each_node in each_subhg.nodes])
|
| 115 |
-
each_atom_symbol_counter \
|
| 116 |
-
= Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
|
| 117 |
-
for each_edge in each_subhg.edges])
|
| 118 |
-
if not match \
|
| 119 |
-
and (subhg.num_nodes == each_subhg.num_nodes
|
| 120 |
-
and subhg.num_edges == each_subhg.num_edges
|
| 121 |
-
and subhg_bond_symbol_counter == each_bond_symbol_counter
|
| 122 |
-
and subhg_atom_symbol_counter == each_atom_symbol_counter):
|
| 123 |
-
gm = GraphMatcher(each_subhg.hg,
|
| 124 |
-
subhg.hg,
|
| 125 |
-
node_match=_easy_node_match,
|
| 126 |
-
edge_match=_edge_match)
|
| 127 |
-
try:
|
| 128 |
-
isomap = next(gm.isomorphisms_iter())
|
| 129 |
-
match = True
|
| 130 |
-
for each_node in each_subhg.nodes:
|
| 131 |
-
subhg.node_attr(isomap[each_node])['order4hrg'] \
|
| 132 |
-
= each_subhg.node_attr(each_node)['order4hrg']
|
| 133 |
-
if 'ext_id' in each_subhg.node_attr(each_node):
|
| 134 |
-
subhg.node_attr(isomap[each_node])['ext_id'] \
|
| 135 |
-
= each_subhg.node_attr(each_node)['ext_id']
|
| 136 |
-
return each_idx, False
|
| 137 |
-
except StopIteration:
|
| 138 |
-
match = False
|
| 139 |
-
if not match:
|
| 140 |
-
node_dict = {}
|
| 141 |
-
for each_node in subhg.nodes:
|
| 142 |
-
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
| 143 |
-
node_list = []
|
| 144 |
-
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
| 145 |
-
node_list.append(each_key)
|
| 146 |
-
for each_idx, each_node in enumerate(node_list):
|
| 147 |
-
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
| 148 |
-
|
| 149 |
-
#for each_idx, each_node in enumerate(subhg.nodes):
|
| 150 |
-
# subhg.node_attr(each_node)['order4hrg'] = each_idx
|
| 151 |
-
self.subhg_list.append(subhg)
|
| 152 |
-
return len(self.subhg_list) - 1, True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/hrg.py
DELETED
|
@@ -1,1065 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Dec 11 2017"
|
| 20 |
-
|
| 21 |
-
from .corpus import CliqueTreeCorpus
|
| 22 |
-
from .base import GraphGrammarBase
|
| 23 |
-
from .symbols import TSymbol, NTSymbol, BondSymbol
|
| 24 |
-
from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
|
| 25 |
-
from ..hypergraph import Hypergraph
|
| 26 |
-
from collections import Counter
|
| 27 |
-
from copy import deepcopy
|
| 28 |
-
from ..algo.tree_decomposition import (
|
| 29 |
-
tree_decomposition,
|
| 30 |
-
tree_decomposition_with_hrg,
|
| 31 |
-
tree_decomposition_from_leaf,
|
| 32 |
-
topological_tree_decomposition,
|
| 33 |
-
molecular_tree_decomposition)
|
| 34 |
-
from functools import partial
|
| 35 |
-
from networkx.algorithms.isomorphism import GraphMatcher
|
| 36 |
-
from typing import List, Dict, Tuple
|
| 37 |
-
import networkx as nx
|
| 38 |
-
import numpy as np
|
| 39 |
-
import torch
|
| 40 |
-
import os
|
| 41 |
-
import random
|
| 42 |
-
|
| 43 |
-
DEBUG = False
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class ProductionRule(object):
|
| 47 |
-
""" A class of a production rule
|
| 48 |
-
|
| 49 |
-
Attributes
|
| 50 |
-
----------
|
| 51 |
-
lhs : Hypergraph or None
|
| 52 |
-
the left hand side of the production rule.
|
| 53 |
-
if None, the rule is a starting rule.
|
| 54 |
-
rhs : Hypergraph
|
| 55 |
-
the right hand side of the production rule.
|
| 56 |
-
"""
|
| 57 |
-
def __init__(self, lhs, rhs):
|
| 58 |
-
self.lhs = lhs
|
| 59 |
-
self.rhs = rhs
|
| 60 |
-
|
| 61 |
-
@property
|
| 62 |
-
def is_start_rule(self) -> bool:
|
| 63 |
-
return self.lhs.num_nodes == 0
|
| 64 |
-
|
| 65 |
-
@property
|
| 66 |
-
def ext_node(self) -> Dict[int, str]:
|
| 67 |
-
""" return a dict of external nodes
|
| 68 |
-
"""
|
| 69 |
-
if self.is_start_rule:
|
| 70 |
-
return {}
|
| 71 |
-
else:
|
| 72 |
-
ext_node_dict = {}
|
| 73 |
-
for each_node in self.lhs.nodes:
|
| 74 |
-
ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
|
| 75 |
-
return ext_node_dict
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
def lhs_nt_symbol(self) -> NTSymbol:
|
| 79 |
-
if self.is_start_rule:
|
| 80 |
-
return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
|
| 81 |
-
else:
|
| 82 |
-
return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
|
| 83 |
-
|
| 84 |
-
def rhs_adj_mat(self, node_edge_list):
|
| 85 |
-
''' return the adjacency matrix of rhs of the production rule
|
| 86 |
-
'''
|
| 87 |
-
return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
|
| 88 |
-
|
| 89 |
-
def draw(self, file_path=None):
|
| 90 |
-
return self.rhs.draw(file_path)
|
| 91 |
-
|
| 92 |
-
def is_same(self, prod_rule, ignore_order=False):
|
| 93 |
-
""" judge whether this production rule is
|
| 94 |
-
the same as the input one, `prod_rule`
|
| 95 |
-
|
| 96 |
-
Parameters
|
| 97 |
-
----------
|
| 98 |
-
prod_rule : ProductionRule
|
| 99 |
-
production rule to be compared
|
| 100 |
-
|
| 101 |
-
Returns
|
| 102 |
-
-------
|
| 103 |
-
is_same : bool
|
| 104 |
-
isomap : dict
|
| 105 |
-
isomorphism of nodes and hyperedges.
|
| 106 |
-
ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
|
| 107 |
-
'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
|
| 108 |
-
'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
|
| 109 |
-
key comes from `prod_rule`, value comes from `self`.
|
| 110 |
-
"""
|
| 111 |
-
if self.is_start_rule:
|
| 112 |
-
if not prod_rule.is_start_rule:
|
| 113 |
-
return False, {}
|
| 114 |
-
else:
|
| 115 |
-
if prod_rule.is_start_rule:
|
| 116 |
-
return False, {}
|
| 117 |
-
else:
|
| 118 |
-
if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
|
| 119 |
-
return False, {}
|
| 120 |
-
|
| 121 |
-
if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
|
| 122 |
-
return False, {}
|
| 123 |
-
if prod_rule.rhs.num_edges != self.rhs.num_edges:
|
| 124 |
-
return False, {}
|
| 125 |
-
|
| 126 |
-
subhg_bond_symbol_counter \
|
| 127 |
-
= Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
|
| 128 |
-
for each_node in prod_rule.rhs.nodes])
|
| 129 |
-
each_bond_symbol_counter \
|
| 130 |
-
= Counter([self.rhs.node_attr(each_node)['symbol'] \
|
| 131 |
-
for each_node in self.rhs.nodes])
|
| 132 |
-
if subhg_bond_symbol_counter != each_bond_symbol_counter:
|
| 133 |
-
return False, {}
|
| 134 |
-
|
| 135 |
-
subhg_atom_symbol_counter \
|
| 136 |
-
= Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
|
| 137 |
-
for each_edge in prod_rule.rhs.edges])
|
| 138 |
-
each_atom_symbol_counter \
|
| 139 |
-
= Counter([self.rhs.edge_attr(each_edge)['symbol'] \
|
| 140 |
-
for each_edge in self.rhs.edges])
|
| 141 |
-
if subhg_atom_symbol_counter != each_atom_symbol_counter:
|
| 142 |
-
return False, {}
|
| 143 |
-
|
| 144 |
-
gm = GraphMatcher(prod_rule.rhs.hg,
|
| 145 |
-
self.rhs.hg,
|
| 146 |
-
partial(_node_match_prod_rule,
|
| 147 |
-
ignore_order=ignore_order),
|
| 148 |
-
partial(_edge_match,
|
| 149 |
-
ignore_order=ignore_order))
|
| 150 |
-
try:
|
| 151 |
-
return True, next(gm.isomorphisms_iter())
|
| 152 |
-
except StopIteration:
|
| 153 |
-
return False, {}
|
| 154 |
-
|
| 155 |
-
def applied_to(self,
|
| 156 |
-
hg: Hypergraph,
|
| 157 |
-
edge: str) -> Tuple[Hypergraph, List[str]]:
|
| 158 |
-
""" augment `hg` by replacing `edge` with `self.rhs`.
|
| 159 |
-
|
| 160 |
-
Parameters
|
| 161 |
-
----------
|
| 162 |
-
hg : Hypergraph
|
| 163 |
-
edge : str
|
| 164 |
-
`edge` must belong to `hg`
|
| 165 |
-
|
| 166 |
-
Returns
|
| 167 |
-
-------
|
| 168 |
-
hg : Hypergraph
|
| 169 |
-
resultant hypergraph
|
| 170 |
-
nt_edge_list : list
|
| 171 |
-
list of non-terminal edges
|
| 172 |
-
"""
|
| 173 |
-
nt_edge_dict = {}
|
| 174 |
-
if self.is_start_rule:
|
| 175 |
-
if (edge is not None) or (hg is not None):
|
| 176 |
-
ValueError("edge and hg must be None for this prod rule.")
|
| 177 |
-
hg = Hypergraph()
|
| 178 |
-
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
| 179 |
-
for num_idx, each_node in enumerate(self.rhs.nodes):
|
| 180 |
-
hg.add_node(f"bond_{num_idx}",
|
| 181 |
-
#attr_dict=deepcopy(self.rhs.node_attr(each_node)))
|
| 182 |
-
attr_dict=self.rhs.node_attr(each_node))
|
| 183 |
-
node_map_rhs[each_node] = f"bond_{num_idx}"
|
| 184 |
-
for each_edge in self.rhs.edges:
|
| 185 |
-
node_list = []
|
| 186 |
-
for each_node in self.rhs.nodes_in_edge(each_edge):
|
| 187 |
-
node_list.append(node_map_rhs[each_node])
|
| 188 |
-
if isinstance(self.rhs.nodes_in_edge(each_edge), set):
|
| 189 |
-
node_list = set(node_list)
|
| 190 |
-
edge_id = hg.add_edge(
|
| 191 |
-
node_list,
|
| 192 |
-
#attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
|
| 193 |
-
attr_dict=self.rhs.edge_attr(each_edge))
|
| 194 |
-
if "nt_idx" in hg.edge_attr(edge_id):
|
| 195 |
-
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
| 196 |
-
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
| 197 |
-
return hg, nt_edge_list
|
| 198 |
-
else:
|
| 199 |
-
if edge not in hg.edges:
|
| 200 |
-
raise ValueError("the input hyperedge does not exist.")
|
| 201 |
-
if hg.edge_attr(edge)["terminal"]:
|
| 202 |
-
raise ValueError("the input hyperedge is terminal.")
|
| 203 |
-
if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
|
| 204 |
-
print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
|
| 205 |
-
raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
|
| 206 |
-
if DEBUG:
|
| 207 |
-
for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
| 208 |
-
other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
|
| 209 |
-
attr = deepcopy(self.lhs.node_attr(other_node))
|
| 210 |
-
attr.pop('ext_id')
|
| 211 |
-
if hg.node_attr(each_node) != attr:
|
| 212 |
-
raise ValueError('node attributes are inconsistent.')
|
| 213 |
-
|
| 214 |
-
# order of nodes that belong to the non-terminal edge in hg
|
| 215 |
-
nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
|
| 216 |
-
nt_order_dict_inv = {} # order -> hg_node
|
| 217 |
-
for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
| 218 |
-
nt_order_dict[each_node] = each_idx
|
| 219 |
-
nt_order_dict_inv[each_idx] = each_node
|
| 220 |
-
|
| 221 |
-
# construct a node_map_rhs: rhs -> new hg
|
| 222 |
-
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
| 223 |
-
node_idx = hg.num_nodes
|
| 224 |
-
for each_node in self.rhs.nodes:
|
| 225 |
-
if "ext_id" in self.rhs.node_attr(each_node):
|
| 226 |
-
node_map_rhs[each_node] \
|
| 227 |
-
= nt_order_dict_inv[
|
| 228 |
-
self.rhs.node_attr(each_node)["ext_id"]]
|
| 229 |
-
else:
|
| 230 |
-
node_map_rhs[each_node] = f"bond_{node_idx}"
|
| 231 |
-
node_idx += 1
|
| 232 |
-
|
| 233 |
-
# delete non-terminal
|
| 234 |
-
hg.remove_edge(edge)
|
| 235 |
-
|
| 236 |
-
# add nodes to hg
|
| 237 |
-
for each_node in self.rhs.nodes:
|
| 238 |
-
hg.add_node(node_map_rhs[each_node],
|
| 239 |
-
attr_dict=self.rhs.node_attr(each_node))
|
| 240 |
-
|
| 241 |
-
# add hyperedges to hg
|
| 242 |
-
for each_edge in self.rhs.edges:
|
| 243 |
-
node_list_hg = []
|
| 244 |
-
for each_node in self.rhs.nodes_in_edge(each_edge):
|
| 245 |
-
node_list_hg.append(node_map_rhs[each_node])
|
| 246 |
-
edge_id = hg.add_edge(
|
| 247 |
-
node_list_hg,
|
| 248 |
-
attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
|
| 249 |
-
if "nt_idx" in hg.edge_attr(edge_id):
|
| 250 |
-
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
| 251 |
-
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
| 252 |
-
return hg, nt_edge_list
|
| 253 |
-
|
| 254 |
-
def revert(self, hg: Hypergraph, return_subhg=False):
|
| 255 |
-
''' revert applying this production rule.
|
| 256 |
-
i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
|
| 257 |
-
this method replaces the subhypergraph with a non-terminal hyperedge.
|
| 258 |
-
|
| 259 |
-
Parameters
|
| 260 |
-
----------
|
| 261 |
-
hg : Hypergraph
|
| 262 |
-
hypergraph to be reverted
|
| 263 |
-
return_subhg : bool
|
| 264 |
-
if True, the removed subhypergraph will be returned.
|
| 265 |
-
|
| 266 |
-
Returns
|
| 267 |
-
-------
|
| 268 |
-
hg : Hypergraph
|
| 269 |
-
the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
|
| 270 |
-
success : bool
|
| 271 |
-
this indicates whether reverting is successed or not.
|
| 272 |
-
'''
|
| 273 |
-
gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
|
| 274 |
-
edge_match=_edge_match)
|
| 275 |
-
try:
|
| 276 |
-
# in case when the matched subhg is connected to the other part via external nodes and more.
|
| 277 |
-
not_iso = True
|
| 278 |
-
while not_iso:
|
| 279 |
-
isomap = next(gm.subgraph_isomorphisms_iter())
|
| 280 |
-
adj_node_set = set([]) # reachable nodes from the internal nodes
|
| 281 |
-
subhg_node_set = set(isomap.keys()) # nodes in subhg
|
| 282 |
-
for each_node in subhg_node_set:
|
| 283 |
-
adj_node_set.add(each_node)
|
| 284 |
-
if isomap[each_node] not in self.ext_node.values():
|
| 285 |
-
adj_node_set.update(hg.hg.adj[each_node])
|
| 286 |
-
if adj_node_set == subhg_node_set:
|
| 287 |
-
not_iso = False
|
| 288 |
-
else:
|
| 289 |
-
if return_subhg:
|
| 290 |
-
return hg, False, Hypergraph()
|
| 291 |
-
else:
|
| 292 |
-
return hg, False
|
| 293 |
-
inv_isomap = {v: k for k, v in isomap.items()}
|
| 294 |
-
'''
|
| 295 |
-
isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
|
| 296 |
-
'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
|
| 297 |
-
where keys come from `hg` and values come from `self.rhs`
|
| 298 |
-
'''
|
| 299 |
-
except StopIteration:
|
| 300 |
-
if return_subhg:
|
| 301 |
-
return hg, False, Hypergraph()
|
| 302 |
-
else:
|
| 303 |
-
return hg, False
|
| 304 |
-
|
| 305 |
-
if return_subhg:
|
| 306 |
-
subhg = Hypergraph()
|
| 307 |
-
for each_node in hg.nodes:
|
| 308 |
-
if each_node in isomap:
|
| 309 |
-
subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
|
| 310 |
-
for each_edge in hg.edges:
|
| 311 |
-
if each_edge in isomap:
|
| 312 |
-
subhg.add_edge(hg.nodes_in_edge(each_edge),
|
| 313 |
-
attr_dict=hg.edge_attr(each_edge),
|
| 314 |
-
edge_name=each_edge)
|
| 315 |
-
subhg.edge_idx = hg.edge_idx
|
| 316 |
-
|
| 317 |
-
# remove subhg except for the externael nodes
|
| 318 |
-
for each_key, each_val in isomap.items():
|
| 319 |
-
if each_key.startswith('e'):
|
| 320 |
-
hg.remove_edge(each_key)
|
| 321 |
-
for each_key, each_val in isomap.items():
|
| 322 |
-
if each_key.startswith('bond_'):
|
| 323 |
-
if each_val not in self.ext_node.values():
|
| 324 |
-
hg.remove_node(each_key)
|
| 325 |
-
|
| 326 |
-
# add non-terminal hyperedge
|
| 327 |
-
nt_node_list = []
|
| 328 |
-
for each_ext_id in self.ext_node.keys():
|
| 329 |
-
nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
|
| 330 |
-
|
| 331 |
-
hg.add_edge(nt_node_list,
|
| 332 |
-
attr_dict=dict(
|
| 333 |
-
terminal=False,
|
| 334 |
-
symbol=self.lhs_nt_symbol))
|
| 335 |
-
if return_subhg:
|
| 336 |
-
return hg, True, subhg
|
| 337 |
-
else:
|
| 338 |
-
return hg, True
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
class ProductionRuleCorpus(object):
|
| 342 |
-
|
| 343 |
-
'''
|
| 344 |
-
A corpus of production rules.
|
| 345 |
-
This class maintains
|
| 346 |
-
(i) list of unique production rules,
|
| 347 |
-
(ii) list of unique edge symbols (both terminal and non-terminal), and
|
| 348 |
-
(iii) list of unique node symbols.
|
| 349 |
-
|
| 350 |
-
Attributes
|
| 351 |
-
----------
|
| 352 |
-
prod_rule_list : list
|
| 353 |
-
list of unique production rules
|
| 354 |
-
edge_symbol_list : list
|
| 355 |
-
list of unique symbols (including both terminal and non-terminal)
|
| 356 |
-
node_symbol_list : list
|
| 357 |
-
list of node symbols
|
| 358 |
-
nt_symbol_list : list
|
| 359 |
-
list of unique lhs symbols
|
| 360 |
-
ext_id_list : list
|
| 361 |
-
list of ext_ids
|
| 362 |
-
lhs_in_prod_rule : array
|
| 363 |
-
a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
|
| 364 |
-
'''
|
| 365 |
-
|
| 366 |
-
def __init__(self):
|
| 367 |
-
self.prod_rule_list = []
|
| 368 |
-
self.edge_symbol_list = []
|
| 369 |
-
self.edge_symbol_dict = {}
|
| 370 |
-
self.node_symbol_list = []
|
| 371 |
-
self.node_symbol_dict = {}
|
| 372 |
-
self.nt_symbol_list = []
|
| 373 |
-
self.ext_id_list = []
|
| 374 |
-
self._lhs_in_prod_rule = None
|
| 375 |
-
self.lhs_in_prod_rule_row_list = []
|
| 376 |
-
self.lhs_in_prod_rule_col_list = []
|
| 377 |
-
|
| 378 |
-
@property
|
| 379 |
-
def lhs_in_prod_rule(self):
|
| 380 |
-
if self._lhs_in_prod_rule is None:
|
| 381 |
-
self._lhs_in_prod_rule = torch.sparse.FloatTensor(
|
| 382 |
-
torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
|
| 383 |
-
torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
|
| 384 |
-
torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
|
| 385 |
-
).to_dense()
|
| 386 |
-
return self._lhs_in_prod_rule
|
| 387 |
-
|
| 388 |
-
@property
|
| 389 |
-
def num_prod_rule(self):
|
| 390 |
-
''' return the number of production rules
|
| 391 |
-
|
| 392 |
-
Returns
|
| 393 |
-
-------
|
| 394 |
-
int : the number of unique production rules
|
| 395 |
-
'''
|
| 396 |
-
return len(self.prod_rule_list)
|
| 397 |
-
|
| 398 |
-
@property
|
| 399 |
-
def start_rule_list(self):
|
| 400 |
-
''' return a list of start rules
|
| 401 |
-
|
| 402 |
-
Returns
|
| 403 |
-
-------
|
| 404 |
-
list : list of start rules
|
| 405 |
-
'''
|
| 406 |
-
start_rule_list = []
|
| 407 |
-
for each_prod_rule in self.prod_rule_list:
|
| 408 |
-
if each_prod_rule.is_start_rule:
|
| 409 |
-
start_rule_list.append(each_prod_rule)
|
| 410 |
-
return start_rule_list
|
| 411 |
-
|
| 412 |
-
@property
|
| 413 |
-
def num_edge_symbol(self):
|
| 414 |
-
return len(self.edge_symbol_list)
|
| 415 |
-
|
| 416 |
-
@property
|
| 417 |
-
def num_node_symbol(self):
|
| 418 |
-
return len(self.node_symbol_list)
|
| 419 |
-
|
| 420 |
-
@property
|
| 421 |
-
def num_ext_id(self):
|
| 422 |
-
return len(self.ext_id_list)
|
| 423 |
-
|
| 424 |
-
def construct_feature_vectors(self):
|
| 425 |
-
''' this method constructs feature vectors for the production rules collected so far.
|
| 426 |
-
currently, NTSymbol and TSymbol are treated in the same manner.
|
| 427 |
-
'''
|
| 428 |
-
feature_id_dict = {}
|
| 429 |
-
feature_id_dict['TSymbol'] = 0
|
| 430 |
-
feature_id_dict['NTSymbol'] = 1
|
| 431 |
-
feature_id_dict['BondSymbol'] = 2
|
| 432 |
-
for each_edge_symbol in self.edge_symbol_list:
|
| 433 |
-
for each_attr in each_edge_symbol.__dict__.keys():
|
| 434 |
-
each_val = each_edge_symbol.__dict__[each_attr]
|
| 435 |
-
if isinstance(each_val, list):
|
| 436 |
-
each_val = tuple(each_val)
|
| 437 |
-
if (each_attr, each_val) not in feature_id_dict:
|
| 438 |
-
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
| 439 |
-
|
| 440 |
-
for each_node_symbol in self.node_symbol_list:
|
| 441 |
-
for each_attr in each_node_symbol.__dict__.keys():
|
| 442 |
-
each_val = each_node_symbol.__dict__[each_attr]
|
| 443 |
-
if isinstance(each_val, list):
|
| 444 |
-
each_val = tuple(each_val)
|
| 445 |
-
if (each_attr, each_val) not in feature_id_dict:
|
| 446 |
-
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
| 447 |
-
for each_ext_id in self.ext_id_list:
|
| 448 |
-
feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
|
| 449 |
-
dim = len(feature_id_dict)
|
| 450 |
-
|
| 451 |
-
feature_dict = {}
|
| 452 |
-
for each_edge_symbol in self.edge_symbol_list:
|
| 453 |
-
idx_list = []
|
| 454 |
-
idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
|
| 455 |
-
for each_attr in each_edge_symbol.__dict__.keys():
|
| 456 |
-
each_val = each_edge_symbol.__dict__[each_attr]
|
| 457 |
-
if isinstance(each_val, list):
|
| 458 |
-
each_val = tuple(each_val)
|
| 459 |
-
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
| 460 |
-
feature = torch.sparse.LongTensor(
|
| 461 |
-
torch.LongTensor([idx_list]),
|
| 462 |
-
torch.ones(len(idx_list)),
|
| 463 |
-
torch.Size([len(feature_id_dict)])
|
| 464 |
-
)
|
| 465 |
-
feature_dict[each_edge_symbol] = feature
|
| 466 |
-
|
| 467 |
-
for each_node_symbol in self.node_symbol_list:
|
| 468 |
-
idx_list = []
|
| 469 |
-
idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
|
| 470 |
-
for each_attr in each_node_symbol.__dict__.keys():
|
| 471 |
-
each_val = each_node_symbol.__dict__[each_attr]
|
| 472 |
-
if isinstance(each_val, list):
|
| 473 |
-
each_val = tuple(each_val)
|
| 474 |
-
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
| 475 |
-
feature = torch.sparse.LongTensor(
|
| 476 |
-
torch.LongTensor([idx_list]),
|
| 477 |
-
torch.ones(len(idx_list)),
|
| 478 |
-
torch.Size([len(feature_id_dict)])
|
| 479 |
-
)
|
| 480 |
-
feature_dict[each_node_symbol] = feature
|
| 481 |
-
for each_ext_id in self.ext_id_list:
|
| 482 |
-
idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
|
| 483 |
-
feature_dict[('ext_id', each_ext_id)] \
|
| 484 |
-
= torch.sparse.LongTensor(
|
| 485 |
-
torch.LongTensor([idx_list]),
|
| 486 |
-
torch.ones(len(idx_list)),
|
| 487 |
-
torch.Size([len(feature_id_dict)])
|
| 488 |
-
)
|
| 489 |
-
return feature_dict, dim
|
| 490 |
-
|
| 491 |
-
def edge_symbol_idx(self, symbol):
|
| 492 |
-
return self.edge_symbol_dict[symbol]
|
| 493 |
-
|
| 494 |
-
def node_symbol_idx(self, symbol):
|
| 495 |
-
return self.node_symbol_dict[symbol]
|
| 496 |
-
|
| 497 |
-
def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
|
| 498 |
-
""" return whether the input production rule is new or not, and its production rule id.
|
| 499 |
-
Production rules are regarded as the same if
|
| 500 |
-
i) there exists a one-to-one mapping of nodes and edges, and
|
| 501 |
-
ii) all the attributes associated with nodes and hyperedges are the same.
|
| 502 |
-
|
| 503 |
-
Parameters
|
| 504 |
-
----------
|
| 505 |
-
prod_rule : ProductionRule
|
| 506 |
-
|
| 507 |
-
Returns
|
| 508 |
-
-------
|
| 509 |
-
prod_rule_id : int
|
| 510 |
-
production rule index. if new, a new index will be assigned.
|
| 511 |
-
prod_rule : ProductionRule
|
| 512 |
-
"""
|
| 513 |
-
num_lhs = len(self.nt_symbol_list)
|
| 514 |
-
for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
|
| 515 |
-
is_same, isomap = prod_rule.is_same(each_prod_rule)
|
| 516 |
-
if is_same:
|
| 517 |
-
# we do not care about edge and node names, but care about the order of non-terminal edges.
|
| 518 |
-
for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
|
| 519 |
-
if key.startswith("bond_"):
|
| 520 |
-
continue
|
| 521 |
-
|
| 522 |
-
# rewrite `nt_idx` in `prod_rule` for further processing
|
| 523 |
-
if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
|
| 524 |
-
if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
|
| 525 |
-
raise ValueError
|
| 526 |
-
prod_rule.rhs.set_edge_attr(
|
| 527 |
-
val,
|
| 528 |
-
{'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
|
| 529 |
-
return each_idx, prod_rule
|
| 530 |
-
self.prod_rule_list.append(prod_rule)
|
| 531 |
-
self._update_edge_symbol_list(prod_rule)
|
| 532 |
-
self._update_node_symbol_list(prod_rule)
|
| 533 |
-
self._update_ext_id_list(prod_rule)
|
| 534 |
-
|
| 535 |
-
lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
|
| 536 |
-
self.lhs_in_prod_rule_row_list.append(lhs_idx)
|
| 537 |
-
self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
|
| 538 |
-
self._lhs_in_prod_rule = None
|
| 539 |
-
return len(self.prod_rule_list)-1, prod_rule
|
| 540 |
-
|
| 541 |
-
def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
|
| 542 |
-
return self.prod_rule_list[prod_rule_idx]
|
| 543 |
-
|
| 544 |
-
def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
|
| 545 |
-
''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
|
| 546 |
-
|
| 547 |
-
Parameters
|
| 548 |
-
----------
|
| 549 |
-
unmasked_logit_array : array-like, length `num_prod_rule`
|
| 550 |
-
nt_symbol : NTSymbol
|
| 551 |
-
'''
|
| 552 |
-
if not isinstance(unmasked_logit_array, np.ndarray):
|
| 553 |
-
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
| 554 |
-
if deterministic:
|
| 555 |
-
prob = masked_softmax(unmasked_logit_array,
|
| 556 |
-
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
| 557 |
-
return self.prod_rule_list[np.argmax(prob)]
|
| 558 |
-
else:
|
| 559 |
-
return np.random.choice(
|
| 560 |
-
self.prod_rule_list, 1,
|
| 561 |
-
p=masked_softmax(unmasked_logit_array,
|
| 562 |
-
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
|
| 563 |
-
|
| 564 |
-
def masked_logprob(self, unmasked_logit_array, nt_symbol):
|
| 565 |
-
if not isinstance(unmasked_logit_array, np.ndarray):
|
| 566 |
-
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
| 567 |
-
prob = masked_softmax(unmasked_logit_array,
|
| 568 |
-
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
| 569 |
-
return np.log(prob)
|
| 570 |
-
|
| 571 |
-
def _update_edge_symbol_list(self, prod_rule: ProductionRule):
|
| 572 |
-
''' update edge symbol list
|
| 573 |
-
|
| 574 |
-
Parameters
|
| 575 |
-
----------
|
| 576 |
-
prod_rule : ProductionRule
|
| 577 |
-
'''
|
| 578 |
-
if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
|
| 579 |
-
self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
|
| 580 |
-
|
| 581 |
-
for each_edge in prod_rule.rhs.edges:
|
| 582 |
-
if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
|
| 583 |
-
edge_symbol_idx = len(self.edge_symbol_list)
|
| 584 |
-
self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
|
| 585 |
-
self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
|
| 586 |
-
else:
|
| 587 |
-
edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
| 588 |
-
prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
|
| 589 |
-
pass
|
| 590 |
-
|
| 591 |
-
def _update_node_symbol_list(self, prod_rule: ProductionRule):
|
| 592 |
-
''' update node symbol list
|
| 593 |
-
|
| 594 |
-
Parameters
|
| 595 |
-
----------
|
| 596 |
-
prod_rule : ProductionRule
|
| 597 |
-
'''
|
| 598 |
-
for each_node in prod_rule.rhs.nodes:
|
| 599 |
-
if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
|
| 600 |
-
node_symbol_idx = len(self.node_symbol_list)
|
| 601 |
-
self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
|
| 602 |
-
self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
|
| 603 |
-
else:
|
| 604 |
-
node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
|
| 605 |
-
prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
|
| 606 |
-
|
| 607 |
-
def _update_ext_id_list(self, prod_rule: ProductionRule):
|
| 608 |
-
for each_node in prod_rule.rhs.nodes:
|
| 609 |
-
if 'ext_id' in prod_rule.rhs.node_attr(each_node):
|
| 610 |
-
if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
|
| 611 |
-
self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
class HyperedgeReplacementGrammar(GraphGrammarBase):
|
| 615 |
-
"""
|
| 616 |
-
Learn a hyperedge replacement grammar from a set of hypergraphs.
|
| 617 |
-
|
| 618 |
-
Attributes
|
| 619 |
-
----------
|
| 620 |
-
prod_rule_list : list of ProductionRule
|
| 621 |
-
production rules learned from the input hypergraphs
|
| 622 |
-
"""
|
| 623 |
-
def __init__(self,
|
| 624 |
-
tree_decomposition=molecular_tree_decomposition,
|
| 625 |
-
ignore_order=False, **kwargs):
|
| 626 |
-
from functools import partial
|
| 627 |
-
self.prod_rule_corpus = ProductionRuleCorpus()
|
| 628 |
-
self.clique_tree_corpus = CliqueTreeCorpus()
|
| 629 |
-
self.ignore_order = ignore_order
|
| 630 |
-
self.tree_decomposition = partial(tree_decomposition, **kwargs)
|
| 631 |
-
|
| 632 |
-
@property
|
| 633 |
-
def num_prod_rule(self):
|
| 634 |
-
''' return the number of production rules
|
| 635 |
-
|
| 636 |
-
Returns
|
| 637 |
-
-------
|
| 638 |
-
int : the number of unique production rules
|
| 639 |
-
'''
|
| 640 |
-
return self.prod_rule_corpus.num_prod_rule
|
| 641 |
-
|
| 642 |
-
@property
|
| 643 |
-
def start_rule_list(self):
|
| 644 |
-
''' return a list of start rules
|
| 645 |
-
|
| 646 |
-
Returns
|
| 647 |
-
-------
|
| 648 |
-
list : list of start rules
|
| 649 |
-
'''
|
| 650 |
-
return self.prod_rule_corpus.start_rule_list
|
| 651 |
-
|
| 652 |
-
@property
|
| 653 |
-
def prod_rule_list(self):
|
| 654 |
-
return self.prod_rule_corpus.prod_rule_list
|
| 655 |
-
|
| 656 |
-
def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
|
| 657 |
-
""" learn from a list of hypergraphs
|
| 658 |
-
|
| 659 |
-
Parameters
|
| 660 |
-
----------
|
| 661 |
-
hg_list : list of Hypergraph
|
| 662 |
-
|
| 663 |
-
Returns
|
| 664 |
-
-------
|
| 665 |
-
prod_rule_seq_list : list of integers
|
| 666 |
-
each element corresponds to a sequence of production rules to generate each hypergraph.
|
| 667 |
-
"""
|
| 668 |
-
prod_rule_seq_list = []
|
| 669 |
-
idx = 0
|
| 670 |
-
for each_idx, each_hg in enumerate(hg_list):
|
| 671 |
-
clique_tree = self.tree_decomposition(each_hg)
|
| 672 |
-
|
| 673 |
-
# get a pair of myself and children
|
| 674 |
-
root_node = _find_root(clique_tree)
|
| 675 |
-
clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
|
| 676 |
-
prod_rule_seq = []
|
| 677 |
-
stack = []
|
| 678 |
-
|
| 679 |
-
children = sorted(list(clique_tree[root_node].keys()))
|
| 680 |
-
|
| 681 |
-
# extract a temporary production rule
|
| 682 |
-
prod_rule = extract_prod_rule(
|
| 683 |
-
None,
|
| 684 |
-
clique_tree.nodes[root_node]["subhg"],
|
| 685 |
-
[clique_tree.nodes[each_child]["subhg"]
|
| 686 |
-
for each_child in children],
|
| 687 |
-
clique_tree.nodes[root_node].get('subhg_idx', None))
|
| 688 |
-
|
| 689 |
-
# update the production rule list
|
| 690 |
-
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
| 691 |
-
children = reorder_children(root_node,
|
| 692 |
-
children,
|
| 693 |
-
prod_rule,
|
| 694 |
-
clique_tree)
|
| 695 |
-
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
| 696 |
-
prod_rule_seq.append(prod_rule_id)
|
| 697 |
-
|
| 698 |
-
while len(stack) != 0:
|
| 699 |
-
# get a triple of parent, myself, and children
|
| 700 |
-
parent, myself = stack.pop()
|
| 701 |
-
children = sorted(list(dict(clique_tree[myself]).keys()))
|
| 702 |
-
children.remove(parent)
|
| 703 |
-
|
| 704 |
-
# extract a temp prod rule
|
| 705 |
-
prod_rule = extract_prod_rule(
|
| 706 |
-
clique_tree.nodes[parent]["subhg"],
|
| 707 |
-
clique_tree.nodes[myself]["subhg"],
|
| 708 |
-
[clique_tree.nodes[each_child]["subhg"]
|
| 709 |
-
for each_child in children],
|
| 710 |
-
clique_tree.nodes[myself].get('subhg_idx', None))
|
| 711 |
-
|
| 712 |
-
# update the prod rule list
|
| 713 |
-
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
| 714 |
-
children = reorder_children(myself,
|
| 715 |
-
children,
|
| 716 |
-
prod_rule,
|
| 717 |
-
clique_tree)
|
| 718 |
-
stack.extend([(myself, each_child)
|
| 719 |
-
for each_child in children[::-1]])
|
| 720 |
-
prod_rule_seq.append(prod_rule_id)
|
| 721 |
-
prod_rule_seq_list.append(prod_rule_seq)
|
| 722 |
-
if (each_idx+1) % print_freq == 0:
|
| 723 |
-
msg = f'#(molecules processed)={each_idx+1}\t'\
|
| 724 |
-
f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
|
| 725 |
-
logger(msg)
|
| 726 |
-
if each_idx > max_mol:
|
| 727 |
-
break
|
| 728 |
-
|
| 729 |
-
print(f'corpus_size = {self.clique_tree_corpus.size}')
|
| 730 |
-
return prod_rule_seq_list
|
| 731 |
-
|
| 732 |
-
def sample(self, z, deterministic=False):
|
| 733 |
-
""" sample a new hypergraph from HRG.
|
| 734 |
-
|
| 735 |
-
Parameters
|
| 736 |
-
----------
|
| 737 |
-
z : array-like, shape (len, num_prod_rule)
|
| 738 |
-
logit
|
| 739 |
-
deterministic : bool
|
| 740 |
-
if True, deterministic sampling
|
| 741 |
-
|
| 742 |
-
Returns
|
| 743 |
-
-------
|
| 744 |
-
Hypergraph
|
| 745 |
-
"""
|
| 746 |
-
seq_idx = 0
|
| 747 |
-
stack = []
|
| 748 |
-
z = z[:, :-1]
|
| 749 |
-
init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
|
| 750 |
-
is_aromatic=False,
|
| 751 |
-
bond_symbol_list=[]),
|
| 752 |
-
deterministic=deterministic)
|
| 753 |
-
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
| 754 |
-
stack = deepcopy(nt_edge_list[::-1])
|
| 755 |
-
while len(stack) != 0 and seq_idx < z.shape[0]-1:
|
| 756 |
-
seq_idx += 1
|
| 757 |
-
nt_edge = stack.pop()
|
| 758 |
-
nt_symbol = hg.edge_attr(nt_edge)['symbol']
|
| 759 |
-
prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
|
| 760 |
-
hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
|
| 761 |
-
stack.extend(nt_edge_list[::-1])
|
| 762 |
-
if len(stack) != 0:
|
| 763 |
-
raise RuntimeError(f'{len(stack)} non-terminals are left.')
|
| 764 |
-
return hg
|
| 765 |
-
|
| 766 |
-
def construct(self, prod_rule_seq):
|
| 767 |
-
""" construct a hypergraph following `prod_rule_seq`
|
| 768 |
-
|
| 769 |
-
Parameters
|
| 770 |
-
----------
|
| 771 |
-
prod_rule_seq : list of integers
|
| 772 |
-
a sequence of production rules.
|
| 773 |
-
|
| 774 |
-
Returns
|
| 775 |
-
-------
|
| 776 |
-
UndirectedHypergraph
|
| 777 |
-
"""
|
| 778 |
-
seq_idx = 0
|
| 779 |
-
init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
|
| 780 |
-
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
| 781 |
-
stack = deepcopy(nt_edge_list[::-1])
|
| 782 |
-
while len(stack) != 0:
|
| 783 |
-
seq_idx += 1
|
| 784 |
-
nt_edge = stack.pop()
|
| 785 |
-
hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
|
| 786 |
-
stack.extend(nt_edge_list[::-1])
|
| 787 |
-
return hg
|
| 788 |
-
|
| 789 |
-
def update_prod_rule_list(self, prod_rule):
|
| 790 |
-
""" return whether the input production rule is new or not, and its production rule id.
|
| 791 |
-
Production rules are regarded as the same if
|
| 792 |
-
i) there exists a one-to-one mapping of nodes and edges, and
|
| 793 |
-
ii) all the attributes associated with nodes and hyperedges are the same.
|
| 794 |
-
|
| 795 |
-
Parameters
|
| 796 |
-
----------
|
| 797 |
-
prod_rule : ProductionRule
|
| 798 |
-
|
| 799 |
-
Returns
|
| 800 |
-
-------
|
| 801 |
-
is_new : bool
|
| 802 |
-
if True, this production rule is new
|
| 803 |
-
prod_rule_id : int
|
| 804 |
-
production rule index. if new, a new index will be assigned.
|
| 805 |
-
"""
|
| 806 |
-
return self.prod_rule_corpus.append(prod_rule)
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
|
| 810 |
-
'''
|
| 811 |
-
This class learns HRG incrementally leveraging the previously obtained production rules.
|
| 812 |
-
'''
|
| 813 |
-
def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
|
| 814 |
-
self.prod_rule_list = []
|
| 815 |
-
self.tree_decomposition = tree_decomposition
|
| 816 |
-
self.ignore_order = ignore_order
|
| 817 |
-
|
| 818 |
-
def learn(self, hg_list):
|
| 819 |
-
""" learn from a list of hypergraphs
|
| 820 |
-
|
| 821 |
-
Parameters
|
| 822 |
-
----------
|
| 823 |
-
hg_list : list of UndirectedHypergraph
|
| 824 |
-
|
| 825 |
-
Returns
|
| 826 |
-
-------
|
| 827 |
-
prod_rule_seq_list : list of integers
|
| 828 |
-
each element corresponds to a sequence of production rules to generate each hypergraph.
|
| 829 |
-
"""
|
| 830 |
-
prod_rule_seq_list = []
|
| 831 |
-
for each_hg in hg_list:
|
| 832 |
-
clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
|
| 833 |
-
|
| 834 |
-
prod_rule_seq = []
|
| 835 |
-
stack = []
|
| 836 |
-
|
| 837 |
-
# get a pair of myself and children
|
| 838 |
-
children = sorted(list(clique_tree[root_node].keys()))
|
| 839 |
-
|
| 840 |
-
# extract a temporary production rule
|
| 841 |
-
prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
|
| 842 |
-
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
| 843 |
-
|
| 844 |
-
# update the production rule list
|
| 845 |
-
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
| 846 |
-
children = reorder_children(root_node, children, prod_rule, clique_tree)
|
| 847 |
-
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
| 848 |
-
prod_rule_seq.append(prod_rule_id)
|
| 849 |
-
|
| 850 |
-
while len(stack) != 0:
|
| 851 |
-
# get a triple of parent, myself, and children
|
| 852 |
-
parent, myself = stack.pop()
|
| 853 |
-
children = sorted(list(dict(clique_tree[myself]).keys()))
|
| 854 |
-
children.remove(parent)
|
| 855 |
-
|
| 856 |
-
# extract a temp prod rule
|
| 857 |
-
prod_rule = extract_prod_rule(
|
| 858 |
-
clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
|
| 859 |
-
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
| 860 |
-
|
| 861 |
-
# update the prod rule list
|
| 862 |
-
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
| 863 |
-
children = reorder_children(myself, children, prod_rule, clique_tree)
|
| 864 |
-
stack.extend([(myself, each_child) for each_child in children[::-1]])
|
| 865 |
-
prod_rule_seq.append(prod_rule_id)
|
| 866 |
-
prod_rule_seq_list.append(prod_rule_seq)
|
| 867 |
-
self._compute_stats()
|
| 868 |
-
return prod_rule_seq_list
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
def reorder_children(myself, children, prod_rule, clique_tree):
|
| 872 |
-
""" reorder children so that they match the order in `prod_rule`.
|
| 873 |
-
|
| 874 |
-
Parameters
|
| 875 |
-
----------
|
| 876 |
-
myself : int
|
| 877 |
-
children : list of int
|
| 878 |
-
prod_rule : ProductionRule
|
| 879 |
-
clique_tree : nx.Graph
|
| 880 |
-
|
| 881 |
-
Returns
|
| 882 |
-
-------
|
| 883 |
-
new_children : list of str
|
| 884 |
-
reordered children
|
| 885 |
-
"""
|
| 886 |
-
perm = {} # key : `nt_idx`, val : child
|
| 887 |
-
for each_edge in prod_rule.rhs.edges:
|
| 888 |
-
if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
|
| 889 |
-
for each_child in children:
|
| 890 |
-
common_node_set = set(
|
| 891 |
-
common_node_list(clique_tree.nodes[myself]["subhg"],
|
| 892 |
-
clique_tree.nodes[each_child]["subhg"])[0])
|
| 893 |
-
if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
|
| 894 |
-
assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
|
| 895 |
-
perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
|
| 896 |
-
new_children = []
|
| 897 |
-
assert len(perm) == len(children)
|
| 898 |
-
for i in range(len(perm)):
|
| 899 |
-
new_children.append(perm[i])
|
| 900 |
-
return new_children
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
|
| 904 |
-
""" extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
|
| 905 |
-
|
| 906 |
-
Parameters
|
| 907 |
-
----------
|
| 908 |
-
parent_hg : Hypergraph
|
| 909 |
-
myself_hg : Hypergraph
|
| 910 |
-
children_hg_list : list of Hypergraph
|
| 911 |
-
|
| 912 |
-
Returns
|
| 913 |
-
-------
|
| 914 |
-
ProductionRule, consisting of
|
| 915 |
-
lhs : Hypergraph or None
|
| 916 |
-
rhs : Hypergraph
|
| 917 |
-
"""
|
| 918 |
-
def _add_ext_node(hg, ext_nodes):
|
| 919 |
-
""" mark nodes to be external (ordered ids are assigned)
|
| 920 |
-
|
| 921 |
-
Parameters
|
| 922 |
-
----------
|
| 923 |
-
hg : UndirectedHypergraph
|
| 924 |
-
ext_nodes : list of str
|
| 925 |
-
list of external nodes
|
| 926 |
-
|
| 927 |
-
Returns
|
| 928 |
-
-------
|
| 929 |
-
hg : Hypergraph
|
| 930 |
-
nodes in `ext_nodes` are marked to be external
|
| 931 |
-
"""
|
| 932 |
-
ext_id = 0
|
| 933 |
-
ext_id_exists = []
|
| 934 |
-
for each_node in ext_nodes:
|
| 935 |
-
ext_id_exists.append('ext_id' in hg.node_attr(each_node))
|
| 936 |
-
if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
|
| 937 |
-
raise ValueError
|
| 938 |
-
if not all(ext_id_exists):
|
| 939 |
-
for each_node in ext_nodes:
|
| 940 |
-
hg.node_attr(each_node)['ext_id'] = ext_id
|
| 941 |
-
ext_id += 1
|
| 942 |
-
return hg
|
| 943 |
-
|
| 944 |
-
def _check_aromatic(hg, node_list):
|
| 945 |
-
is_aromatic = False
|
| 946 |
-
node_aromatic_list = []
|
| 947 |
-
for each_node in node_list:
|
| 948 |
-
if hg.node_attr(each_node)['symbol'].is_aromatic:
|
| 949 |
-
is_aromatic = True
|
| 950 |
-
node_aromatic_list.append(True)
|
| 951 |
-
else:
|
| 952 |
-
node_aromatic_list.append(False)
|
| 953 |
-
return is_aromatic, node_aromatic_list
|
| 954 |
-
|
| 955 |
-
def _check_ring(hg):
|
| 956 |
-
for each_edge in hg.edges:
|
| 957 |
-
if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
|
| 958 |
-
return False
|
| 959 |
-
return True
|
| 960 |
-
|
| 961 |
-
if parent_hg is None:
|
| 962 |
-
lhs = Hypergraph()
|
| 963 |
-
node_list = []
|
| 964 |
-
else:
|
| 965 |
-
lhs = Hypergraph()
|
| 966 |
-
node_list, edge_exists = common_node_list(parent_hg, myself_hg)
|
| 967 |
-
for each_node in node_list:
|
| 968 |
-
lhs.add_node(each_node,
|
| 969 |
-
deepcopy(myself_hg.node_attr(each_node)))
|
| 970 |
-
is_aromatic, _ = _check_aromatic(parent_hg, node_list)
|
| 971 |
-
for_ring = _check_ring(myself_hg)
|
| 972 |
-
bond_symbol_list = []
|
| 973 |
-
for each_node in node_list:
|
| 974 |
-
bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
|
| 975 |
-
lhs.add_edge(
|
| 976 |
-
node_list,
|
| 977 |
-
attr_dict=dict(
|
| 978 |
-
terminal=False,
|
| 979 |
-
edge_exists=edge_exists,
|
| 980 |
-
symbol=NTSymbol(
|
| 981 |
-
degree=len(node_list),
|
| 982 |
-
is_aromatic=is_aromatic,
|
| 983 |
-
bond_symbol_list=bond_symbol_list,
|
| 984 |
-
for_ring=for_ring)))
|
| 985 |
-
try:
|
| 986 |
-
lhs = _add_ext_node(lhs, node_list)
|
| 987 |
-
except ValueError:
|
| 988 |
-
import pdb; pdb.set_trace()
|
| 989 |
-
|
| 990 |
-
rhs = remove_tmp_edge(deepcopy(myself_hg))
|
| 991 |
-
#rhs = remove_ext_node(rhs)
|
| 992 |
-
#rhs = remove_nt_edge(rhs)
|
| 993 |
-
try:
|
| 994 |
-
rhs = _add_ext_node(rhs, node_list)
|
| 995 |
-
except ValueError:
|
| 996 |
-
import pdb; pdb.set_trace()
|
| 997 |
-
|
| 998 |
-
nt_idx = 0
|
| 999 |
-
if children_hg_list is not None:
|
| 1000 |
-
for each_child_hg in children_hg_list:
|
| 1001 |
-
node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
|
| 1002 |
-
is_aromatic, _ = _check_aromatic(myself_hg, node_list)
|
| 1003 |
-
for_ring = _check_ring(each_child_hg)
|
| 1004 |
-
bond_symbol_list = []
|
| 1005 |
-
for each_node in node_list:
|
| 1006 |
-
bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
|
| 1007 |
-
rhs.add_edge(
|
| 1008 |
-
node_list,
|
| 1009 |
-
attr_dict=dict(
|
| 1010 |
-
terminal=False,
|
| 1011 |
-
nt_idx=nt_idx,
|
| 1012 |
-
edge_exists=edge_exists,
|
| 1013 |
-
symbol=NTSymbol(degree=len(node_list),
|
| 1014 |
-
is_aromatic=is_aromatic,
|
| 1015 |
-
bond_symbol_list=bond_symbol_list,
|
| 1016 |
-
for_ring=for_ring)))
|
| 1017 |
-
nt_idx += 1
|
| 1018 |
-
prod_rule = ProductionRule(lhs, rhs)
|
| 1019 |
-
prod_rule.subhg_idx = subhg_idx
|
| 1020 |
-
if DEBUG:
|
| 1021 |
-
if sorted(list(prod_rule.ext_node.keys())) \
|
| 1022 |
-
!= list(np.arange(len(prod_rule.ext_node))):
|
| 1023 |
-
raise RuntimeError('ext_id is not continuous')
|
| 1024 |
-
return prod_rule
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
def _find_root(clique_tree):
|
| 1028 |
-
max_node = None
|
| 1029 |
-
num_nodes_max = -np.inf
|
| 1030 |
-
for each_node in clique_tree.nodes:
|
| 1031 |
-
if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
|
| 1032 |
-
max_node = each_node
|
| 1033 |
-
num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
|
| 1034 |
-
'''
|
| 1035 |
-
children = sorted(list(clique_tree[each_node].keys()))
|
| 1036 |
-
prod_rule = extract_prod_rule(None,
|
| 1037 |
-
clique_tree.nodes[each_node]["subhg"],
|
| 1038 |
-
[clique_tree.nodes[each_child]["subhg"]
|
| 1039 |
-
for each_child in children])
|
| 1040 |
-
for each_start_rule in start_rule_list:
|
| 1041 |
-
if prod_rule.is_same(each_start_rule):
|
| 1042 |
-
return each_node
|
| 1043 |
-
'''
|
| 1044 |
-
return max_node
|
| 1045 |
-
|
| 1046 |
-
def remove_ext_node(hg):
|
| 1047 |
-
for each_node in hg.nodes:
|
| 1048 |
-
hg.node_attr(each_node).pop('ext_id', None)
|
| 1049 |
-
return hg
|
| 1050 |
-
|
| 1051 |
-
def remove_nt_edge(hg):
|
| 1052 |
-
remove_edge_list = []
|
| 1053 |
-
for each_edge in hg.edges:
|
| 1054 |
-
if not hg.edge_attr(each_edge)['terminal']:
|
| 1055 |
-
remove_edge_list.append(each_edge)
|
| 1056 |
-
hg.remove_edges(remove_edge_list)
|
| 1057 |
-
return hg
|
| 1058 |
-
|
| 1059 |
-
def remove_tmp_edge(hg):
|
| 1060 |
-
remove_edge_list = []
|
| 1061 |
-
for each_edge in hg.edges:
|
| 1062 |
-
if hg.edge_attr(each_edge).get('tmp', False):
|
| 1063 |
-
remove_edge_list.append(each_edge)
|
| 1064 |
-
hg.remove_edges(remove_edge_list)
|
| 1065 |
-
return hg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/symbols.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
""" Title """
|
| 16 |
-
|
| 17 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 18 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 19 |
-
__version__ = "0.1"
|
| 20 |
-
__date__ = "Jan 1 2018"
|
| 21 |
-
|
| 22 |
-
from typing import List
|
| 23 |
-
|
| 24 |
-
class TSymbol(object):
|
| 25 |
-
|
| 26 |
-
''' terminal symbol
|
| 27 |
-
|
| 28 |
-
Attributes
|
| 29 |
-
----------
|
| 30 |
-
degree : int
|
| 31 |
-
the number of nodes in a hyperedge
|
| 32 |
-
is_aromatic : bool
|
| 33 |
-
whether or not the hyperedge is in an aromatic ring
|
| 34 |
-
symbol : str
|
| 35 |
-
atomic symbol
|
| 36 |
-
num_explicit_Hs : int
|
| 37 |
-
the number of hydrogens associated to this hyperedge
|
| 38 |
-
formal_charge : int
|
| 39 |
-
charge
|
| 40 |
-
chirality : int
|
| 41 |
-
chirality
|
| 42 |
-
'''
|
| 43 |
-
|
| 44 |
-
def __init__(self, degree, is_aromatic,
|
| 45 |
-
symbol, num_explicit_Hs, formal_charge, chirality):
|
| 46 |
-
self.degree = degree
|
| 47 |
-
self.is_aromatic = is_aromatic
|
| 48 |
-
self.symbol = symbol
|
| 49 |
-
self.num_explicit_Hs = num_explicit_Hs
|
| 50 |
-
self.formal_charge = formal_charge
|
| 51 |
-
self.chirality = chirality
|
| 52 |
-
|
| 53 |
-
@property
|
| 54 |
-
def terminal(self):
|
| 55 |
-
return True
|
| 56 |
-
|
| 57 |
-
def __eq__(self, other):
|
| 58 |
-
if not isinstance(other, TSymbol):
|
| 59 |
-
return False
|
| 60 |
-
if self.degree != other.degree:
|
| 61 |
-
return False
|
| 62 |
-
if self.is_aromatic != other.is_aromatic:
|
| 63 |
-
return False
|
| 64 |
-
if self.symbol != other.symbol:
|
| 65 |
-
return False
|
| 66 |
-
if self.num_explicit_Hs != other.num_explicit_Hs:
|
| 67 |
-
return False
|
| 68 |
-
if self.formal_charge != other.formal_charge:
|
| 69 |
-
return False
|
| 70 |
-
if self.chirality != other.chirality:
|
| 71 |
-
return False
|
| 72 |
-
return True
|
| 73 |
-
|
| 74 |
-
def __hash__(self):
|
| 75 |
-
return self.__str__().__hash__()
|
| 76 |
-
|
| 77 |
-
def __str__(self):
|
| 78 |
-
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
| 79 |
-
f'symbol={self.symbol}, '\
|
| 80 |
-
f'num_explicit_Hs={self.num_explicit_Hs}, '\
|
| 81 |
-
f'formal_charge={self.formal_charge}, chirality={self.chirality}'
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
class NTSymbol(object):
|
| 85 |
-
|
| 86 |
-
''' non-terminal symbol
|
| 87 |
-
|
| 88 |
-
Attributes
|
| 89 |
-
----------
|
| 90 |
-
degree : int
|
| 91 |
-
degree of the hyperedge
|
| 92 |
-
is_aromatic : bool
|
| 93 |
-
if True, at least one of the associated bonds must be aromatic.
|
| 94 |
-
node_aromatic_list : list of bool
|
| 95 |
-
indicate whether each of the nodes is aromatic or not.
|
| 96 |
-
bond_type_list : list of int
|
| 97 |
-
bond type of each node"
|
| 98 |
-
'''
|
| 99 |
-
|
| 100 |
-
def __init__(self, degree: int, is_aromatic: bool,
|
| 101 |
-
bond_symbol_list: list,
|
| 102 |
-
for_ring=False):
|
| 103 |
-
self.degree = degree
|
| 104 |
-
self.is_aromatic = is_aromatic
|
| 105 |
-
self.for_ring = for_ring
|
| 106 |
-
self.bond_symbol_list = bond_symbol_list
|
| 107 |
-
|
| 108 |
-
@property
|
| 109 |
-
def terminal(self) -> bool:
|
| 110 |
-
return False
|
| 111 |
-
|
| 112 |
-
@property
|
| 113 |
-
def symbol(self):
|
| 114 |
-
return f'NT{self.degree}'
|
| 115 |
-
|
| 116 |
-
def __eq__(self, other) -> bool:
|
| 117 |
-
if not isinstance(other, NTSymbol):
|
| 118 |
-
return False
|
| 119 |
-
|
| 120 |
-
if self.degree != other.degree:
|
| 121 |
-
return False
|
| 122 |
-
if self.is_aromatic != other.is_aromatic:
|
| 123 |
-
return False
|
| 124 |
-
if self.for_ring != other.for_ring:
|
| 125 |
-
return False
|
| 126 |
-
if len(self.bond_symbol_list) != len(other.bond_symbol_list):
|
| 127 |
-
return False
|
| 128 |
-
for each_idx in range(len(self.bond_symbol_list)):
|
| 129 |
-
if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
|
| 130 |
-
return False
|
| 131 |
-
return True
|
| 132 |
-
|
| 133 |
-
def __hash__(self):
|
| 134 |
-
return self.__str__().__hash__()
|
| 135 |
-
|
| 136 |
-
def __str__(self) -> str:
|
| 137 |
-
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
| 138 |
-
f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
|
| 139 |
-
f'for_ring={self.for_ring}'
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
class BondSymbol(object):
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
''' Bond symbol
|
| 146 |
-
|
| 147 |
-
Attributes
|
| 148 |
-
----------
|
| 149 |
-
is_aromatic : bool
|
| 150 |
-
if True, at least one of the associated bonds must be aromatic.
|
| 151 |
-
bond_type : int
|
| 152 |
-
bond type of each node"
|
| 153 |
-
'''
|
| 154 |
-
|
| 155 |
-
def __init__(self, is_aromatic: bool,
|
| 156 |
-
bond_type: int,
|
| 157 |
-
stereo: int):
|
| 158 |
-
self.is_aromatic = is_aromatic
|
| 159 |
-
self.bond_type = bond_type
|
| 160 |
-
self.stereo = stereo
|
| 161 |
-
|
| 162 |
-
def __eq__(self, other) -> bool:
|
| 163 |
-
if not isinstance(other, BondSymbol):
|
| 164 |
-
return False
|
| 165 |
-
|
| 166 |
-
if self.is_aromatic != other.is_aromatic:
|
| 167 |
-
return False
|
| 168 |
-
if self.bond_type != other.bond_type:
|
| 169 |
-
return False
|
| 170 |
-
if self.stereo != other.stereo:
|
| 171 |
-
return False
|
| 172 |
-
return True
|
| 173 |
-
|
| 174 |
-
def __hash__(self):
|
| 175 |
-
return self.__str__().__hash__()
|
| 176 |
-
|
| 177 |
-
def __str__(self) -> str:
|
| 178 |
-
return f'is_aromatic={self.is_aromatic}, '\
|
| 179 |
-
f'bond_type={self.bond_type}, '\
|
| 180 |
-
f'stereo={self.stereo}, '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/graph_grammar/utils.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jun 4 2018"
|
| 20 |
-
|
| 21 |
-
from ..hypergraph import Hypergraph
|
| 22 |
-
from copy import deepcopy
|
| 23 |
-
from typing import List
|
| 24 |
-
import numpy as np
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
|
| 28 |
-
""" return a list of common nodes
|
| 29 |
-
|
| 30 |
-
Parameters
|
| 31 |
-
----------
|
| 32 |
-
hg1, hg2 : Hypergraph
|
| 33 |
-
|
| 34 |
-
Returns
|
| 35 |
-
-------
|
| 36 |
-
list of str
|
| 37 |
-
list of common nodes
|
| 38 |
-
"""
|
| 39 |
-
if hg1 is None or hg2 is None:
|
| 40 |
-
return [], False
|
| 41 |
-
else:
|
| 42 |
-
node_set = hg1.nodes.intersection(hg2.nodes)
|
| 43 |
-
node_dict = {}
|
| 44 |
-
if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
|
| 45 |
-
for each_node in node_set:
|
| 46 |
-
node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
|
| 47 |
-
else:
|
| 48 |
-
for each_node in node_set:
|
| 49 |
-
node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
|
| 50 |
-
node_list = []
|
| 51 |
-
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
| 52 |
-
node_list.append(each_key)
|
| 53 |
-
edge_name = hg1.has_edge(node_list, ignore_order=True)
|
| 54 |
-
if edge_name:
|
| 55 |
-
if not hg1.edge_attr(edge_name).get('terminal', True):
|
| 56 |
-
node_list = hg1.nodes_in_edge(edge_name)
|
| 57 |
-
return node_list, True
|
| 58 |
-
else:
|
| 59 |
-
return node_list, False
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _node_match(node1, node2):
|
| 63 |
-
# if the nodes are hyperedges, `atom_attr` determines the match
|
| 64 |
-
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
| 65 |
-
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
| 66 |
-
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
| 67 |
-
# bond_symbol
|
| 68 |
-
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
| 69 |
-
else:
|
| 70 |
-
return False
|
| 71 |
-
|
| 72 |
-
def _easy_node_match(node1, node2):
|
| 73 |
-
# if the nodes are hyperedges, `atom_attr` determines the match
|
| 74 |
-
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
| 75 |
-
return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
|
| 76 |
-
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
| 77 |
-
# bond_symbol
|
| 78 |
-
return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
|
| 79 |
-
and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
| 80 |
-
else:
|
| 81 |
-
return False
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _node_match_prod_rule(node1, node2, ignore_order=False):
|
| 85 |
-
# if the nodes are hyperedges, `atom_attr` determines the match
|
| 86 |
-
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
| 87 |
-
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
| 88 |
-
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
| 89 |
-
# ext_id, order4hrg, bond_symbol
|
| 90 |
-
if ignore_order:
|
| 91 |
-
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
| 92 |
-
else:
|
| 93 |
-
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
|
| 94 |
-
and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
|
| 95 |
-
else:
|
| 96 |
-
return False
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def _edge_match(edge1, edge2, ignore_order=False):
|
| 100 |
-
#return True
|
| 101 |
-
if ignore_order:
|
| 102 |
-
return True
|
| 103 |
-
else:
|
| 104 |
-
return edge1["order"] == edge2["order"]
|
| 105 |
-
|
| 106 |
-
def masked_softmax(logit, mask):
|
| 107 |
-
''' compute a probability distribution from logit
|
| 108 |
-
|
| 109 |
-
Parameters
|
| 110 |
-
----------
|
| 111 |
-
logit : array-like, length D
|
| 112 |
-
each element indicates how each dimension is likely to be chosen
|
| 113 |
-
(the larger, the more likely)
|
| 114 |
-
mask : array-like, length D
|
| 115 |
-
each element is either 0 or 1.
|
| 116 |
-
if 0, the dimension is ignored
|
| 117 |
-
when computing the probability distribution.
|
| 118 |
-
|
| 119 |
-
Returns
|
| 120 |
-
-------
|
| 121 |
-
prob_dist : array, length D
|
| 122 |
-
probability distribution computed from logit.
|
| 123 |
-
if `mask[d] = 0`, `prob_dist[d] = 0`.
|
| 124 |
-
'''
|
| 125 |
-
if logit.shape != mask.shape:
|
| 126 |
-
raise ValueError('logit and mask must have the same shape')
|
| 127 |
-
c = np.max(logit)
|
| 128 |
-
exp_logit = np.exp(logit - c) * mask
|
| 129 |
-
sum_exp_logit = exp_logit @ mask
|
| 130 |
-
return exp_logit / sum_exp_logit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/hypergraph.py
DELETED
|
@@ -1,544 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 31 2018"
|
| 20 |
-
|
| 21 |
-
from copy import deepcopy
|
| 22 |
-
from typing import List, Dict, Tuple
|
| 23 |
-
import networkx as nx
|
| 24 |
-
import numpy as np
|
| 25 |
-
import os
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class Hypergraph(object):
|
| 29 |
-
'''
|
| 30 |
-
A class of a hypergraph.
|
| 31 |
-
Each hyperedge can be ordered. For the ordered case,
|
| 32 |
-
edges adjacent to the hyperedge node are labeled by their orders.
|
| 33 |
-
|
| 34 |
-
Attributes
|
| 35 |
-
----------
|
| 36 |
-
hg : nx.Graph
|
| 37 |
-
a bipartite graph representation of a hypergraph
|
| 38 |
-
edge_idx : int
|
| 39 |
-
total number of hyperedges that exist so far
|
| 40 |
-
'''
|
| 41 |
-
def __init__(self):
|
| 42 |
-
self.hg = nx.Graph()
|
| 43 |
-
self.edge_idx = 0
|
| 44 |
-
self.nodes = set([])
|
| 45 |
-
self.num_nodes = 0
|
| 46 |
-
self.edges = set([])
|
| 47 |
-
self.num_edges = 0
|
| 48 |
-
self.nodes_in_edge_dict = {}
|
| 49 |
-
|
| 50 |
-
def add_node(self, node: str, attr_dict=None):
|
| 51 |
-
''' add a node to hypergraph
|
| 52 |
-
|
| 53 |
-
Parameters
|
| 54 |
-
----------
|
| 55 |
-
node : str
|
| 56 |
-
node name
|
| 57 |
-
attr_dict : dict
|
| 58 |
-
dictionary of node attributes
|
| 59 |
-
'''
|
| 60 |
-
self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
|
| 61 |
-
if node not in self.nodes:
|
| 62 |
-
self.num_nodes += 1
|
| 63 |
-
self.nodes.add(node)
|
| 64 |
-
|
| 65 |
-
def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
|
| 66 |
-
''' add an edge consisting of nodes `node_list`
|
| 67 |
-
|
| 68 |
-
Parameters
|
| 69 |
-
----------
|
| 70 |
-
node_list : list
|
| 71 |
-
ordered list of nodes that consist the edge
|
| 72 |
-
attr_dict : dict
|
| 73 |
-
dictionary of edge attributes
|
| 74 |
-
'''
|
| 75 |
-
if edge_name is None:
|
| 76 |
-
edge = 'e{}'.format(self.edge_idx)
|
| 77 |
-
else:
|
| 78 |
-
assert edge_name not in self.edges
|
| 79 |
-
edge = edge_name
|
| 80 |
-
self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
|
| 81 |
-
if edge not in self.edges:
|
| 82 |
-
self.num_edges += 1
|
| 83 |
-
self.edges.add(edge)
|
| 84 |
-
self.nodes_in_edge_dict[edge] = node_list
|
| 85 |
-
if type(node_list) == list:
|
| 86 |
-
for node_idx, each_node in enumerate(node_list):
|
| 87 |
-
self.hg.add_edge(edge, each_node, order=node_idx)
|
| 88 |
-
if each_node not in self.nodes:
|
| 89 |
-
self.num_nodes += 1
|
| 90 |
-
self.nodes.add(each_node)
|
| 91 |
-
|
| 92 |
-
elif type(node_list) == set:
|
| 93 |
-
for each_node in node_list:
|
| 94 |
-
self.hg.add_edge(edge, each_node, order=-1)
|
| 95 |
-
if each_node not in self.nodes:
|
| 96 |
-
self.num_nodes += 1
|
| 97 |
-
self.nodes.add(each_node)
|
| 98 |
-
else:
|
| 99 |
-
raise ValueError
|
| 100 |
-
self.edge_idx += 1
|
| 101 |
-
return edge
|
| 102 |
-
|
| 103 |
-
def remove_node(self, node: str, remove_connected_edges=True):
|
| 104 |
-
''' remove a node
|
| 105 |
-
|
| 106 |
-
Parameters
|
| 107 |
-
----------
|
| 108 |
-
node : str
|
| 109 |
-
node name
|
| 110 |
-
remove_connected_edges : bool
|
| 111 |
-
if True, remove edges that are adjacent to the node
|
| 112 |
-
'''
|
| 113 |
-
if remove_connected_edges:
|
| 114 |
-
connected_edges = deepcopy(self.adj_edges(node))
|
| 115 |
-
for each_edge in connected_edges:
|
| 116 |
-
self.remove_edge(each_edge)
|
| 117 |
-
self.hg.remove_node(node)
|
| 118 |
-
self.num_nodes -= 1
|
| 119 |
-
self.nodes.remove(node)
|
| 120 |
-
|
| 121 |
-
def remove_nodes(self, node_iter, remove_connected_edges=True):
|
| 122 |
-
''' remove a set of nodes
|
| 123 |
-
|
| 124 |
-
Parameters
|
| 125 |
-
----------
|
| 126 |
-
node_iter : iterator of strings
|
| 127 |
-
nodes to be removed
|
| 128 |
-
remove_connected_edges : bool
|
| 129 |
-
if True, remove edges that are adjacent to the node
|
| 130 |
-
'''
|
| 131 |
-
for each_node in node_iter:
|
| 132 |
-
self.remove_node(each_node, remove_connected_edges)
|
| 133 |
-
|
| 134 |
-
def remove_edge(self, edge: str):
|
| 135 |
-
''' remove an edge
|
| 136 |
-
|
| 137 |
-
Parameters
|
| 138 |
-
----------
|
| 139 |
-
edge : str
|
| 140 |
-
edge to be removed
|
| 141 |
-
'''
|
| 142 |
-
self.hg.remove_node(edge)
|
| 143 |
-
self.edges.remove(edge)
|
| 144 |
-
self.num_edges -= 1
|
| 145 |
-
self.nodes_in_edge_dict.pop(edge)
|
| 146 |
-
|
| 147 |
-
def remove_edges(self, edge_iter):
|
| 148 |
-
''' remove a set of edges
|
| 149 |
-
|
| 150 |
-
Parameters
|
| 151 |
-
----------
|
| 152 |
-
edge_iter : iterator of strings
|
| 153 |
-
edges to be removed
|
| 154 |
-
'''
|
| 155 |
-
for each_edge in edge_iter:
|
| 156 |
-
self.remove_edge(each_edge)
|
| 157 |
-
|
| 158 |
-
def remove_edges_with_attr(self, edge_attr_dict):
|
| 159 |
-
remove_edge_list = []
|
| 160 |
-
for each_edge in self.edges:
|
| 161 |
-
satisfy = True
|
| 162 |
-
for each_key, each_val in edge_attr_dict.items():
|
| 163 |
-
if not satisfy:
|
| 164 |
-
break
|
| 165 |
-
try:
|
| 166 |
-
if self.edge_attr(each_edge)[each_key] != each_val:
|
| 167 |
-
satisfy = False
|
| 168 |
-
except KeyError:
|
| 169 |
-
satisfy = False
|
| 170 |
-
if satisfy:
|
| 171 |
-
remove_edge_list.append(each_edge)
|
| 172 |
-
self.remove_edges(remove_edge_list)
|
| 173 |
-
|
| 174 |
-
def remove_subhg(self, subhg):
|
| 175 |
-
''' remove subhypergraph.
|
| 176 |
-
all of the hyperedges are removed.
|
| 177 |
-
each node of subhg is removed if its degree becomes 0 after removing hyperedges.
|
| 178 |
-
|
| 179 |
-
Parameters
|
| 180 |
-
----------
|
| 181 |
-
subhg : Hypergraph
|
| 182 |
-
'''
|
| 183 |
-
for each_edge in subhg.edges:
|
| 184 |
-
self.remove_edge(each_edge)
|
| 185 |
-
for each_node in subhg.nodes:
|
| 186 |
-
if self.degree(each_node) == 0:
|
| 187 |
-
self.remove_node(each_node)
|
| 188 |
-
|
| 189 |
-
def nodes_in_edge(self, edge):
|
| 190 |
-
''' return an ordered list of nodes in a given edge.
|
| 191 |
-
|
| 192 |
-
Parameters
|
| 193 |
-
----------
|
| 194 |
-
edge : str
|
| 195 |
-
edge whose nodes are returned
|
| 196 |
-
|
| 197 |
-
Returns
|
| 198 |
-
-------
|
| 199 |
-
list or set
|
| 200 |
-
ordered list or set of nodes that belong to the edge
|
| 201 |
-
'''
|
| 202 |
-
if edge.startswith('e'):
|
| 203 |
-
return self.nodes_in_edge_dict[edge]
|
| 204 |
-
else:
|
| 205 |
-
adj_node_list = self.hg.adj[edge]
|
| 206 |
-
adj_node_order_list = []
|
| 207 |
-
adj_node_name_list = []
|
| 208 |
-
for each_node in adj_node_list:
|
| 209 |
-
adj_node_order_list.append(adj_node_list[each_node]['order'])
|
| 210 |
-
adj_node_name_list.append(each_node)
|
| 211 |
-
if adj_node_order_list == [-1] * len(adj_node_order_list):
|
| 212 |
-
return set(adj_node_name_list)
|
| 213 |
-
else:
|
| 214 |
-
return [adj_node_name_list[each_idx] for each_idx
|
| 215 |
-
in np.argsort(adj_node_order_list)]
|
| 216 |
-
|
| 217 |
-
def adj_edges(self, node):
|
| 218 |
-
''' return a dict of adjacent hyperedges
|
| 219 |
-
|
| 220 |
-
Parameters
|
| 221 |
-
----------
|
| 222 |
-
node : str
|
| 223 |
-
|
| 224 |
-
Returns
|
| 225 |
-
-------
|
| 226 |
-
set
|
| 227 |
-
set of edges that are adjacent to `node`
|
| 228 |
-
'''
|
| 229 |
-
return self.hg.adj[node]
|
| 230 |
-
|
| 231 |
-
def adj_nodes(self, node):
|
| 232 |
-
''' return a set of adjacent nodes
|
| 233 |
-
|
| 234 |
-
Parameters
|
| 235 |
-
----------
|
| 236 |
-
node : str
|
| 237 |
-
|
| 238 |
-
Returns
|
| 239 |
-
-------
|
| 240 |
-
set
|
| 241 |
-
set of nodes that are adjacent to `node`
|
| 242 |
-
'''
|
| 243 |
-
node_set = set([])
|
| 244 |
-
for each_adj_edge in self.adj_edges(node):
|
| 245 |
-
node_set.update(set(self.nodes_in_edge(each_adj_edge)))
|
| 246 |
-
node_set.discard(node)
|
| 247 |
-
return node_set
|
| 248 |
-
|
| 249 |
-
def has_edge(self, node_list, ignore_order=False):
|
| 250 |
-
for each_edge in self.edges:
|
| 251 |
-
if ignore_order:
|
| 252 |
-
if set(self.nodes_in_edge(each_edge)) == set(node_list):
|
| 253 |
-
return each_edge
|
| 254 |
-
else:
|
| 255 |
-
if self.nodes_in_edge(each_edge) == node_list:
|
| 256 |
-
return each_edge
|
| 257 |
-
return False
|
| 258 |
-
|
| 259 |
-
def degree(self, node):
|
| 260 |
-
return len(self.hg.adj[node])
|
| 261 |
-
|
| 262 |
-
def degrees(self):
|
| 263 |
-
return {each_node: self.degree(each_node) for each_node in self.nodes}
|
| 264 |
-
|
| 265 |
-
def edge_degree(self, edge):
|
| 266 |
-
return len(self.nodes_in_edge(edge))
|
| 267 |
-
|
| 268 |
-
def edge_degrees(self):
|
| 269 |
-
return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
|
| 270 |
-
|
| 271 |
-
def is_adj(self, node1, node2):
|
| 272 |
-
return node1 in self.adj_nodes(node2)
|
| 273 |
-
|
| 274 |
-
def adj_subhg(self, node, ident_node_dict=None):
|
| 275 |
-
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
| 276 |
-
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
| 277 |
-
|
| 278 |
-
Parameters
|
| 279 |
-
----------
|
| 280 |
-
node : str
|
| 281 |
-
ident_node_dict : dict
|
| 282 |
-
dict containing identical nodes. see `get_identical_node_dict` for more details
|
| 283 |
-
|
| 284 |
-
Returns
|
| 285 |
-
-------
|
| 286 |
-
subhg : Hypergraph
|
| 287 |
-
"""
|
| 288 |
-
if ident_node_dict is None:
|
| 289 |
-
ident_node_dict = self.get_identical_node_dict()
|
| 290 |
-
adj_node_set = set(ident_node_dict[node])
|
| 291 |
-
adj_edge_set = set([])
|
| 292 |
-
for each_node in ident_node_dict[node]:
|
| 293 |
-
adj_edge_set.update(set(self.adj_edges(each_node)))
|
| 294 |
-
fixed_adj_edge_set = deepcopy(adj_edge_set)
|
| 295 |
-
for each_edge in fixed_adj_edge_set:
|
| 296 |
-
other_nodes = self.nodes_in_edge(each_edge)
|
| 297 |
-
adj_node_set.update(other_nodes)
|
| 298 |
-
|
| 299 |
-
# if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
|
| 300 |
-
for each_node in other_nodes:
|
| 301 |
-
for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
|
| 302 |
-
if len(set(self.nodes_in_edge(other_edge)) \
|
| 303 |
-
- set(self.nodes_in_edge(each_edge))) == 0:
|
| 304 |
-
adj_edge_set.update(set([other_edge]))
|
| 305 |
-
subhg = Hypergraph()
|
| 306 |
-
for each_node in adj_node_set:
|
| 307 |
-
subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
|
| 308 |
-
for each_edge in adj_edge_set:
|
| 309 |
-
subhg.add_edge(self.nodes_in_edge(each_edge),
|
| 310 |
-
attr_dict=self.edge_attr(each_edge),
|
| 311 |
-
edge_name=each_edge)
|
| 312 |
-
subhg.edge_idx = self.edge_idx
|
| 313 |
-
return subhg
|
| 314 |
-
|
| 315 |
-
def get_subhg(self, node_list, edge_list, ident_node_dict=None):
|
| 316 |
-
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
| 317 |
-
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
| 318 |
-
|
| 319 |
-
Parameters
|
| 320 |
-
----------
|
| 321 |
-
node : str
|
| 322 |
-
ident_node_dict : dict
|
| 323 |
-
dict containing identical nodes. see `get_identical_node_dict` for more details
|
| 324 |
-
|
| 325 |
-
Returns
|
| 326 |
-
-------
|
| 327 |
-
subhg : Hypergraph
|
| 328 |
-
"""
|
| 329 |
-
if ident_node_dict is None:
|
| 330 |
-
ident_node_dict = self.get_identical_node_dict()
|
| 331 |
-
adj_node_set = set([])
|
| 332 |
-
for each_node in node_list:
|
| 333 |
-
adj_node_set.update(set(ident_node_dict[each_node]))
|
| 334 |
-
adj_edge_set = set(edge_list)
|
| 335 |
-
|
| 336 |
-
subhg = Hypergraph()
|
| 337 |
-
for each_node in adj_node_set:
|
| 338 |
-
subhg.add_node(each_node,
|
| 339 |
-
attr_dict=deepcopy(self.node_attr(each_node)))
|
| 340 |
-
for each_edge in adj_edge_set:
|
| 341 |
-
subhg.add_edge(self.nodes_in_edge(each_edge),
|
| 342 |
-
attr_dict=deepcopy(self.edge_attr(each_edge)),
|
| 343 |
-
edge_name=each_edge)
|
| 344 |
-
subhg.edge_idx = self.edge_idx
|
| 345 |
-
return subhg
|
| 346 |
-
|
| 347 |
-
def copy(self):
|
| 348 |
-
''' return a copy of the object
|
| 349 |
-
|
| 350 |
-
Returns
|
| 351 |
-
-------
|
| 352 |
-
Hypergraph
|
| 353 |
-
'''
|
| 354 |
-
return deepcopy(self)
|
| 355 |
-
|
| 356 |
-
def node_attr(self, node):
|
| 357 |
-
return self.hg.nodes[node]['attr_dict']
|
| 358 |
-
|
| 359 |
-
def edge_attr(self, edge):
|
| 360 |
-
return self.hg.nodes[edge]['attr_dict']
|
| 361 |
-
|
| 362 |
-
def set_node_attr(self, node, attr_dict):
|
| 363 |
-
for each_key, each_val in attr_dict.items():
|
| 364 |
-
self.hg.nodes[node]['attr_dict'][each_key] = each_val
|
| 365 |
-
|
| 366 |
-
def set_edge_attr(self, edge, attr_dict):
|
| 367 |
-
for each_key, each_val in attr_dict.items():
|
| 368 |
-
self.hg.nodes[edge]['attr_dict'][each_key] = each_val
|
| 369 |
-
|
| 370 |
-
def get_identical_node_dict(self):
|
| 371 |
-
''' get identical nodes
|
| 372 |
-
nodes are identical if they share the same set of adjacent edges.
|
| 373 |
-
|
| 374 |
-
Returns
|
| 375 |
-
-------
|
| 376 |
-
ident_node_dict : dict
|
| 377 |
-
ident_node_dict[node] returns a list of nodes that are identical to `node`.
|
| 378 |
-
'''
|
| 379 |
-
ident_node_dict = {}
|
| 380 |
-
for each_node in self.nodes:
|
| 381 |
-
ident_node_list = []
|
| 382 |
-
for each_other_node in self.nodes:
|
| 383 |
-
if each_other_node == each_node:
|
| 384 |
-
ident_node_list.append(each_other_node)
|
| 385 |
-
elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
|
| 386 |
-
and len(self.adj_edges(each_node)) != 0:
|
| 387 |
-
ident_node_list.append(each_other_node)
|
| 388 |
-
ident_node_dict[each_node] = ident_node_list
|
| 389 |
-
return ident_node_dict
|
| 390 |
-
'''
|
| 391 |
-
ident_node_dict = {}
|
| 392 |
-
for each_node in self.nodes:
|
| 393 |
-
ident_node_dict[each_node] = [each_node]
|
| 394 |
-
return ident_node_dict
|
| 395 |
-
'''
|
| 396 |
-
|
| 397 |
-
def get_leaf_edge(self):
|
| 398 |
-
''' get an edge that is incident only to one edge
|
| 399 |
-
|
| 400 |
-
Returns
|
| 401 |
-
-------
|
| 402 |
-
if exists, return a leaf edge. otherwise, return None.
|
| 403 |
-
'''
|
| 404 |
-
for each_edge in self.edges:
|
| 405 |
-
if len(self.adj_nodes(each_edge)) == 1:
|
| 406 |
-
if 'tmp' not in self.edge_attr(each_edge):
|
| 407 |
-
return each_edge
|
| 408 |
-
return None
|
| 409 |
-
|
| 410 |
-
def get_nontmp_edge(self):
|
| 411 |
-
for each_edge in self.edges:
|
| 412 |
-
if 'tmp' not in self.edge_attr(each_edge):
|
| 413 |
-
return each_edge
|
| 414 |
-
return None
|
| 415 |
-
|
| 416 |
-
def is_subhg(self, hg):
|
| 417 |
-
''' return whether this hypergraph is a subhypergraph of `hg`
|
| 418 |
-
|
| 419 |
-
Returns
|
| 420 |
-
-------
|
| 421 |
-
True if self \in hg,
|
| 422 |
-
False otherwise.
|
| 423 |
-
'''
|
| 424 |
-
for each_node in self.nodes:
|
| 425 |
-
if each_node not in hg.nodes:
|
| 426 |
-
return False
|
| 427 |
-
for each_edge in self.edges:
|
| 428 |
-
if each_edge not in hg.edges:
|
| 429 |
-
return False
|
| 430 |
-
return True
|
| 431 |
-
|
| 432 |
-
def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
|
| 433 |
-
''' if `node` is in a cycle, then return True. otherwise, False.
|
| 434 |
-
|
| 435 |
-
Parameters
|
| 436 |
-
----------
|
| 437 |
-
node : str
|
| 438 |
-
node in a hypergraph
|
| 439 |
-
visited : list
|
| 440 |
-
list of visited nodes, used for recursion
|
| 441 |
-
parent : str
|
| 442 |
-
parent node, used to eliminate a cycle consisting of two nodes and one edge.
|
| 443 |
-
|
| 444 |
-
Returns
|
| 445 |
-
-------
|
| 446 |
-
bool
|
| 447 |
-
'''
|
| 448 |
-
if visited is None:
|
| 449 |
-
visited = []
|
| 450 |
-
if parent == '':
|
| 451 |
-
visited = []
|
| 452 |
-
if root_node == '':
|
| 453 |
-
root_node = node
|
| 454 |
-
visited.append(node)
|
| 455 |
-
for each_adj_node in self.adj_nodes(node):
|
| 456 |
-
if each_adj_node not in visited:
|
| 457 |
-
if self.in_cycle(each_adj_node, visited, node, root_node):
|
| 458 |
-
return True
|
| 459 |
-
elif each_adj_node != parent and each_adj_node == root_node:
|
| 460 |
-
return True
|
| 461 |
-
return False
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
def draw(self, file_path=None, with_node=False, with_edge_name=False):
|
| 465 |
-
''' draw hypergraph
|
| 466 |
-
'''
|
| 467 |
-
import graphviz
|
| 468 |
-
G = graphviz.Graph(format='png')
|
| 469 |
-
for each_node in self.nodes:
|
| 470 |
-
if 'ext_id' in self.node_attr(each_node):
|
| 471 |
-
G.node(each_node, label='',
|
| 472 |
-
shape='circle', width='0.1', height='0.1', style='filled',
|
| 473 |
-
fillcolor='black')
|
| 474 |
-
else:
|
| 475 |
-
if with_node:
|
| 476 |
-
G.node(each_node, label='',
|
| 477 |
-
shape='circle', width='0.1', height='0.1', style='filled',
|
| 478 |
-
fillcolor='gray')
|
| 479 |
-
edge_list = []
|
| 480 |
-
for each_edge in self.edges:
|
| 481 |
-
if self.edge_attr(each_edge).get('terminal', False):
|
| 482 |
-
G.node(each_edge,
|
| 483 |
-
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
| 484 |
-
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
| 485 |
-
fontcolor='black', shape='square')
|
| 486 |
-
elif self.edge_attr(each_edge).get('tmp', False):
|
| 487 |
-
G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
|
| 488 |
-
fontcolor='black', shape='square')
|
| 489 |
-
else:
|
| 490 |
-
G.node(each_edge,
|
| 491 |
-
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
| 492 |
-
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
| 493 |
-
fontcolor='black', shape='square', style='filled')
|
| 494 |
-
if with_node:
|
| 495 |
-
for each_node in self.nodes_in_edge(each_edge):
|
| 496 |
-
G.edge(each_edge, each_node)
|
| 497 |
-
else:
|
| 498 |
-
for each_node in self.nodes_in_edge(each_edge):
|
| 499 |
-
if 'ext_id' in self.node_attr(each_node)\
|
| 500 |
-
and set([each_node, each_edge]) not in edge_list:
|
| 501 |
-
G.edge(each_edge, each_node)
|
| 502 |
-
edge_list.append(set([each_node, each_edge]))
|
| 503 |
-
for each_other_edge in self.adj_nodes(each_edge):
|
| 504 |
-
if set([each_edge, each_other_edge]) not in edge_list:
|
| 505 |
-
num_bond = 0
|
| 506 |
-
common_node_set = set(self.nodes_in_edge(each_edge))\
|
| 507 |
-
.intersection(set(self.nodes_in_edge(each_other_edge)))
|
| 508 |
-
for each_node in common_node_set:
|
| 509 |
-
if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
|
| 510 |
-
num_bond += self.node_attr(each_node)['symbol'].bond_type
|
| 511 |
-
elif self.node_attr(each_node)['symbol'].bond_type in [12]:
|
| 512 |
-
num_bond += 1
|
| 513 |
-
else:
|
| 514 |
-
raise NotImplementedError('unsupported bond type')
|
| 515 |
-
for _ in range(num_bond):
|
| 516 |
-
G.edge(each_edge, each_other_edge)
|
| 517 |
-
edge_list.append(set([each_edge, each_other_edge]))
|
| 518 |
-
if file_path is not None:
|
| 519 |
-
G.render(file_path, cleanup=True)
|
| 520 |
-
#os.remove(file_path)
|
| 521 |
-
return G
|
| 522 |
-
|
| 523 |
-
def is_dividable(self, node):
|
| 524 |
-
_hg = deepcopy(self.hg)
|
| 525 |
-
_hg.remove_node(node)
|
| 526 |
-
return (not nx.is_connected(_hg))
|
| 527 |
-
|
| 528 |
-
def divide(self, node):
|
| 529 |
-
subhg_list = []
|
| 530 |
-
|
| 531 |
-
hg_wo_node = deepcopy(self)
|
| 532 |
-
hg_wo_node.remove_node(node, remove_connected_edges=False)
|
| 533 |
-
connected_components = nx.connected_components(hg_wo_node.hg)
|
| 534 |
-
for each_component in connected_components:
|
| 535 |
-
node_list = [node]
|
| 536 |
-
edge_list = []
|
| 537 |
-
node_list.extend([each_node for each_node in each_component
|
| 538 |
-
if each_node.startswith('bond_')])
|
| 539 |
-
edge_list.extend([each_edge for each_edge in each_component
|
| 540 |
-
if each_edge.startswith('e')])
|
| 541 |
-
subhg_list.append(self.get_subhg(node_list, edge_list))
|
| 542 |
-
#subhg_list[-1].set_node_attr(node, {'divided': True})
|
| 543 |
-
return subhg_list
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/io/__init__.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 1 2018"
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/io/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (669 Bytes)
|
|
|
graph_grammar/io/__pycache__/smi.cpython-310.pyc
DELETED
|
Binary file (12.9 kB)
|
|
|
graph_grammar/io/smi.py
DELETED
|
@@ -1,559 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 12 2018"
|
| 20 |
-
|
| 21 |
-
from copy import deepcopy
|
| 22 |
-
from rdkit import Chem
|
| 23 |
-
from rdkit import RDLogger
|
| 24 |
-
import networkx as nx
|
| 25 |
-
import numpy as np
|
| 26 |
-
from ..hypergraph import Hypergraph
|
| 27 |
-
from ..graph_grammar.symbols import TSymbol, BondSymbol
|
| 28 |
-
|
| 29 |
-
# supress warnings
|
| 30 |
-
lg = RDLogger.logger()
|
| 31 |
-
lg.setLevel(RDLogger.CRITICAL)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class HGGen(object):
|
| 35 |
-
"""
|
| 36 |
-
load .smi file and yield a hypergraph.
|
| 37 |
-
|
| 38 |
-
Attributes
|
| 39 |
-
----------
|
| 40 |
-
path_to_file : str
|
| 41 |
-
path to .smi file
|
| 42 |
-
kekulize : bool
|
| 43 |
-
kekulize or not
|
| 44 |
-
add_Hs : bool
|
| 45 |
-
add implicit hydrogens to the molecule or not.
|
| 46 |
-
all_single : bool
|
| 47 |
-
if True, all multiple bonds are summarized into a single bond with some attributes
|
| 48 |
-
|
| 49 |
-
Yields
|
| 50 |
-
------
|
| 51 |
-
Hypergraph
|
| 52 |
-
"""
|
| 53 |
-
def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
|
| 54 |
-
self.num_line = 1
|
| 55 |
-
self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
|
| 56 |
-
self.kekulize = kekulize
|
| 57 |
-
self.add_Hs = add_Hs
|
| 58 |
-
self.all_single = all_single
|
| 59 |
-
|
| 60 |
-
def __iter__(self):
|
| 61 |
-
return self
|
| 62 |
-
|
| 63 |
-
def __next__(self):
|
| 64 |
-
'''
|
| 65 |
-
each_mol = None
|
| 66 |
-
while each_mol is None:
|
| 67 |
-
each_mol = next(self.mol_gen)
|
| 68 |
-
'''
|
| 69 |
-
# not ignoring parse errors
|
| 70 |
-
each_mol = next(self.mol_gen)
|
| 71 |
-
if each_mol is None:
|
| 72 |
-
raise ValueError(f'incorrect smiles in line {self.num_line}')
|
| 73 |
-
else:
|
| 74 |
-
self.num_line += 1
|
| 75 |
-
return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def mol_to_bipartite(mol, kekulize):
|
| 79 |
-
"""
|
| 80 |
-
get a bipartite representation of a molecule.
|
| 81 |
-
|
| 82 |
-
Parameters
|
| 83 |
-
----------
|
| 84 |
-
mol : rdkit.Chem.rdchem.Mol
|
| 85 |
-
molecule object
|
| 86 |
-
|
| 87 |
-
Returns
|
| 88 |
-
-------
|
| 89 |
-
nx.Graph
|
| 90 |
-
a bipartite graph representing which bond is connected to which atoms.
|
| 91 |
-
"""
|
| 92 |
-
try:
|
| 93 |
-
mol = standardize_stereo(mol)
|
| 94 |
-
except KeyError:
|
| 95 |
-
print(Chem.MolToSmiles(mol))
|
| 96 |
-
raise KeyError
|
| 97 |
-
|
| 98 |
-
if kekulize:
|
| 99 |
-
Chem.Kekulize(mol)
|
| 100 |
-
|
| 101 |
-
bipartite_g = nx.Graph()
|
| 102 |
-
for each_atom in mol.GetAtoms():
|
| 103 |
-
bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
|
| 104 |
-
atom_attr=atom_attr(each_atom, kekulize))
|
| 105 |
-
|
| 106 |
-
for each_bond in mol.GetBonds():
|
| 107 |
-
bond_idx = each_bond.GetIdx()
|
| 108 |
-
bipartite_g.add_node(
|
| 109 |
-
f"bond_{bond_idx}",
|
| 110 |
-
bond_attr=bond_attr(each_bond, kekulize))
|
| 111 |
-
bipartite_g.add_edge(
|
| 112 |
-
f"atom_{each_bond.GetBeginAtomIdx()}",
|
| 113 |
-
f"bond_{bond_idx}")
|
| 114 |
-
bipartite_g.add_edge(
|
| 115 |
-
f"atom_{each_bond.GetEndAtomIdx()}",
|
| 116 |
-
f"bond_{bond_idx}")
|
| 117 |
-
return bipartite_g
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def mol_to_hg(mol, kekulize, add_Hs):
|
| 121 |
-
"""
|
| 122 |
-
get a bipartite representation of a molecule.
|
| 123 |
-
|
| 124 |
-
Parameters
|
| 125 |
-
----------
|
| 126 |
-
mol : rdkit.Chem.rdchem.Mol
|
| 127 |
-
molecule object
|
| 128 |
-
kekulize : bool
|
| 129 |
-
kekulize or not
|
| 130 |
-
add_Hs : bool
|
| 131 |
-
add implicit hydrogens to the molecule or not.
|
| 132 |
-
|
| 133 |
-
Returns
|
| 134 |
-
-------
|
| 135 |
-
Hypergraph
|
| 136 |
-
"""
|
| 137 |
-
if add_Hs:
|
| 138 |
-
mol = Chem.AddHs(mol)
|
| 139 |
-
|
| 140 |
-
if kekulize:
|
| 141 |
-
Chem.Kekulize(mol)
|
| 142 |
-
|
| 143 |
-
bipartite_g = mol_to_bipartite(mol, kekulize)
|
| 144 |
-
hg = Hypergraph()
|
| 145 |
-
for each_atom in [each_node for each_node in bipartite_g.nodes()
|
| 146 |
-
if each_node.startswith('atom_')]:
|
| 147 |
-
node_set = set([])
|
| 148 |
-
for each_bond in bipartite_g.adj[each_atom]:
|
| 149 |
-
hg.add_node(each_bond,
|
| 150 |
-
attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
|
| 151 |
-
node_set.add(each_bond)
|
| 152 |
-
hg.add_edge(node_set,
|
| 153 |
-
attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
|
| 154 |
-
return hg
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def hg_to_mol(hg, verbose=False):
|
| 158 |
-
""" convert a hypergraph into Mol object
|
| 159 |
-
|
| 160 |
-
Parameters
|
| 161 |
-
----------
|
| 162 |
-
hg : Hypergraph
|
| 163 |
-
|
| 164 |
-
Returns
|
| 165 |
-
-------
|
| 166 |
-
mol : Chem.RWMol
|
| 167 |
-
"""
|
| 168 |
-
mol = Chem.RWMol()
|
| 169 |
-
atom_dict = {}
|
| 170 |
-
bond_set = set([])
|
| 171 |
-
for each_edge in hg.edges:
|
| 172 |
-
atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
|
| 173 |
-
atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
|
| 174 |
-
atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
|
| 175 |
-
atom.SetChiralTag(
|
| 176 |
-
Chem.rdchem.ChiralType.values[
|
| 177 |
-
hg.edge_attr(each_edge)['symbol'].chirality])
|
| 178 |
-
atom_idx = mol.AddAtom(atom)
|
| 179 |
-
atom_dict[each_edge] = atom_idx
|
| 180 |
-
|
| 181 |
-
for each_node in hg.nodes:
|
| 182 |
-
edge_1, edge_2 = hg.adj_edges(each_node)
|
| 183 |
-
if edge_1+edge_2 not in bond_set:
|
| 184 |
-
if hg.node_attr(each_node)['symbol'].bond_type <= 3:
|
| 185 |
-
num_bond = hg.node_attr(each_node)['symbol'].bond_type
|
| 186 |
-
elif hg.node_attr(each_node)['symbol'].bond_type == 12:
|
| 187 |
-
num_bond = 1
|
| 188 |
-
else:
|
| 189 |
-
raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
|
| 190 |
-
_ = mol.AddBond(atom_dict[edge_1],
|
| 191 |
-
atom_dict[edge_2],
|
| 192 |
-
order=Chem.rdchem.BondType.values[num_bond])
|
| 193 |
-
bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
|
| 194 |
-
|
| 195 |
-
# stereo
|
| 196 |
-
mol.GetBondWithIdx(bond_idx).SetStereo(
|
| 197 |
-
Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
|
| 198 |
-
bond_set.update([edge_1+edge_2])
|
| 199 |
-
bond_set.update([edge_2+edge_1])
|
| 200 |
-
mol.UpdatePropertyCache()
|
| 201 |
-
mol = mol.GetMol()
|
| 202 |
-
not_stereo_mol = deepcopy(mol)
|
| 203 |
-
if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
|
| 204 |
-
raise RuntimeError('no valid molecule was obtained.')
|
| 205 |
-
try:
|
| 206 |
-
mol = set_stereo(mol)
|
| 207 |
-
is_stereo = True
|
| 208 |
-
except:
|
| 209 |
-
import traceback
|
| 210 |
-
traceback.print_exc()
|
| 211 |
-
is_stereo = False
|
| 212 |
-
mol_tmp = deepcopy(mol)
|
| 213 |
-
Chem.SetAromaticity(mol_tmp)
|
| 214 |
-
if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
|
| 215 |
-
mol = mol_tmp
|
| 216 |
-
else:
|
| 217 |
-
if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
|
| 218 |
-
mol = not_stereo_mol
|
| 219 |
-
mol.UpdatePropertyCache()
|
| 220 |
-
Chem.GetSymmSSSR(mol)
|
| 221 |
-
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
| 222 |
-
if verbose:
|
| 223 |
-
return mol, is_stereo
|
| 224 |
-
else:
|
| 225 |
-
return mol
|
| 226 |
-
|
| 227 |
-
def hgs_to_mols(hg_list, ignore_error=False):
|
| 228 |
-
if ignore_error:
|
| 229 |
-
mol_list = []
|
| 230 |
-
for each_hg in hg_list:
|
| 231 |
-
try:
|
| 232 |
-
mol = hg_to_mol(each_hg)
|
| 233 |
-
except:
|
| 234 |
-
mol = None
|
| 235 |
-
mol_list.append(mol)
|
| 236 |
-
else:
|
| 237 |
-
mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
|
| 238 |
-
return mol_list
|
| 239 |
-
|
| 240 |
-
def hgs_to_smiles(hg_list, ignore_error=False):
|
| 241 |
-
mol_list = hgs_to_mols(hg_list, ignore_error)
|
| 242 |
-
smiles_list = []
|
| 243 |
-
for each_mol in mol_list:
|
| 244 |
-
try:
|
| 245 |
-
smiles_list.append(
|
| 246 |
-
Chem.MolToSmiles(
|
| 247 |
-
Chem.MolFromSmiles(
|
| 248 |
-
Chem.MolToSmiles(
|
| 249 |
-
each_mol))))
|
| 250 |
-
except:
|
| 251 |
-
smiles_list.append(None)
|
| 252 |
-
return smiles_list
|
| 253 |
-
|
| 254 |
-
def atom_attr(atom, kekulize):
|
| 255 |
-
"""
|
| 256 |
-
get atom's attributes
|
| 257 |
-
|
| 258 |
-
Parameters
|
| 259 |
-
----------
|
| 260 |
-
atom : rdkit.Chem.rdchem.Atom
|
| 261 |
-
kekulize : bool
|
| 262 |
-
kekulize or not
|
| 263 |
-
|
| 264 |
-
Returns
|
| 265 |
-
-------
|
| 266 |
-
atom_attr : dict
|
| 267 |
-
"is_aromatic" : bool
|
| 268 |
-
the atom is aromatic or not.
|
| 269 |
-
"smarts" : str
|
| 270 |
-
SMARTS representation of the atom.
|
| 271 |
-
"""
|
| 272 |
-
if kekulize:
|
| 273 |
-
return {'terminal': True,
|
| 274 |
-
'is_in_ring': atom.IsInRing(),
|
| 275 |
-
'symbol': TSymbol(degree=0,
|
| 276 |
-
#degree=atom.GetTotalDegree(),
|
| 277 |
-
is_aromatic=False,
|
| 278 |
-
symbol=atom.GetSymbol(),
|
| 279 |
-
num_explicit_Hs=atom.GetNumExplicitHs(),
|
| 280 |
-
formal_charge=atom.GetFormalCharge(),
|
| 281 |
-
chirality=atom.GetChiralTag().real
|
| 282 |
-
)}
|
| 283 |
-
else:
|
| 284 |
-
return {'terminal': True,
|
| 285 |
-
'is_in_ring': atom.IsInRing(),
|
| 286 |
-
'symbol': TSymbol(degree=0,
|
| 287 |
-
#degree=atom.GetTotalDegree(),
|
| 288 |
-
is_aromatic=atom.GetIsAromatic(),
|
| 289 |
-
symbol=atom.GetSymbol(),
|
| 290 |
-
num_explicit_Hs=atom.GetNumExplicitHs(),
|
| 291 |
-
formal_charge=atom.GetFormalCharge(),
|
| 292 |
-
chirality=atom.GetChiralTag().real
|
| 293 |
-
)}
|
| 294 |
-
|
| 295 |
-
def bond_attr(bond, kekulize):
|
| 296 |
-
"""
|
| 297 |
-
get atom's attributes
|
| 298 |
-
|
| 299 |
-
Parameters
|
| 300 |
-
----------
|
| 301 |
-
bond : rdkit.Chem.rdchem.Bond
|
| 302 |
-
kekulize : bool
|
| 303 |
-
kekulize or not
|
| 304 |
-
|
| 305 |
-
Returns
|
| 306 |
-
-------
|
| 307 |
-
bond_attr : dict
|
| 308 |
-
"bond_type" : int
|
| 309 |
-
{0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
|
| 310 |
-
1: rdkit.Chem.rdchem.BondType.SINGLE,
|
| 311 |
-
2: rdkit.Chem.rdchem.BondType.DOUBLE,
|
| 312 |
-
3: rdkit.Chem.rdchem.BondType.TRIPLE,
|
| 313 |
-
4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
|
| 314 |
-
5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
|
| 315 |
-
6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
|
| 316 |
-
7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
|
| 317 |
-
8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
|
| 318 |
-
9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
|
| 319 |
-
10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
|
| 320 |
-
11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
|
| 321 |
-
12: rdkit.Chem.rdchem.BondType.AROMATIC,
|
| 322 |
-
13: rdkit.Chem.rdchem.BondType.IONIC,
|
| 323 |
-
14: rdkit.Chem.rdchem.BondType.HYDROGEN,
|
| 324 |
-
15: rdkit.Chem.rdchem.BondType.THREECENTER,
|
| 325 |
-
16: rdkit.Chem.rdchem.BondType.DATIVEONE,
|
| 326 |
-
17: rdkit.Chem.rdchem.BondType.DATIVE,
|
| 327 |
-
18: rdkit.Chem.rdchem.BondType.DATIVEL,
|
| 328 |
-
19: rdkit.Chem.rdchem.BondType.DATIVER,
|
| 329 |
-
20: rdkit.Chem.rdchem.BondType.OTHER,
|
| 330 |
-
21: rdkit.Chem.rdchem.BondType.ZERO}
|
| 331 |
-
"""
|
| 332 |
-
if kekulize:
|
| 333 |
-
is_aromatic = False
|
| 334 |
-
if bond.GetBondType().real == 12:
|
| 335 |
-
bond_type = 1
|
| 336 |
-
else:
|
| 337 |
-
bond_type = bond.GetBondType().real
|
| 338 |
-
else:
|
| 339 |
-
is_aromatic = bond.GetIsAromatic()
|
| 340 |
-
bond_type = bond.GetBondType().real
|
| 341 |
-
return {'symbol': BondSymbol(is_aromatic=is_aromatic,
|
| 342 |
-
bond_type=bond_type,
|
| 343 |
-
stereo=int(bond.GetStereo())),
|
| 344 |
-
'is_in_ring': bond.IsInRing()}
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
def standardize_stereo(mol):
|
| 348 |
-
'''
|
| 349 |
-
0: rdkit.Chem.rdchem.BondDir.NONE,
|
| 350 |
-
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
| 351 |
-
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
| 352 |
-
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
| 353 |
-
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
| 354 |
-
|
| 355 |
-
'''
|
| 356 |
-
# mol = Chem.AddHs(mol) # this removes CIPRank !!!
|
| 357 |
-
for each_bond in mol.GetBonds():
|
| 358 |
-
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
| 359 |
-
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
| 360 |
-
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
| 361 |
-
atom_idx_1 = each_bond.GetStereoAtoms()[0]
|
| 362 |
-
atom_idx_2 = each_bond.GetStereoAtoms()[1]
|
| 363 |
-
if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
|
| 364 |
-
begin_atom_idx = atom_idx_1
|
| 365 |
-
end_atom_idx = atom_idx_2
|
| 366 |
-
else:
|
| 367 |
-
begin_atom_idx = atom_idx_2
|
| 368 |
-
end_atom_idx = atom_idx_1
|
| 369 |
-
|
| 370 |
-
begin_another_atom_idx = None
|
| 371 |
-
assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
|
| 372 |
-
for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
|
| 373 |
-
each_neighbor_idx = each_neighbor.GetIdx()
|
| 374 |
-
if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
|
| 375 |
-
begin_another_atom_idx = each_neighbor_idx
|
| 376 |
-
|
| 377 |
-
end_another_atom_idx = None
|
| 378 |
-
assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
|
| 379 |
-
for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
|
| 380 |
-
each_neighbor_idx = each_neighbor.GetIdx()
|
| 381 |
-
if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
|
| 382 |
-
end_another_atom_idx = each_neighbor_idx
|
| 383 |
-
|
| 384 |
-
'''
|
| 385 |
-
relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
|
| 386 |
-
'''
|
| 387 |
-
begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
|
| 388 |
-
end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
|
| 389 |
-
try:
|
| 390 |
-
begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
|
| 391 |
-
except:
|
| 392 |
-
begin_another_atom_rank = np.inf
|
| 393 |
-
try:
|
| 394 |
-
end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
|
| 395 |
-
except:
|
| 396 |
-
end_another_atom_rank = np.inf
|
| 397 |
-
if begin_atom_rank < begin_another_atom_rank\
|
| 398 |
-
and end_atom_rank < end_another_atom_rank:
|
| 399 |
-
pass
|
| 400 |
-
elif begin_atom_rank < begin_another_atom_rank\
|
| 401 |
-
and end_atom_rank > end_another_atom_rank:
|
| 402 |
-
# (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
|
| 403 |
-
if each_bond.GetStereo() == 2:
|
| 404 |
-
# set stereo
|
| 405 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
| 406 |
-
# set bond dir
|
| 407 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
| 408 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
| 409 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
| 410 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
| 411 |
-
elif each_bond.GetStereo() == 3:
|
| 412 |
-
# set stereo
|
| 413 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
| 414 |
-
# set bond dir
|
| 415 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
| 416 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
| 417 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
| 418 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
| 419 |
-
else:
|
| 420 |
-
raise ValueError
|
| 421 |
-
each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
|
| 422 |
-
elif begin_atom_rank > begin_another_atom_rank\
|
| 423 |
-
and end_atom_rank < end_another_atom_rank:
|
| 424 |
-
# (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
|
| 425 |
-
if each_bond.GetStereo() == 2:
|
| 426 |
-
# set stereo
|
| 427 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
| 428 |
-
# set bond dir
|
| 429 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
| 430 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
| 431 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
| 432 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
| 433 |
-
elif each_bond.GetStereo() == 3:
|
| 434 |
-
# set stereo
|
| 435 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
| 436 |
-
# set bond dir
|
| 437 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
| 438 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
| 439 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
| 440 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
| 441 |
-
else:
|
| 442 |
-
raise ValueError
|
| 443 |
-
each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
|
| 444 |
-
elif begin_atom_rank > begin_another_atom_rank\
|
| 445 |
-
and end_atom_rank > end_another_atom_rank:
|
| 446 |
-
# begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
|
| 447 |
-
if each_bond.GetStereo() == 2:
|
| 448 |
-
# set bond dir
|
| 449 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
| 450 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
| 451 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
| 452 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
| 453 |
-
elif each_bond.GetStereo() == 3:
|
| 454 |
-
# set bond dir
|
| 455 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
| 456 |
-
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
| 457 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
| 458 |
-
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
| 459 |
-
else:
|
| 460 |
-
raise ValueError
|
| 461 |
-
each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
|
| 462 |
-
else:
|
| 463 |
-
raise RuntimeError
|
| 464 |
-
return mol
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
def set_stereo(mol):
|
| 468 |
-
'''
|
| 469 |
-
0: rdkit.Chem.rdchem.BondDir.NONE,
|
| 470 |
-
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
| 471 |
-
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
| 472 |
-
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
| 473 |
-
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
| 474 |
-
'''
|
| 475 |
-
_mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
| 476 |
-
Chem.Kekulize(_mol, True)
|
| 477 |
-
substruct_match = mol.GetSubstructMatch(_mol)
|
| 478 |
-
if not substruct_match:
|
| 479 |
-
''' mol and _mol are kekulized.
|
| 480 |
-
sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
|
| 481 |
-
'''
|
| 482 |
-
Chem.SetAromaticity(mol)
|
| 483 |
-
Chem.SetAromaticity(_mol)
|
| 484 |
-
substruct_match = mol.GetSubstructMatch(_mol)
|
| 485 |
-
try:
|
| 486 |
-
atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
|
| 487 |
-
except:
|
| 488 |
-
raise ValueError('two molecules obtained from the same data do not match.')
|
| 489 |
-
|
| 490 |
-
for each_bond in mol.GetBonds():
|
| 491 |
-
begin_atom_idx = each_bond.GetBeginAtomIdx()
|
| 492 |
-
end_atom_idx = each_bond.GetEndAtomIdx()
|
| 493 |
-
_bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
|
| 494 |
-
_bond.SetStereo(each_bond.GetStereo())
|
| 495 |
-
|
| 496 |
-
mol = _mol
|
| 497 |
-
for each_bond in mol.GetBonds():
|
| 498 |
-
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
| 499 |
-
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
| 500 |
-
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
| 501 |
-
begin_atom_idx_set = set([each_neighbor.GetIdx()
|
| 502 |
-
for each_neighbor
|
| 503 |
-
in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
|
| 504 |
-
if each_neighbor.GetIdx() != end_stereo_atom_idx])
|
| 505 |
-
end_atom_idx_set = set([each_neighbor.GetIdx()
|
| 506 |
-
for each_neighbor
|
| 507 |
-
in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
|
| 508 |
-
if each_neighbor.GetIdx() != begin_stereo_atom_idx])
|
| 509 |
-
if not begin_atom_idx_set:
|
| 510 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
| 511 |
-
continue
|
| 512 |
-
if not end_atom_idx_set:
|
| 513 |
-
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
| 514 |
-
continue
|
| 515 |
-
if len(begin_atom_idx_set) == 1:
|
| 516 |
-
begin_atom_idx = begin_atom_idx_set.pop()
|
| 517 |
-
begin_another_atom_idx = None
|
| 518 |
-
if len(end_atom_idx_set) == 1:
|
| 519 |
-
end_atom_idx = end_atom_idx_set.pop()
|
| 520 |
-
end_another_atom_idx = None
|
| 521 |
-
if len(begin_atom_idx_set) == 2:
|
| 522 |
-
atom_idx_1 = begin_atom_idx_set.pop()
|
| 523 |
-
atom_idx_2 = begin_atom_idx_set.pop()
|
| 524 |
-
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
| 525 |
-
begin_atom_idx = atom_idx_1
|
| 526 |
-
begin_another_atom_idx = atom_idx_2
|
| 527 |
-
else:
|
| 528 |
-
begin_atom_idx = atom_idx_2
|
| 529 |
-
begin_another_atom_idx = atom_idx_1
|
| 530 |
-
if len(end_atom_idx_set) == 2:
|
| 531 |
-
atom_idx_1 = end_atom_idx_set.pop()
|
| 532 |
-
atom_idx_2 = end_atom_idx_set.pop()
|
| 533 |
-
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
| 534 |
-
end_atom_idx = atom_idx_1
|
| 535 |
-
end_another_atom_idx = atom_idx_2
|
| 536 |
-
else:
|
| 537 |
-
end_atom_idx = atom_idx_2
|
| 538 |
-
end_another_atom_idx = atom_idx_1
|
| 539 |
-
|
| 540 |
-
if each_bond.GetStereo() == 2: # same side
|
| 541 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
| 542 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
| 543 |
-
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
| 544 |
-
elif each_bond.GetStereo() == 3: # opposite side
|
| 545 |
-
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
| 546 |
-
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
| 547 |
-
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
| 548 |
-
else:
|
| 549 |
-
raise ValueError
|
| 550 |
-
return mol
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
|
| 554 |
-
if atom_idx_1 is None or atom_idx_2 is None:
|
| 555 |
-
return mol
|
| 556 |
-
else:
|
| 557 |
-
mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
|
| 558 |
-
return mol
|
| 559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/nn/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
# -*- coding:utf-8 -*-
|
| 2 |
-
# Rhizome
|
| 3 |
-
# Version beta 0.0, August 2023
|
| 4 |
-
# Property of IBM Research, Accelerated Discovery
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 9 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 10 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 11 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/nn/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (508 Bytes)
|
|
|
graph_grammar/nn/__pycache__/decoder.cpython-310.pyc
DELETED
|
Binary file (3.98 kB)
|
|
|
graph_grammar/nn/__pycache__/encoder.cpython-310.pyc
DELETED
|
Binary file (5.38 kB)
|
|
|
graph_grammar/nn/dataset.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Apr 18 2018"
|
| 20 |
-
|
| 21 |
-
from torch.utils.data import Dataset, DataLoader
|
| 22 |
-
import torch
|
| 23 |
-
import numpy as np
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
|
| 27 |
-
''' pad left
|
| 28 |
-
|
| 29 |
-
Parameters
|
| 30 |
-
----------
|
| 31 |
-
sentence_list : list of sequences of integers
|
| 32 |
-
max_len : int
|
| 33 |
-
maximum length of sentences.
|
| 34 |
-
if a sentence is shorter than `max_len`, its left part is padded.
|
| 35 |
-
pad_idx : int
|
| 36 |
-
integer for padding
|
| 37 |
-
inverse : bool
|
| 38 |
-
if True, the sequence is inversed.
|
| 39 |
-
|
| 40 |
-
Returns
|
| 41 |
-
-------
|
| 42 |
-
List of torch.LongTensor
|
| 43 |
-
each sentence is left-padded.
|
| 44 |
-
'''
|
| 45 |
-
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
| 46 |
-
|
| 47 |
-
if max_in_list > max_len:
|
| 48 |
-
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
| 49 |
-
|
| 50 |
-
if inverse:
|
| 51 |
-
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
|
| 52 |
-
else:
|
| 53 |
-
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def right_padding(sentence_list, max_len, pad_idx=-1):
|
| 57 |
-
''' pad right
|
| 58 |
-
|
| 59 |
-
Parameters
|
| 60 |
-
----------
|
| 61 |
-
sentence_list : list of sequences of integers
|
| 62 |
-
max_len : int
|
| 63 |
-
maximum length of sentences.
|
| 64 |
-
if a sentence is shorter than `max_len`, its right part is padded.
|
| 65 |
-
pad_idx : int
|
| 66 |
-
integer for padding
|
| 67 |
-
|
| 68 |
-
Returns
|
| 69 |
-
-------
|
| 70 |
-
List of torch.LongTensor
|
| 71 |
-
each sentence is right-padded.
|
| 72 |
-
'''
|
| 73 |
-
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
| 74 |
-
if max_in_list > max_len:
|
| 75 |
-
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
| 76 |
-
|
| 77 |
-
return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class HRGDataset(Dataset):
|
| 81 |
-
|
| 82 |
-
'''
|
| 83 |
-
A class of HRG data
|
| 84 |
-
'''
|
| 85 |
-
|
| 86 |
-
def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
|
| 87 |
-
self.hrg = hrg
|
| 88 |
-
self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
|
| 89 |
-
max_len,
|
| 90 |
-
inverse=inversed_input)
|
| 91 |
-
|
| 92 |
-
self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
|
| 93 |
-
self.inserved_input = inversed_input
|
| 94 |
-
self.target_val_list = target_val_list
|
| 95 |
-
if target_val_list is not None:
|
| 96 |
-
if len(prod_rule_seq_list) != len(target_val_list):
|
| 97 |
-
raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
|
| 98 |
-
|
| 99 |
-
def __len__(self):
|
| 100 |
-
return len(self.left_prod_rule_seq_list)
|
| 101 |
-
|
| 102 |
-
def __getitem__(self, idx):
|
| 103 |
-
if self.target_val_list is not None:
|
| 104 |
-
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
|
| 105 |
-
else:
|
| 106 |
-
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
|
| 107 |
-
|
| 108 |
-
@property
|
| 109 |
-
def vocab_size(self):
|
| 110 |
-
return self.hrg.num_prod_rule
|
| 111 |
-
|
| 112 |
-
def batch_padding(each_batch, batch_size, padding_idx):
|
| 113 |
-
num_pad = batch_size - len(each_batch[0])
|
| 114 |
-
if num_pad:
|
| 115 |
-
each_batch[0] = torch.cat([each_batch[0],
|
| 116 |
-
padding_idx * torch.ones((batch_size - len(each_batch[0]),
|
| 117 |
-
len(each_batch[0][0])), dtype=torch.int64)], dim=0)
|
| 118 |
-
each_batch[1] = torch.cat([each_batch[1],
|
| 119 |
-
padding_idx * torch.ones((batch_size - len(each_batch[1]),
|
| 120 |
-
len(each_batch[1][0])), dtype=torch.int64)], dim=0)
|
| 121 |
-
return each_batch, num_pad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/nn/decoder.py
DELETED
|
@@ -1,158 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Aug 9 2018"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
import abc
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
from torch import nn
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class DecoderBase(nn.Module):
|
| 29 |
-
|
| 30 |
-
def __init__(self):
|
| 31 |
-
super().__init__()
|
| 32 |
-
self.hidden_dict = {}
|
| 33 |
-
|
| 34 |
-
@abc.abstractmethod
|
| 35 |
-
def forward_one_step(self, tgt_emb_in):
|
| 36 |
-
''' one-step forward model
|
| 37 |
-
|
| 38 |
-
Parameters
|
| 39 |
-
----------
|
| 40 |
-
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
| 41 |
-
|
| 42 |
-
Returns
|
| 43 |
-
-------
|
| 44 |
-
Tensor, shape (batch_size, hidden_dim)
|
| 45 |
-
'''
|
| 46 |
-
tgt_emb_out = None
|
| 47 |
-
return tgt_emb_out
|
| 48 |
-
|
| 49 |
-
@abc.abstractmethod
|
| 50 |
-
def init_hidden(self):
|
| 51 |
-
''' initialize the hidden states
|
| 52 |
-
'''
|
| 53 |
-
pass
|
| 54 |
-
|
| 55 |
-
@abc.abstractmethod
|
| 56 |
-
def feed_hidden(self, hidden_dict_0):
|
| 57 |
-
for each_hidden in self.hidden_dict.keys():
|
| 58 |
-
self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
class GRUDecoder(DecoderBase):
|
| 62 |
-
|
| 63 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 64 |
-
dropout: float, batch_size: int, use_gpu: bool,
|
| 65 |
-
no_dropout=False):
|
| 66 |
-
super().__init__()
|
| 67 |
-
self.input_dim = input_dim
|
| 68 |
-
self.hidden_dim = hidden_dim
|
| 69 |
-
self.num_layers = num_layers
|
| 70 |
-
self.dropout = dropout
|
| 71 |
-
self.batch_size = batch_size
|
| 72 |
-
self.use_gpu = use_gpu
|
| 73 |
-
self.model = nn.GRU(input_size=self.input_dim,
|
| 74 |
-
hidden_size=self.hidden_dim,
|
| 75 |
-
num_layers=self.num_layers,
|
| 76 |
-
batch_first=True,
|
| 77 |
-
bidirectional=False,
|
| 78 |
-
dropout=self.dropout if not no_dropout else 0
|
| 79 |
-
)
|
| 80 |
-
if self.use_gpu:
|
| 81 |
-
self.model.cuda()
|
| 82 |
-
self.init_hidden()
|
| 83 |
-
|
| 84 |
-
def init_hidden(self):
|
| 85 |
-
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
| 86 |
-
self.batch_size,
|
| 87 |
-
self.hidden_dim),
|
| 88 |
-
requires_grad=False)
|
| 89 |
-
if self.use_gpu:
|
| 90 |
-
self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
|
| 91 |
-
|
| 92 |
-
def forward_one_step(self, tgt_emb_in):
|
| 93 |
-
''' one-step forward model
|
| 94 |
-
|
| 95 |
-
Parameters
|
| 96 |
-
----------
|
| 97 |
-
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
| 98 |
-
|
| 99 |
-
Returns
|
| 100 |
-
-------
|
| 101 |
-
Tensor, shape (batch_size, hidden_dim)
|
| 102 |
-
'''
|
| 103 |
-
tgt_emb_out, self.hidden_dict['h'] \
|
| 104 |
-
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
| 105 |
-
self.hidden_dict['h'])
|
| 106 |
-
return tgt_emb_out
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class LSTMDecoder(DecoderBase):
|
| 110 |
-
|
| 111 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 112 |
-
dropout: float, batch_size: int, use_gpu: bool,
|
| 113 |
-
no_dropout=False):
|
| 114 |
-
super().__init__()
|
| 115 |
-
self.input_dim = input_dim
|
| 116 |
-
self.hidden_dim = hidden_dim
|
| 117 |
-
self.num_layers = num_layers
|
| 118 |
-
self.dropout = dropout
|
| 119 |
-
self.batch_size = batch_size
|
| 120 |
-
self.use_gpu = use_gpu
|
| 121 |
-
self.model = nn.LSTM(input_size=self.input_dim,
|
| 122 |
-
hidden_size=self.hidden_dim,
|
| 123 |
-
num_layers=self.num_layers,
|
| 124 |
-
batch_first=True,
|
| 125 |
-
bidirectional=False,
|
| 126 |
-
dropout=self.dropout if not no_dropout else 0)
|
| 127 |
-
if self.use_gpu:
|
| 128 |
-
self.model.cuda()
|
| 129 |
-
self.init_hidden()
|
| 130 |
-
|
| 131 |
-
def init_hidden(self):
|
| 132 |
-
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
| 133 |
-
self.batch_size,
|
| 134 |
-
self.hidden_dim),
|
| 135 |
-
requires_grad=False)
|
| 136 |
-
self.hidden_dict['c'] = torch.zeros((self.num_layers,
|
| 137 |
-
self.batch_size,
|
| 138 |
-
self.hidden_dim),
|
| 139 |
-
requires_grad=False)
|
| 140 |
-
if self.use_gpu:
|
| 141 |
-
for each_hidden in self.hidden_dict.keys():
|
| 142 |
-
self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
|
| 143 |
-
|
| 144 |
-
def forward_one_step(self, tgt_emb_in):
|
| 145 |
-
''' one-step forward model
|
| 146 |
-
|
| 147 |
-
Parameters
|
| 148 |
-
----------
|
| 149 |
-
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
| 150 |
-
|
| 151 |
-
Returns
|
| 152 |
-
-------
|
| 153 |
-
Tensor, shape (batch_size, hidden_dim)
|
| 154 |
-
'''
|
| 155 |
-
tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
|
| 156 |
-
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
| 157 |
-
self.hidden_dict['h'], self.hidden_dict['c'])
|
| 158 |
-
return tgt_hidden_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/nn/encoder.py
DELETED
|
@@ -1,199 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Aug 9 2018"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
import abc
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
import torch.nn.functional as F
|
| 26 |
-
from torch import nn
|
| 27 |
-
from typing import List
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class EncoderBase(nn.Module):
|
| 31 |
-
|
| 32 |
-
def __init__(self):
|
| 33 |
-
super().__init__()
|
| 34 |
-
|
| 35 |
-
@abc.abstractmethod
|
| 36 |
-
def forward(self, in_seq):
|
| 37 |
-
''' forward model
|
| 38 |
-
|
| 39 |
-
Parameters
|
| 40 |
-
----------
|
| 41 |
-
in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
|
| 42 |
-
|
| 43 |
-
Returns
|
| 44 |
-
-------
|
| 45 |
-
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
| 46 |
-
'''
|
| 47 |
-
pass
|
| 48 |
-
|
| 49 |
-
@abc.abstractmethod
|
| 50 |
-
def init_hidden(self):
|
| 51 |
-
''' initialize the hidden states
|
| 52 |
-
'''
|
| 53 |
-
pass
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class GRUEncoder(EncoderBase):
|
| 57 |
-
|
| 58 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 59 |
-
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
| 60 |
-
no_dropout=False):
|
| 61 |
-
super().__init__()
|
| 62 |
-
self.input_dim = input_dim
|
| 63 |
-
self.hidden_dim = hidden_dim
|
| 64 |
-
self.num_layers = num_layers
|
| 65 |
-
self.bidirectional = bidirectional
|
| 66 |
-
self.dropout = dropout
|
| 67 |
-
self.batch_size = batch_size
|
| 68 |
-
self.use_gpu = use_gpu
|
| 69 |
-
self.model = nn.GRU(input_size=self.input_dim,
|
| 70 |
-
hidden_size=self.hidden_dim,
|
| 71 |
-
num_layers=self.num_layers,
|
| 72 |
-
batch_first=True,
|
| 73 |
-
bidirectional=self.bidirectional,
|
| 74 |
-
dropout=self.dropout if not no_dropout else 0)
|
| 75 |
-
if self.use_gpu:
|
| 76 |
-
self.model.cuda()
|
| 77 |
-
self.init_hidden()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def init_hidden(self):
|
| 81 |
-
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
| 82 |
-
self.batch_size,
|
| 83 |
-
self.hidden_dim),
|
| 84 |
-
requires_grad=False)
|
| 85 |
-
if self.use_gpu:
|
| 86 |
-
self.h0 = self.h0.cuda()
|
| 87 |
-
|
| 88 |
-
def forward(self, in_seq_emb):
|
| 89 |
-
''' forward model
|
| 90 |
-
|
| 91 |
-
Parameters
|
| 92 |
-
----------
|
| 93 |
-
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
| 94 |
-
|
| 95 |
-
Returns
|
| 96 |
-
-------
|
| 97 |
-
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
| 98 |
-
'''
|
| 99 |
-
max_len = in_seq_emb.size(1)
|
| 100 |
-
hidden_seq_emb, self.h0 = self.model(
|
| 101 |
-
in_seq_emb, self.h0)
|
| 102 |
-
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
| 103 |
-
max_len,
|
| 104 |
-
1 + self.bidirectional,
|
| 105 |
-
self.hidden_dim)
|
| 106 |
-
return hidden_seq_emb
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class LSTMEncoder(EncoderBase):
|
| 110 |
-
|
| 111 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 112 |
-
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
| 113 |
-
no_dropout=False):
|
| 114 |
-
super().__init__()
|
| 115 |
-
self.input_dim = input_dim
|
| 116 |
-
self.hidden_dim = hidden_dim
|
| 117 |
-
self.num_layers = num_layers
|
| 118 |
-
self.bidirectional = bidirectional
|
| 119 |
-
self.dropout = dropout
|
| 120 |
-
self.batch_size = batch_size
|
| 121 |
-
self.use_gpu = use_gpu
|
| 122 |
-
self.model = nn.LSTM(input_size=self.input_dim,
|
| 123 |
-
hidden_size=self.hidden_dim,
|
| 124 |
-
num_layers=self.num_layers,
|
| 125 |
-
batch_first=True,
|
| 126 |
-
bidirectional=self.bidirectional,
|
| 127 |
-
dropout=self.dropout if not no_dropout else 0)
|
| 128 |
-
if self.use_gpu:
|
| 129 |
-
self.model.cuda()
|
| 130 |
-
self.init_hidden()
|
| 131 |
-
|
| 132 |
-
def init_hidden(self):
|
| 133 |
-
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
| 134 |
-
self.batch_size,
|
| 135 |
-
self.hidden_dim),
|
| 136 |
-
requires_grad=False)
|
| 137 |
-
self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
| 138 |
-
self.batch_size,
|
| 139 |
-
self.hidden_dim),
|
| 140 |
-
requires_grad=False)
|
| 141 |
-
if self.use_gpu:
|
| 142 |
-
self.h0 = self.h0.cuda()
|
| 143 |
-
self.c0 = self.c0.cuda()
|
| 144 |
-
|
| 145 |
-
def forward(self, in_seq_emb):
|
| 146 |
-
''' forward model
|
| 147 |
-
|
| 148 |
-
Parameters
|
| 149 |
-
----------
|
| 150 |
-
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
| 151 |
-
|
| 152 |
-
Returns
|
| 153 |
-
-------
|
| 154 |
-
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
| 155 |
-
'''
|
| 156 |
-
max_len = in_seq_emb.size(1)
|
| 157 |
-
hidden_seq_emb, (self.h0, self.c0) = self.model(
|
| 158 |
-
in_seq_emb, (self.h0, self.c0))
|
| 159 |
-
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
| 160 |
-
max_len,
|
| 161 |
-
1 + self.bidirectional,
|
| 162 |
-
self.hidden_dim)
|
| 163 |
-
return hidden_seq_emb
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class FullConnectedEncoder(EncoderBase):
|
| 167 |
-
|
| 168 |
-
def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
|
| 169 |
-
batch_size: int, use_gpu: bool):
|
| 170 |
-
super().__init__()
|
| 171 |
-
self.input_dim = input_dim
|
| 172 |
-
self.hidden_dim = hidden_dim
|
| 173 |
-
self.max_len = max_len
|
| 174 |
-
self.hidden_dim_list = hidden_dim_list
|
| 175 |
-
self.use_gpu = use_gpu
|
| 176 |
-
in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
|
| 177 |
-
self.linear_list = nn.ModuleList(
|
| 178 |
-
[nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
|
| 179 |
-
for each_idx in range(len(in_out_dim_list) - 1)])
|
| 180 |
-
|
| 181 |
-
def forward(self, in_seq_emb):
|
| 182 |
-
''' forward model
|
| 183 |
-
|
| 184 |
-
Parameters
|
| 185 |
-
----------
|
| 186 |
-
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
| 187 |
-
|
| 188 |
-
Returns
|
| 189 |
-
-------
|
| 190 |
-
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
| 191 |
-
'''
|
| 192 |
-
batch_size = in_seq_emb.size(0)
|
| 193 |
-
x = in_seq_emb.view(batch_size, -1)
|
| 194 |
-
for each_linear in self.linear_list:
|
| 195 |
-
x = F.relu(each_linear(x))
|
| 196 |
-
return x.view(batch_size, 1, -1)
|
| 197 |
-
|
| 198 |
-
def init_hidden(self):
|
| 199 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_grammar/nn/graph.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
# Rhizome
|
| 4 |
-
# Version beta 0.0, August 2023
|
| 5 |
-
# Property of IBM Research, Accelerated Discovery
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
| 10 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
""" Title """
|
| 15 |
-
|
| 16 |
-
__author__ = "Hiroshi Kajino <KAJINO@jp.ibm.com>"
|
| 17 |
-
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
| 18 |
-
__version__ = "0.1"
|
| 19 |
-
__date__ = "Jan 1 2018"
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn.functional as F
|
| 24 |
-
from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
|
| 25 |
-
from torch import nn
|
| 26 |
-
from torch.autograd import Variable
|
| 27 |
-
|
| 28 |
-
class MolecularProdRuleEmbedding(nn.Module):
|
| 29 |
-
|
| 30 |
-
''' molecular fingerprint layer
|
| 31 |
-
'''
|
| 32 |
-
|
| 33 |
-
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
| 34 |
-
out_dim=32, element_embed_dim=32,
|
| 35 |
-
num_layers=3, padding_idx=None, use_gpu=False):
|
| 36 |
-
super().__init__()
|
| 37 |
-
if padding_idx is not None:
|
| 38 |
-
assert padding_idx == -1, 'padding_idx must be -1.'
|
| 39 |
-
self.prod_rule_corpus = prod_rule_corpus
|
| 40 |
-
self.layer2layer_activation = layer2layer_activation
|
| 41 |
-
self.layer2out_activation = layer2out_activation
|
| 42 |
-
self.out_dim = out_dim
|
| 43 |
-
self.element_embed_dim = element_embed_dim
|
| 44 |
-
self.num_layers = num_layers
|
| 45 |
-
self.padding_idx = padding_idx
|
| 46 |
-
self.use_gpu = use_gpu
|
| 47 |
-
|
| 48 |
-
self.layer2layer_list = []
|
| 49 |
-
self.layer2out_list = []
|
| 50 |
-
|
| 51 |
-
if self.use_gpu:
|
| 52 |
-
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
| 53 |
-
self.element_embed_dim, requires_grad=True).cuda()
|
| 54 |
-
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
| 55 |
-
self.element_embed_dim, requires_grad=True).cuda()
|
| 56 |
-
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
| 57 |
-
self.element_embed_dim, requires_grad=True).cuda()
|
| 58 |
-
for _ in range(num_layers):
|
| 59 |
-
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
| 60 |
-
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
| 61 |
-
else:
|
| 62 |
-
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
| 63 |
-
self.element_embed_dim, requires_grad=True)
|
| 64 |
-
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
| 65 |
-
self.element_embed_dim, requires_grad=True)
|
| 66 |
-
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
| 67 |
-
self.element_embed_dim, requires_grad=True)
|
| 68 |
-
for _ in range(num_layers):
|
| 69 |
-
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
| 70 |
-
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def forward(self, prod_rule_idx_seq):
|
| 74 |
-
''' forward model for mini-batch
|
| 75 |
-
|
| 76 |
-
Parameters
|
| 77 |
-
----------
|
| 78 |
-
prod_rule_idx_seq : (batch_size, length)
|
| 79 |
-
|
| 80 |
-
Returns
|
| 81 |
-
-------
|
| 82 |
-
Variable, shape (batch_size, length, out_dim)
|
| 83 |
-
'''
|
| 84 |
-
batch_size, length = prod_rule_idx_seq.shape
|
| 85 |
-
if self.use_gpu:
|
| 86 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
| 87 |
-
else:
|
| 88 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
| 89 |
-
for each_batch_idx in range(batch_size):
|
| 90 |
-
for each_idx in range(length):
|
| 91 |
-
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
| 92 |
-
continue
|
| 93 |
-
else:
|
| 94 |
-
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
| 95 |
-
layer_wise_embed_dict = {each_edge: self.atom_embed[
|
| 96 |
-
each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
| 97 |
-
for each_edge in each_prod_rule.rhs.edges}
|
| 98 |
-
layer_wise_embed_dict.update({each_node: self.bond_embed[
|
| 99 |
-
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
|
| 100 |
-
for each_node in each_prod_rule.rhs.nodes})
|
| 101 |
-
for each_node in each_prod_rule.rhs.nodes:
|
| 102 |
-
if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
|
| 103 |
-
layer_wise_embed_dict[each_node] \
|
| 104 |
-
= layer_wise_embed_dict[each_node] \
|
| 105 |
-
+ self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
|
| 106 |
-
|
| 107 |
-
for each_layer in range(self.num_layers):
|
| 108 |
-
next_layer_embed_dict = {}
|
| 109 |
-
for each_edge in each_prod_rule.rhs.edges:
|
| 110 |
-
v = layer_wise_embed_dict[each_edge]
|
| 111 |
-
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
| 112 |
-
v = v + layer_wise_embed_dict[each_node]
|
| 113 |
-
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
| 114 |
-
out[each_batch_idx, each_idx, :] \
|
| 115 |
-
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
| 116 |
-
for each_node in each_prod_rule.rhs.nodes:
|
| 117 |
-
v = layer_wise_embed_dict[each_node]
|
| 118 |
-
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
| 119 |
-
v = v + layer_wise_embed_dict[each_edge]
|
| 120 |
-
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
| 121 |
-
out[each_batch_idx, each_idx, :]\
|
| 122 |
-
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
| 123 |
-
layer_wise_embed_dict = next_layer_embed_dict
|
| 124 |
-
|
| 125 |
-
return out
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
class MolecularProdRuleEmbeddingLastLayer(nn.Module):
|
| 129 |
-
|
| 130 |
-
''' molecular fingerprint layer
|
| 131 |
-
'''
|
| 132 |
-
|
| 133 |
-
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
| 134 |
-
out_dim=32, element_embed_dim=32,
|
| 135 |
-
num_layers=3, padding_idx=None, use_gpu=False):
|
| 136 |
-
super().__init__()
|
| 137 |
-
if padding_idx is not None:
|
| 138 |
-
assert padding_idx == -1, 'padding_idx must be -1.'
|
| 139 |
-
self.prod_rule_corpus = prod_rule_corpus
|
| 140 |
-
self.layer2layer_activation = layer2layer_activation
|
| 141 |
-
self.layer2out_activation = layer2out_activation
|
| 142 |
-
self.out_dim = out_dim
|
| 143 |
-
self.element_embed_dim = element_embed_dim
|
| 144 |
-
self.num_layers = num_layers
|
| 145 |
-
self.padding_idx = padding_idx
|
| 146 |
-
self.use_gpu = use_gpu
|
| 147 |
-
|
| 148 |
-
self.layer2layer_list = []
|
| 149 |
-
self.layer2out_list = []
|
| 150 |
-
|
| 151 |
-
if self.use_gpu:
|
| 152 |
-
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
|
| 153 |
-
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
|
| 154 |
-
for _ in range(num_layers+1):
|
| 155 |
-
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
| 156 |
-
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
| 157 |
-
else:
|
| 158 |
-
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
|
| 159 |
-
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
|
| 160 |
-
for _ in range(num_layers+1):
|
| 161 |
-
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
| 162 |
-
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def forward(self, prod_rule_idx_seq):
|
| 166 |
-
''' forward model for mini-batch
|
| 167 |
-
|
| 168 |
-
Parameters
|
| 169 |
-
----------
|
| 170 |
-
prod_rule_idx_seq : (batch_size, length)
|
| 171 |
-
|
| 172 |
-
Returns
|
| 173 |
-
-------
|
| 174 |
-
Variable, shape (batch_size, length, out_dim)
|
| 175 |
-
'''
|
| 176 |
-
batch_size, length = prod_rule_idx_seq.shape
|
| 177 |
-
if self.use_gpu:
|
| 178 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
| 179 |
-
else:
|
| 180 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
| 181 |
-
for each_batch_idx in range(batch_size):
|
| 182 |
-
for each_idx in range(length):
|
| 183 |
-
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
| 184 |
-
continue
|
| 185 |
-
else:
|
| 186 |
-
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
| 187 |
-
|
| 188 |
-
if self.use_gpu:
|
| 189 |
-
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
| 190 |
-
Variable(torch.LongTensor(
|
| 191 |
-
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
| 192 |
-
), requires_grad=False).cuda())
|
| 193 |
-
for each_edge in each_prod_rule.rhs.edges}
|
| 194 |
-
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
| 195 |
-
Variable(
|
| 196 |
-
torch.LongTensor([
|
| 197 |
-
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
| 198 |
-
requires_grad=False).cuda()
|
| 199 |
-
) for each_node in each_prod_rule.rhs.nodes})
|
| 200 |
-
else:
|
| 201 |
-
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
| 202 |
-
Variable(torch.LongTensor(
|
| 203 |
-
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
| 204 |
-
), requires_grad=False))
|
| 205 |
-
for each_edge in each_prod_rule.rhs.edges}
|
| 206 |
-
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
| 207 |
-
Variable(
|
| 208 |
-
torch.LongTensor([
|
| 209 |
-
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
| 210 |
-
requires_grad=False)
|
| 211 |
-
) for each_node in each_prod_rule.rhs.nodes})
|
| 212 |
-
|
| 213 |
-
for each_layer in range(self.num_layers):
|
| 214 |
-
next_layer_embed_dict = {}
|
| 215 |
-
for each_edge in each_prod_rule.rhs.edges:
|
| 216 |
-
v = layer_wise_embed_dict[each_edge]
|
| 217 |
-
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
| 218 |
-
v += layer_wise_embed_dict[each_node]
|
| 219 |
-
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
| 220 |
-
for each_node in each_prod_rule.rhs.nodes:
|
| 221 |
-
v = layer_wise_embed_dict[each_node]
|
| 222 |
-
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
| 223 |
-
v += layer_wise_embed_dict[each_edge]
|
| 224 |
-
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
| 225 |
-
layer_wise_embed_dict = next_layer_embed_dict
|
| 226 |
-
for each_edge in each_prod_rule.rhs.edges:
|
| 227 |
-
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
| 228 |
-
for each_edge in each_prod_rule.rhs.edges:
|
| 229 |
-
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
| 230 |
-
|
| 231 |
-
return out
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
|
| 235 |
-
|
| 236 |
-
''' molecular fingerprint layer
|
| 237 |
-
'''
|
| 238 |
-
|
| 239 |
-
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
| 240 |
-
out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
|
| 241 |
-
super().__init__()
|
| 242 |
-
if padding_idx is not None:
|
| 243 |
-
assert padding_idx == -1, 'padding_idx must be -1.'
|
| 244 |
-
self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
|
| 245 |
-
self.prod_rule_corpus = prod_rule_corpus
|
| 246 |
-
self.layer2layer_activation = layer2layer_activation
|
| 247 |
-
self.layer2out_activation = layer2out_activation
|
| 248 |
-
self.out_dim = out_dim
|
| 249 |
-
self.num_layers = num_layers
|
| 250 |
-
self.padding_idx = padding_idx
|
| 251 |
-
self.use_gpu = use_gpu
|
| 252 |
-
|
| 253 |
-
self.layer2layer_list = []
|
| 254 |
-
self.layer2out_list = []
|
| 255 |
-
|
| 256 |
-
if self.use_gpu:
|
| 257 |
-
for each_key in self.feature_dict:
|
| 258 |
-
self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
|
| 259 |
-
for _ in range(num_layers):
|
| 260 |
-
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
|
| 261 |
-
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
|
| 262 |
-
else:
|
| 263 |
-
for _ in range(num_layers):
|
| 264 |
-
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
|
| 265 |
-
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def forward(self, prod_rule_idx_seq):
|
| 269 |
-
''' forward model for mini-batch
|
| 270 |
-
|
| 271 |
-
Parameters
|
| 272 |
-
----------
|
| 273 |
-
prod_rule_idx_seq : (batch_size, length)
|
| 274 |
-
|
| 275 |
-
Returns
|
| 276 |
-
-------
|
| 277 |
-
Variable, shape (batch_size, length, out_dim)
|
| 278 |
-
'''
|
| 279 |
-
batch_size, length = prod_rule_idx_seq.shape
|
| 280 |
-
if self.use_gpu:
|
| 281 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
| 282 |
-
else:
|
| 283 |
-
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
| 284 |
-
for each_batch_idx in range(batch_size):
|
| 285 |
-
for each_idx in range(length):
|
| 286 |
-
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
| 287 |
-
continue
|
| 288 |
-
else:
|
| 289 |
-
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
| 290 |
-
edge_list = sorted(list(each_prod_rule.rhs.edges))
|
| 291 |
-
node_list = sorted(list(each_prod_rule.rhs.nodes))
|
| 292 |
-
adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
|
| 293 |
-
if self.use_gpu:
|
| 294 |
-
adj_mat = adj_mat.cuda()
|
| 295 |
-
layer_wise_embed = [
|
| 296 |
-
self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
| 297 |
-
for each_edge in edge_list]\
|
| 298 |
-
+ [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
|
| 299 |
-
for each_node in node_list]
|
| 300 |
-
for each_node in each_prod_rule.ext_node.values():
|
| 301 |
-
layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
| 302 |
-
= layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
| 303 |
-
+ self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
|
| 304 |
-
layer_wise_embed = torch.stack(layer_wise_embed)
|
| 305 |
-
|
| 306 |
-
for each_layer in range(self.num_layers):
|
| 307 |
-
message = adj_mat @ layer_wise_embed
|
| 308 |
-
next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
|
| 309 |
-
out[each_batch_idx, each_idx, :] \
|
| 310 |
-
= out[each_batch_idx, each_idx, :] \
|
| 311 |
-
+ self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
|
| 312 |
-
layer_wise_embed = next_layer_embed
|
| 313 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
| 1 |
-
# -*- coding:utf-8 -*-
|
| 2 |
-
# Rhizome
|
| 3 |
-
# Version beta 0.0, August 2023
|
| 4 |
-
# Property of IBM Research, Accelerated Discovery
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
import pickle
|
| 9 |
-
import sys
|
| 10 |
-
|
| 11 |
-
from rdkit import Chem
|
| 12 |
-
import torch
|
| 13 |
-
from torch_geometric.utils.smiles import from_smiles
|
| 14 |
-
|
| 15 |
-
from typing import Any, Dict, List, Optional, Union
|
| 16 |
-
from typing_extensions import Self
|
| 17 |
-
|
| 18 |
-
from .graph_grammar.io.smi import hg_to_mol
|
| 19 |
-
from .models.mhgvae import GrammarGINVAE
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class PretrainedModelWrapper:
|
| 23 |
-
model: GrammarGINVAE
|
| 24 |
-
|
| 25 |
-
def __init__(self, model_dict: Dict[str, Any]) -> None:
|
| 26 |
-
json_params = model_dict['gnn_params']
|
| 27 |
-
encoder_params = json_params['encoder_params']
|
| 28 |
-
encoder_params['node_feature_size'] = model_dict['num_features']
|
| 29 |
-
encoder_params['edge_feature_size'] = model_dict['num_edge_features']
|
| 30 |
-
self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
|
| 31 |
-
decoder_params=json_params['decoder_params'],
|
| 32 |
-
prod_rule_embed_params=json_params["prod_rule_embed_params"],
|
| 33 |
-
batch_size=512, max_len=model_dict['max_length'])
|
| 34 |
-
self.model.load_state_dict(model_dict['model_state_dict'])
|
| 35 |
-
|
| 36 |
-
self.model.eval()
|
| 37 |
-
|
| 38 |
-
def to(self, device: Union[str, int, torch.device]) -> Self:
|
| 39 |
-
dev_type = type(device)
|
| 40 |
-
if dev_type != torch.device:
|
| 41 |
-
if dev_type == str or torch.cuda.is_available():
|
| 42 |
-
device = torch.device(device)
|
| 43 |
-
else:
|
| 44 |
-
device = torch.device("mps", device)
|
| 45 |
-
|
| 46 |
-
self.model = self.model.to(device)
|
| 47 |
-
return self
|
| 48 |
-
|
| 49 |
-
def encode(self, data: List[str]) -> List[torch.tensor]:
|
| 50 |
-
# Need to encode them into a graph nn
|
| 51 |
-
output = []
|
| 52 |
-
for d in data:
|
| 53 |
-
params = next(self.model.parameters())
|
| 54 |
-
g = from_smiles(d)
|
| 55 |
-
if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
|
| 56 |
-
g.to(params.device)
|
| 57 |
-
ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
|
| 58 |
-
output.append(ltvec[0])
|
| 59 |
-
return output
|
| 60 |
-
|
| 61 |
-
def decode(self, data: List[torch.tensor]) -> List[str]:
|
| 62 |
-
output = []
|
| 63 |
-
for d in data:
|
| 64 |
-
mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
|
| 65 |
-
z = self.model.reparameterize(mu, logvar)
|
| 66 |
-
flags, _, hgs = self.model.decode(z)
|
| 67 |
-
if flags[0]:
|
| 68 |
-
reconstructed_mol, _ = hg_to_mol(hgs[0], True)
|
| 69 |
-
output.append(Chem.MolToSmiles(reconstructed_mol))
|
| 70 |
-
else:
|
| 71 |
-
output.append(None)
|
| 72 |
-
return output
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
|
| 76 |
-
PretrainedModelWrapper]:
|
| 77 |
-
for p in sys.path:
|
| 78 |
-
file = p + "/" + model_name
|
| 79 |
-
if os.path.isfile(file):
|
| 80 |
-
with open(file, "rb") as f:
|
| 81 |
-
model_dict = pickle.load(f)
|
| 82 |
-
return PretrainedModelWrapper(model_dict)
|
| 83 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mhg_gnn.egg-info/PKG-INFO
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
Metadata-Version: 2.1
|
| 2 |
-
Name: mhg-gnn
|
| 3 |
-
Version: 0.0
|
| 4 |
-
Summary: Package for mhg-gnn
|
| 5 |
-
Author: team
|
| 6 |
-
License: TBD
|
| 7 |
-
Classifier: Programming Language :: Python :: 3
|
| 8 |
-
Classifier: Programming Language :: Python :: 3.9
|
| 9 |
-
Description-Content-Type: text/markdown
|
| 10 |
-
Requires-Dist: networkx>=2.8
|
| 11 |
-
Requires-Dist: numpy<2.0.0,>=1.23.5
|
| 12 |
-
Requires-Dist: pandas>=1.5.3
|
| 13 |
-
Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
|
| 14 |
-
Requires-Dist: torch>=2.0.0
|
| 15 |
-
Requires-Dist: torchinfo>=1.8.0
|
| 16 |
-
Requires-Dist: torch-geometric>=2.3.1
|
| 17 |
-
|
| 18 |
-
# mhg-gnn
|
| 19 |
-
|
| 20 |
-
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
| 21 |
-
|
| 22 |
-
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
| 23 |
-
|
| 24 |
-
For more information contact: SEIJITKD@jp.ibm.com
|
| 25 |
-
|
| 26 |
-

|
| 27 |
-
|
| 28 |
-
## Introduction
|
| 29 |
-
|
| 30 |
-
We present MHG-GNN, an autoencoder architecture
|
| 31 |
-
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
| 32 |
-
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
| 33 |
-
demonstrate high predictive performance on molecular graph data.
|
| 34 |
-
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
| 35 |
-
|
| 36 |
-
## Table of Contents
|
| 37 |
-
|
| 38 |
-
1. [Getting Started](#getting-started)
|
| 39 |
-
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
| 40 |
-
2. [Replicating Conda Environment](#replicating-conda-environment)
|
| 41 |
-
2. [Feature Extraction](#feature-extraction)
|
| 42 |
-
|
| 43 |
-
## Getting Started
|
| 44 |
-
|
| 45 |
-
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
| 46 |
-
|
| 47 |
-
### Pretrained Models and Training Logs
|
| 48 |
-
|
| 49 |
-
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
|
| 50 |
-
|
| 51 |
-
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
| 52 |
-
|
| 53 |
-
### Replacicating Conda Environment
|
| 54 |
-
|
| 55 |
-
Follow these steps to replicate our Conda environment and install the necessary libraries:
|
| 56 |
-
|
| 57 |
-
```
|
| 58 |
-
conda create --name mhg-gnn-env python=3.8.18
|
| 59 |
-
conda activate mhg-gnn-env
|
| 60 |
-
```
|
| 61 |
-
|
| 62 |
-
#### Install Packages with Conda
|
| 63 |
-
|
| 64 |
-
```
|
| 65 |
-
conda install -c conda-forge networkx=2.8
|
| 66 |
-
conda install numpy=1.23.5
|
| 67 |
-
# conda install -c conda-forge rdkit=2022.9.4
|
| 68 |
-
conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
|
| 69 |
-
conda install -c conda-forge torchinfo=1.8.0
|
| 70 |
-
conda install pyg -c pyg
|
| 71 |
-
```
|
| 72 |
-
|
| 73 |
-
#### Install Packages with pip
|
| 74 |
-
```
|
| 75 |
-
pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
|
| 76 |
-
```
|
| 77 |
-
|
| 78 |
-
## Feature Extraction
|
| 79 |
-
|
| 80 |
-
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
| 81 |
-
|
| 82 |
-
To load mhg-gnn, you can simply use:
|
| 83 |
-
|
| 84 |
-
```python
|
| 85 |
-
import torch
|
| 86 |
-
import load
|
| 87 |
-
|
| 88 |
-
model = load.load()
|
| 89 |
-
```
|
| 90 |
-
|
| 91 |
-
To encode SMILES into embeddings, you can use:
|
| 92 |
-
|
| 93 |
-
```python
|
| 94 |
-
with torch.no_grad():
|
| 95 |
-
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
| 96 |
-
```
|
| 97 |
-
|
| 98 |
-
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
| 99 |
-
|
| 100 |
-
```python
|
| 101 |
-
orig = model.decode(repr)
|
| 102 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mhg_gnn.egg-info/SOURCES.txt
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
README.md
|
| 2 |
-
setup.cfg
|
| 3 |
-
setup.py
|
| 4 |
-
./graph_grammar/__init__.py
|
| 5 |
-
./graph_grammar/hypergraph.py
|
| 6 |
-
./graph_grammar/algo/__init__.py
|
| 7 |
-
./graph_grammar/algo/tree_decomposition.py
|
| 8 |
-
./graph_grammar/graph_grammar/__init__.py
|
| 9 |
-
./graph_grammar/graph_grammar/base.py
|
| 10 |
-
./graph_grammar/graph_grammar/corpus.py
|
| 11 |
-
./graph_grammar/graph_grammar/hrg.py
|
| 12 |
-
./graph_grammar/graph_grammar/symbols.py
|
| 13 |
-
./graph_grammar/graph_grammar/utils.py
|
| 14 |
-
./graph_grammar/io/__init__.py
|
| 15 |
-
./graph_grammar/io/smi.py
|
| 16 |
-
./graph_grammar/nn/__init__.py
|
| 17 |
-
./graph_grammar/nn/dataset.py
|
| 18 |
-
./graph_grammar/nn/decoder.py
|
| 19 |
-
./graph_grammar/nn/encoder.py
|
| 20 |
-
./graph_grammar/nn/graph.py
|
| 21 |
-
./models/__init__.py
|
| 22 |
-
./models/mhgvae.py
|
| 23 |
-
graph_grammar/__init__.py
|
| 24 |
-
graph_grammar/hypergraph.py
|
| 25 |
-
graph_grammar/algo/__init__.py
|
| 26 |
-
graph_grammar/algo/tree_decomposition.py
|
| 27 |
-
graph_grammar/graph_grammar/__init__.py
|
| 28 |
-
graph_grammar/graph_grammar/base.py
|
| 29 |
-
graph_grammar/graph_grammar/corpus.py
|
| 30 |
-
graph_grammar/graph_grammar/hrg.py
|
| 31 |
-
graph_grammar/graph_grammar/symbols.py
|
| 32 |
-
graph_grammar/graph_grammar/utils.py
|
| 33 |
-
graph_grammar/io/__init__.py
|
| 34 |
-
graph_grammar/io/smi.py
|
| 35 |
-
graph_grammar/nn/__init__.py
|
| 36 |
-
graph_grammar/nn/dataset.py
|
| 37 |
-
graph_grammar/nn/decoder.py
|
| 38 |
-
graph_grammar/nn/encoder.py
|
| 39 |
-
graph_grammar/nn/graph.py
|
| 40 |
-
mhg_gnn.egg-info/PKG-INFO
|
| 41 |
-
mhg_gnn.egg-info/SOURCES.txt
|
| 42 |
-
mhg_gnn.egg-info/dependency_links.txt
|
| 43 |
-
mhg_gnn.egg-info/requires.txt
|
| 44 |
-
mhg_gnn.egg-info/top_level.txt
|
| 45 |
-
models/__init__.py
|
| 46 |
-
models/mhgvae.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mhg_gnn.egg-info/dependency_links.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
mhg_gnn.egg-info/requires.txt
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
networkx>=2.8
|
| 2 |
-
numpy<2.0.0,>=1.23.5
|
| 3 |
-
pandas>=1.5.3
|
| 4 |
-
rdkit-pypi<2023.9.6,>=2022.9.4
|
| 5 |
-
torch>=2.0.0
|
| 6 |
-
torchinfo>=1.8.0
|
| 7 |
-
torch-geometric>=2.3.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mhg_gnn.egg-info/top_level.txt
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
graph_grammar
|
| 2 |
-
models
|
|
|
|
|
|
|
|
|
pickles/mhggnn_pretrained_model_0724_2023.pickle → mhggnn_pretrained_model_0724_2023.pickle
RENAMED
|
File without changes
|
models/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# -*- coding:utf-8 -*-
|
| 2 |
-
# Rhizome
|
| 3 |
-
# Version beta 0.0, August 2023
|
| 4 |
-
# Property of IBM Research, Accelerated Discovery
|
| 5 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (221 Bytes)
|
|
|
models/__pycache__/mhgvae.cpython-310.pyc
DELETED
|
Binary file (24.8 kB)
|
|
|
models/mhgvae.py
DELETED
|
@@ -1,956 +0,0 @@
|
|
| 1 |
-
# -*- coding:utf-8 -*-
|
| 2 |
-
# Rhizome
|
| 3 |
-
# Version beta 0.0, August 2023
|
| 4 |
-
# Property of IBM Research, Accelerated Discovery
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE
|
| 9 |
-
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE,
|
| 10 |
-
E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE.
|
| 11 |
-
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
import numpy as np
|
| 15 |
-
import logging
|
| 16 |
-
|
| 17 |
-
import torch
|
| 18 |
-
from torch.autograd import Variable
|
| 19 |
-
import torch.nn as nn
|
| 20 |
-
import torch.nn.functional as F
|
| 21 |
-
from torch.nn.modules.loss import _Loss
|
| 22 |
-
|
| 23 |
-
from torch_geometric.nn import MessagePassing
|
| 24 |
-
from torch_geometric.nn import global_add_pool
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
from ..graph_grammar.graph_grammar.symbols import NTSymbol
|
| 28 |
-
from ..graph_grammar.nn.encoder import EncoderBase
|
| 29 |
-
from ..graph_grammar.nn.decoder import DecoderBase
|
| 30 |
-
|
| 31 |
-
def get_atom_edge_feature_dims():
|
| 32 |
-
from torch_geometric.utils.smiles import x_map, e_map
|
| 33 |
-
func = lambda x: len(x[1])
|
| 34 |
-
return list(map(func, x_map.items())), list(map(func, e_map.items()))
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class FeatureEmbedding(nn.Module):
|
| 38 |
-
def __init__(self, input_dims, embedded_dim):
|
| 39 |
-
super().__init__()
|
| 40 |
-
self.embedding_list = nn.ModuleList()
|
| 41 |
-
for dim in input_dims:
|
| 42 |
-
embedding = nn.Embedding(dim, embedded_dim)
|
| 43 |
-
self.embedding_list.append(embedding)
|
| 44 |
-
|
| 45 |
-
def forward(self, x):
|
| 46 |
-
output = 0
|
| 47 |
-
for i in range(x.shape[1]):
|
| 48 |
-
input = x[:, i].to(torch.int)
|
| 49 |
-
device = next(self.parameters()).device
|
| 50 |
-
if device != input.device:
|
| 51 |
-
input = input.to(device)
|
| 52 |
-
emb = self.embedding_list[i](input)
|
| 53 |
-
output += emb
|
| 54 |
-
return output
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class GRUEncoder(EncoderBase):
|
| 58 |
-
|
| 59 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 60 |
-
bidirectional: bool, dropout: float, batch_size: int, rank: int=-1,
|
| 61 |
-
no_dropout: bool=False):
|
| 62 |
-
super().__init__()
|
| 63 |
-
self.input_dim = input_dim
|
| 64 |
-
self.hidden_dim = hidden_dim
|
| 65 |
-
self.num_layers = num_layers
|
| 66 |
-
self.bidirectional = bidirectional
|
| 67 |
-
self.dropout = dropout
|
| 68 |
-
self.batch_size = batch_size
|
| 69 |
-
self.rank = rank
|
| 70 |
-
self.model = nn.GRU(input_size=self.input_dim,
|
| 71 |
-
hidden_size=self.hidden_dim,
|
| 72 |
-
num_layers=self.num_layers,
|
| 73 |
-
batch_first=True,
|
| 74 |
-
bidirectional=self.bidirectional,
|
| 75 |
-
dropout=self.dropout if not no_dropout else 0)
|
| 76 |
-
if self.rank >= 0:
|
| 77 |
-
if torch.cuda.is_available():
|
| 78 |
-
self.model = self.model.to(rank)
|
| 79 |
-
else:
|
| 80 |
-
# support mac mps
|
| 81 |
-
self.model = self.model.to(torch.device("mps", rank))
|
| 82 |
-
self.init_hidden(self.batch_size)
|
| 83 |
-
|
| 84 |
-
def init_hidden(self, bsize):
|
| 85 |
-
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
| 86 |
-
min(self.batch_size, bsize),
|
| 87 |
-
self.hidden_dim),
|
| 88 |
-
requires_grad=False)
|
| 89 |
-
if self.rank >= 0:
|
| 90 |
-
if torch.cuda.is_available():
|
| 91 |
-
self.h0 = self.h0.to(self.rank)
|
| 92 |
-
else:
|
| 93 |
-
# support mac mps
|
| 94 |
-
self.h0 = self.h0.to(torch.device("mps", self.rank))
|
| 95 |
-
|
| 96 |
-
def to(self, device):
|
| 97 |
-
newself = super().to(device)
|
| 98 |
-
newself.model = newself.model.to(device)
|
| 99 |
-
newself.h0 = newself.h0.to(device)
|
| 100 |
-
newself.rank = next(newself.parameters()).get_device()
|
| 101 |
-
return newself
|
| 102 |
-
|
| 103 |
-
def forward(self, in_seq_emb):
|
| 104 |
-
''' forward model
|
| 105 |
-
|
| 106 |
-
Parameters
|
| 107 |
-
----------
|
| 108 |
-
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
| 109 |
-
|
| 110 |
-
Returns
|
| 111 |
-
-------
|
| 112 |
-
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
| 113 |
-
'''
|
| 114 |
-
# Kishi: I think original MHG had this init_hidden()
|
| 115 |
-
self.init_hidden(in_seq_emb.size(0))
|
| 116 |
-
max_len = in_seq_emb.size(1)
|
| 117 |
-
hidden_seq_emb, self.h0 = self.model(
|
| 118 |
-
in_seq_emb, self.h0)
|
| 119 |
-
# As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) -->
|
| 120 |
-
# (batch_size, seq_len, 1 or 2, hidden_size)
|
| 121 |
-
# In the original input the original GRU/LSTM with bidirectional encoding
|
| 122 |
-
# has contactinated tensors
|
| 123 |
-
# (first half for forward RNN, latter half for backward RNN)
|
| 124 |
-
# so convert them in a more friendly format packed for each RNN
|
| 125 |
-
hidden_seq_emb = hidden_seq_emb.view(-1,
|
| 126 |
-
max_len,
|
| 127 |
-
1 + self.bidirectional,
|
| 128 |
-
self.hidden_dim)
|
| 129 |
-
return hidden_seq_emb
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class GRUDecoder(DecoderBase):
|
| 133 |
-
|
| 134 |
-
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
| 135 |
-
dropout: float, batch_size: int, rank: int=-1,
|
| 136 |
-
no_dropout: bool=False):
|
| 137 |
-
super().__init__()
|
| 138 |
-
self.input_dim = input_dim
|
| 139 |
-
self.hidden_dim = hidden_dim
|
| 140 |
-
self.num_layers = num_layers
|
| 141 |
-
self.dropout = dropout
|
| 142 |
-
self.batch_size = batch_size
|
| 143 |
-
self.rank = rank
|
| 144 |
-
self.model = nn.GRU(input_size=self.input_dim,
|
| 145 |
-
hidden_size=self.hidden_dim,
|
| 146 |
-
num_layers=self.num_layers,
|
| 147 |
-
batch_first=True,
|
| 148 |
-
bidirectional=False,
|
| 149 |
-
dropout=self.dropout if not no_dropout else 0
|
| 150 |
-
)
|
| 151 |
-
if self.rank >= 0:
|
| 152 |
-
if torch.cuda.is_available():
|
| 153 |
-
self.model = self.model.to(self.rank)
|
| 154 |
-
else:
|
| 155 |
-
# support mac mps
|
| 156 |
-
self.model = self.model.to(torch.device("mps", self.rank))
|
| 157 |
-
self.init_hidden(self.batch_size)
|
| 158 |
-
|
| 159 |
-
def init_hidden(self, bsize):
|
| 160 |
-
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
| 161 |
-
min(self.batch_size, bsize),
|
| 162 |
-
self.hidden_dim),
|
| 163 |
-
requires_grad=False)
|
| 164 |
-
if self.rank >= 0:
|
| 165 |
-
if torch.cuda.is_available():
|
| 166 |
-
self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank)
|
| 167 |
-
else:
|
| 168 |
-
self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank))
|
| 169 |
-
|
| 170 |
-
def to(self, device):
|
| 171 |
-
newself = super().to(device)
|
| 172 |
-
newself.model = newself.model.to(device)
|
| 173 |
-
for k in self.hidden_dict.keys():
|
| 174 |
-
newself.hidden_dict[k] = newself.hidden_dict[k].to(device)
|
| 175 |
-
newself.rank = next(newself.parameters()).get_device()
|
| 176 |
-
return newself
|
| 177 |
-
|
| 178 |
-
def forward_one_step(self, tgt_emb_in):
|
| 179 |
-
''' one-step forward model
|
| 180 |
-
|
| 181 |
-
Parameters
|
| 182 |
-
----------
|
| 183 |
-
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
| 184 |
-
|
| 185 |
-
Returns
|
| 186 |
-
-------
|
| 187 |
-
Tensor, shape (batch_size, hidden_dim)
|
| 188 |
-
'''
|
| 189 |
-
bsize = tgt_emb_in.size(0)
|
| 190 |
-
tgt_emb_out, self.hidden_dict['h'] \
|
| 191 |
-
= self.model(tgt_emb_in.view(bsize, 1, -1),
|
| 192 |
-
self.hidden_dict['h'])
|
| 193 |
-
return tgt_emb_out
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
class NodeMLP(nn.Module):
|
| 197 |
-
def __init__(self, input_size, output_size, hidden_size):
|
| 198 |
-
super().__init__()
|
| 199 |
-
self.lin1 = nn.Linear(input_size, hidden_size)
|
| 200 |
-
self.nbat = nn.BatchNorm1d(hidden_size)
|
| 201 |
-
self.lin2 = nn.Linear(hidden_size, output_size)
|
| 202 |
-
|
| 203 |
-
def forward(self, x):
|
| 204 |
-
x = self.lin1(x)
|
| 205 |
-
x = self.nbat(x)
|
| 206 |
-
x = x.relu()
|
| 207 |
-
x = self.lin2(x)
|
| 208 |
-
return x
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
class GINLayer(MessagePassing):
|
| 212 |
-
def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size):
|
| 213 |
-
super().__init__()
|
| 214 |
-
self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size)
|
| 215 |
-
self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size)
|
| 216 |
-
self.eps = nn.Parameter(torch.tensor([0.0]))
|
| 217 |
-
|
| 218 |
-
def forward(self, x, edge_index, edge_attr):
|
| 219 |
-
msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr)
|
| 220 |
-
x = (1.0 + self.eps) * x + msg
|
| 221 |
-
x = x.relu()
|
| 222 |
-
x = self.node_mlp(x)
|
| 223 |
-
return x
|
| 224 |
-
|
| 225 |
-
def message(self, x_j, edge_attr):
|
| 226 |
-
edge_attr = self.edge_mlp(edge_attr)
|
| 227 |
-
x_j = x_j + edge_attr
|
| 228 |
-
x_j = x_j.relu()
|
| 229 |
-
return x_j
|
| 230 |
-
|
| 231 |
-
def update(self, aggr_out):
|
| 232 |
-
return aggr_out
|
| 233 |
-
|
| 234 |
-
#TODO implement the case where features of atoms and edges are considered
|
| 235 |
-
# Check GraphMVP and ogb (open graph benchmark) to realize this
|
| 236 |
-
class GIN(torch.nn.Module):
|
| 237 |
-
def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64,
|
| 238 |
-
proximity_size=3, dropout=0.1):
|
| 239 |
-
super().__init__()
|
| 240 |
-
#print("(num node features, num edge features)=", (node_feature_size, edge_feature_size))
|
| 241 |
-
hsize = hidden_channels * 2
|
| 242 |
-
atom_dim, edge_dim = get_atom_edge_feature_dims()
|
| 243 |
-
self.trans = FeatureEmbedding(atom_dim, hidden_channels)
|
| 244 |
-
ml = []
|
| 245 |
-
for _ in range(proximity_size):
|
| 246 |
-
ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim))
|
| 247 |
-
self.mlist = nn.ModuleList(ml)
|
| 248 |
-
#It is possible to calculate relu with x.relu() where x is an output
|
| 249 |
-
#self.activations = nn.ModuleList(actl)
|
| 250 |
-
self.dropout = dropout
|
| 251 |
-
self.proximity_size = proximity_size
|
| 252 |
-
|
| 253 |
-
def forward(self, x, edge_index, edge_attr, batch_size):
|
| 254 |
-
x = x.to(torch.float)
|
| 255 |
-
#print("before: edge_weight.shape=", edge_attr.shape)
|
| 256 |
-
edge_attr = edge_attr.to(torch.float)
|
| 257 |
-
#print("after: edge_weight.shape=", edge_attr.shape)
|
| 258 |
-
x = self.trans(x)
|
| 259 |
-
# TODO Check if this x is consistent with global_add_pool
|
| 260 |
-
hlist = [global_add_pool(x, batch_size)]
|
| 261 |
-
for id, m in enumerate(self.mlist):
|
| 262 |
-
x = m(x, edge_index=edge_index, edge_attr=edge_attr)
|
| 263 |
-
#print("Done with one layer")
|
| 264 |
-
###if id != self.proximity_size - 1:
|
| 265 |
-
x = x.relu()
|
| 266 |
-
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 267 |
-
#h = global_mean_pool(x, batch_size)
|
| 268 |
-
h = global_add_pool(x, batch_size)
|
| 269 |
-
hlist.append(h)
|
| 270 |
-
#print("Done with one relu call: x.shape=", x.shape)
|
| 271 |
-
#print("calling golbal mean pool")
|
| 272 |
-
#print("calling dropout x.shape=", x.shape)
|
| 273 |
-
#print("x=", x)
|
| 274 |
-
#print("hlist[0].shape=", hlist[0].shape)
|
| 275 |
-
x = torch.cat(hlist, dim=1)
|
| 276 |
-
#print("x.shape=", x.shape)
|
| 277 |
-
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 278 |
-
|
| 279 |
-
return x
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# TODO copied from MHG implementation and adapted here.
|
| 283 |
-
class GrammarSeq2SeqVAE(nn.Module):
|
| 284 |
-
|
| 285 |
-
'''
|
| 286 |
-
Variational seq2seq with grammar.
|
| 287 |
-
TODO: rewrite this class using mixin
|
| 288 |
-
'''
|
| 289 |
-
|
| 290 |
-
def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80,
|
| 291 |
-
batch_size=64, padding_idx=-1,
|
| 292 |
-
encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True,
|
| 293 |
-
'dropout': 0.1},
|
| 294 |
-
decoder_params={'hidden_dim': 384, #'num_layers': 2,
|
| 295 |
-
'num_layers': 3,
|
| 296 |
-
'dropout': 0.1},
|
| 297 |
-
prod_rule_embed_params={'out_dim': 128},
|
| 298 |
-
no_dropout=False):
|
| 299 |
-
|
| 300 |
-
super().__init__()
|
| 301 |
-
# TODO USE GRU FOR ENCODING AND DECODING
|
| 302 |
-
self.hrg = hrg
|
| 303 |
-
self.rank = rank
|
| 304 |
-
self.prod_rule_corpus = hrg.prod_rule_corpus
|
| 305 |
-
self.prod_rule_embed_params = prod_rule_embed_params
|
| 306 |
-
|
| 307 |
-
self.vocab_size = hrg.num_prod_rule + 1
|
| 308 |
-
self.batch_size = batch_size
|
| 309 |
-
self.padding_idx = np.mod(padding_idx, self.vocab_size)
|
| 310 |
-
self.no_dropout = no_dropout
|
| 311 |
-
|
| 312 |
-
self.latent_dim = latent_dim
|
| 313 |
-
self.max_len = max_len
|
| 314 |
-
self.encoder_params = encoder_params
|
| 315 |
-
self.decoder_params = decoder_params
|
| 316 |
-
|
| 317 |
-
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
|
| 318 |
-
embed_out_dim = self.prod_rule_embed_params['out_dim']
|
| 319 |
-
#use MolecularProdRuleEmbedding later on
|
| 320 |
-
self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
| 321 |
-
padding_idx=self.padding_idx)
|
| 322 |
-
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
| 323 |
-
padding_idx=self.padding_idx)
|
| 324 |
-
|
| 325 |
-
# USE a GRU-based encoder in MHG
|
| 326 |
-
self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
| 327 |
-
rank=self.rank, no_dropout=self.no_dropout,
|
| 328 |
-
**self.encoder_params)
|
| 329 |
-
|
| 330 |
-
lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim']
|
| 331 |
-
lin_out_dim = self.latent_dim
|
| 332 |
-
self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False)
|
| 333 |
-
self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim)
|
| 334 |
-
|
| 335 |
-
# USE a GRU-based decoder in MHG
|
| 336 |
-
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
| 337 |
-
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
|
| 338 |
-
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
|
| 339 |
-
self.latent2hidden_dict = nn.ModuleDict()
|
| 340 |
-
dec_lin_out_dim = self.decoder_params['hidden_dim']
|
| 341 |
-
for each_hidden in self.decoder.hidden_dict.keys():
|
| 342 |
-
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim)
|
| 343 |
-
if self.rank >= 0:
|
| 344 |
-
if torch.cuda.is_available():
|
| 345 |
-
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
|
| 346 |
-
else:
|
| 347 |
-
# support mac mps
|
| 348 |
-
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
|
| 349 |
-
|
| 350 |
-
self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size)
|
| 351 |
-
self.encoder.init_hidden(self.batch_size)
|
| 352 |
-
self.decoder.init_hidden(self.batch_size)
|
| 353 |
-
|
| 354 |
-
# TODO Do we need this?
|
| 355 |
-
if hasattr(self.src_embedding, 'weight'):
|
| 356 |
-
self.src_embedding.weight.data.uniform_(-0.1, 0.1)
|
| 357 |
-
if hasattr(self.tgt_embedding, 'weight'):
|
| 358 |
-
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
|
| 359 |
-
|
| 360 |
-
self.encoder.init_hidden(self.batch_size)
|
| 361 |
-
self.decoder.init_hidden(self.batch_size)
|
| 362 |
-
|
| 363 |
-
def to(self, device):
|
| 364 |
-
newself = super().to(device)
|
| 365 |
-
newself.src_embedding = newself.src_embedding.to(device)
|
| 366 |
-
newself.tgt_embedding = newself.tgt_embedding.to(device)
|
| 367 |
-
newself.encoder = newself.encoder.to(device)
|
| 368 |
-
newself.decoder = newself.decoder.to(device)
|
| 369 |
-
newself.dec2vocab = newself.dec2vocab.to(device)
|
| 370 |
-
newself.hidden2mean = newself.hidden2mean.to(device)
|
| 371 |
-
newself.hidden2logvar = newself.hidden2logvar.to(device)
|
| 372 |
-
newself.latent2tgt_emb = newself.latent2tgt_emb.to(device)
|
| 373 |
-
newself.latent2hidden_dict = newself.latent2hidden_dict.to(device)
|
| 374 |
-
return newself
|
| 375 |
-
|
| 376 |
-
def forward(self, in_seq, out_seq):
|
| 377 |
-
''' forward model
|
| 378 |
-
|
| 379 |
-
Parameters
|
| 380 |
-
----------
|
| 381 |
-
in_seq : Variable, shape (batch_size, length)
|
| 382 |
-
each element corresponds to word index.
|
| 383 |
-
where the index should be less than `vocab_size`
|
| 384 |
-
|
| 385 |
-
Returns
|
| 386 |
-
-------
|
| 387 |
-
Variable, shape (batch_size, length, vocab_size)
|
| 388 |
-
logit of each word (applying softmax yields the probability)
|
| 389 |
-
'''
|
| 390 |
-
mu, logvar = self.encode(in_seq)
|
| 391 |
-
z = self.reparameterize(mu, logvar)
|
| 392 |
-
return self.decode(z, out_seq), mu, logvar
|
| 393 |
-
|
| 394 |
-
def encode(self, in_seq):
|
| 395 |
-
src_emb = self.src_embedding(in_seq)
|
| 396 |
-
src_h = self.encoder.forward(src_emb)
|
| 397 |
-
if self.encoder_params.get('bidirectional', False):
|
| 398 |
-
concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1)
|
| 399 |
-
return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h)
|
| 400 |
-
else:
|
| 401 |
-
return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :])
|
| 402 |
-
|
| 403 |
-
def reparameterize(self, mu, logvar, training=True):
|
| 404 |
-
if training:
|
| 405 |
-
std = logvar.mul(0.5).exp_()
|
| 406 |
-
device = next(self.parameters()).device
|
| 407 |
-
eps = Variable(std.data.new(std.size()).normal_())
|
| 408 |
-
if device != eps.get_device():
|
| 409 |
-
eps.to(device)
|
| 410 |
-
return eps.mul(std).add_(mu)
|
| 411 |
-
else:
|
| 412 |
-
return mu
|
| 413 |
-
|
| 414 |
-
#TODO Not tested. Need to implement this in case of molecular structure generation
|
| 415 |
-
def sample(self, sample_size=-1, deterministic=True, return_z=False):
|
| 416 |
-
self.eval()
|
| 417 |
-
self.init_hidden()
|
| 418 |
-
if sample_size == -1:
|
| 419 |
-
sample_size = self.batch_size
|
| 420 |
-
|
| 421 |
-
num_iter = int(np.ceil(sample_size / self.batch_size))
|
| 422 |
-
hg_list = []
|
| 423 |
-
z_list = []
|
| 424 |
-
for _ in range(num_iter):
|
| 425 |
-
z = Variable(torch.normal(
|
| 426 |
-
torch.zeros(self.batch_size, self.latent_dim),
|
| 427 |
-
torch.ones(self.batch_size * self.latent_dim))).cuda()
|
| 428 |
-
_, each_hg_list = self.decode(z, deterministic=deterministic)
|
| 429 |
-
z_list.append(z)
|
| 430 |
-
hg_list += each_hg_list
|
| 431 |
-
z = torch.cat(z_list)[:sample_size]
|
| 432 |
-
hg_list = hg_list[:sample_size]
|
| 433 |
-
if return_z:
|
| 434 |
-
return hg_list, z.cpu().detach().numpy()
|
| 435 |
-
else:
|
| 436 |
-
return hg_list
|
| 437 |
-
|
| 438 |
-
def decode(self, z=None, out_seq=None, deterministic=True):
|
| 439 |
-
if z is None:
|
| 440 |
-
z = Variable(torch.normal(
|
| 441 |
-
torch.zeros(self.batch_size, self.latent_dim),
|
| 442 |
-
torch.ones(self.batch_size * self.latent_dim)))
|
| 443 |
-
if self.rank >= 0:
|
| 444 |
-
z = z.to(next(self.parameters()).device)
|
| 445 |
-
|
| 446 |
-
hidden_dict_0 = {}
|
| 447 |
-
for each_hidden in self.latent2hidden_dict.keys():
|
| 448 |
-
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
| 449 |
-
bsize = z.size(0)
|
| 450 |
-
self.decoder.init_hidden(bsize)
|
| 451 |
-
self.decoder.feed_hidden(hidden_dict_0)
|
| 452 |
-
|
| 453 |
-
if out_seq is not None:
|
| 454 |
-
tgt_emb0 = self.latent2tgt_emb(z)
|
| 455 |
-
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
|
| 456 |
-
out_seq_emb = self.tgt_embedding(out_seq)
|
| 457 |
-
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
|
| 458 |
-
tgt_emb_pred_list = []
|
| 459 |
-
for each_idx in range(self.max_len):
|
| 460 |
-
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1))
|
| 461 |
-
tgt_emb_pred_list.append(tgt_emb_pred)
|
| 462 |
-
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
| 463 |
-
return vocab_logit
|
| 464 |
-
else:
|
| 465 |
-
with torch.no_grad():
|
| 466 |
-
tgt_emb = self.latent2tgt_emb(z)
|
| 467 |
-
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
| 468 |
-
tgt_emb_pred_list = []
|
| 469 |
-
stack_list = []
|
| 470 |
-
hg_list = []
|
| 471 |
-
nt_symbol_list = []
|
| 472 |
-
nt_edge_list = []
|
| 473 |
-
gen_finish_list = []
|
| 474 |
-
for _ in range(bsize):
|
| 475 |
-
stack_list.append([])
|
| 476 |
-
hg_list.append(None)
|
| 477 |
-
nt_symbol_list.append(NTSymbol(degree=0,
|
| 478 |
-
is_aromatic=False,
|
| 479 |
-
bond_symbol_list=[]))
|
| 480 |
-
nt_edge_list.append(None)
|
| 481 |
-
gen_finish_list.append(False)
|
| 482 |
-
|
| 483 |
-
for idx in range(self.max_len):
|
| 484 |
-
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
| 485 |
-
tgt_emb_pred_list.append(tgt_emb_pred)
|
| 486 |
-
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
| 487 |
-
for each_batch_idx in range(bsize):
|
| 488 |
-
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
| 489 |
-
# get production rule greedily
|
| 490 |
-
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
| 491 |
-
nt_symbol_list[each_batch_idx],
|
| 492 |
-
deterministic=deterministic)
|
| 493 |
-
# convert production rule into an index
|
| 494 |
-
tgt_id = self.hrg.prod_rule_list.index(prod_rule)
|
| 495 |
-
# apply the production rule
|
| 496 |
-
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
| 497 |
-
# add non-terminals to the stack
|
| 498 |
-
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
| 499 |
-
# if the stack size is 0, generation has finished!
|
| 500 |
-
if len(stack_list[each_batch_idx]) == 0:
|
| 501 |
-
gen_finish_list[each_batch_idx] = True
|
| 502 |
-
else:
|
| 503 |
-
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
| 504 |
-
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
| 505 |
-
else:
|
| 506 |
-
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
| 507 |
-
indice_tensor = torch.LongTensor([tgt_id])
|
| 508 |
-
device = next(self.parameters()).device
|
| 509 |
-
if indice_tensor.device != device:
|
| 510 |
-
indice_tensor = indice_tensor.to(device)
|
| 511 |
-
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
| 512 |
-
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
| 513 |
-
#for id, v in enumerate(gen_finish_list):
|
| 514 |
-
#if not v:
|
| 515 |
-
# print("bacth id={} not finished generating a sequence: ".format(id))
|
| 516 |
-
return gen_finish_list, vocab_logit, hg_list
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
# TODO A lot of duplicates with GrammarVAE. Clean up it if necessary
|
| 520 |
-
class GrammarGINVAE(nn.Module):
|
| 521 |
-
|
| 522 |
-
'''
|
| 523 |
-
Variational autoencoder based on GIN and grammar
|
| 524 |
-
'''
|
| 525 |
-
|
| 526 |
-
def __init__(self, hrg, rank=-1, max_len=80,
|
| 527 |
-
batch_size=64, padding_idx=-1,
|
| 528 |
-
encoder_params={'node_feature_size': 4, 'edge_feature_size': 3,
|
| 529 |
-
'hidden_channels': 64, 'proximity_size': 3,
|
| 530 |
-
'dropout': 0.1},
|
| 531 |
-
decoder_params={'hidden_dim': 384, 'num_layers': 3,
|
| 532 |
-
'dropout': 0.1},
|
| 533 |
-
prod_rule_embed_params={'out_dim': 128},
|
| 534 |
-
no_dropout=False):
|
| 535 |
-
|
| 536 |
-
super().__init__()
|
| 537 |
-
# TODO USE GRU FOR ENCODING AND DECODING
|
| 538 |
-
self.hrg = hrg
|
| 539 |
-
self.rank = rank
|
| 540 |
-
self.prod_rule_corpus = hrg.prod_rule_corpus
|
| 541 |
-
self.prod_rule_embed_params = prod_rule_embed_params
|
| 542 |
-
|
| 543 |
-
self.vocab_size = hrg.num_prod_rule + 1
|
| 544 |
-
self.batch_size = batch_size
|
| 545 |
-
self.padding_idx = np.mod(padding_idx, self.vocab_size)
|
| 546 |
-
self.no_dropout = no_dropout
|
| 547 |
-
self.max_len = max_len
|
| 548 |
-
self.encoder_params = encoder_params
|
| 549 |
-
self.decoder_params = decoder_params
|
| 550 |
-
|
| 551 |
-
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
|
| 552 |
-
embed_out_dim = self.prod_rule_embed_params['out_dim']
|
| 553 |
-
#use MolecularProdRuleEmbedding later on
|
| 554 |
-
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
| 555 |
-
padding_idx=self.padding_idx)
|
| 556 |
-
|
| 557 |
-
self.encoder = GIN(**self.encoder_params)
|
| 558 |
-
self.latent_dim = self.encoder_params['hidden_channels']
|
| 559 |
-
self.proximity_size = self.encoder_params['proximity_size']
|
| 560 |
-
hidden_dim = self.decoder_params['hidden_dim']
|
| 561 |
-
self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False)
|
| 562 |
-
self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim)
|
| 563 |
-
|
| 564 |
-
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
| 565 |
-
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
|
| 566 |
-
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
|
| 567 |
-
self.latent2hidden_dict = nn.ModuleDict()
|
| 568 |
-
for each_hidden in self.decoder.hidden_dict.keys():
|
| 569 |
-
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim)
|
| 570 |
-
if self.rank >= 0:
|
| 571 |
-
if torch.cuda.is_available():
|
| 572 |
-
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
|
| 573 |
-
else:
|
| 574 |
-
# support mac mps
|
| 575 |
-
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
|
| 576 |
-
|
| 577 |
-
self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size)
|
| 578 |
-
self.decoder.init_hidden(self.batch_size)
|
| 579 |
-
|
| 580 |
-
# TODO Do we need this?
|
| 581 |
-
if hasattr(self.tgt_embedding, 'weight'):
|
| 582 |
-
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
|
| 583 |
-
self.decoder.init_hidden(self.batch_size)
|
| 584 |
-
|
| 585 |
-
def to(self, device):
|
| 586 |
-
newself = super().to(device)
|
| 587 |
-
newself.encoder = newself.encoder.to(device)
|
| 588 |
-
newself.decoder = newself.decoder.to(device)
|
| 589 |
-
newself.rank = next(newself.encoder.parameters()).get_device()
|
| 590 |
-
return newself
|
| 591 |
-
|
| 592 |
-
def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None):
|
| 593 |
-
mu, logvar = self.encode(x, edge_index, edge_attr, batch_size)
|
| 594 |
-
z = self.reparameterize(mu, logvar)
|
| 595 |
-
return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar
|
| 596 |
-
|
| 597 |
-
#TODO Not tested. Need to implement this in case of molecular structure generation
|
| 598 |
-
def sample(self, sample_size=-1, deterministic=True, return_z=False):
|
| 599 |
-
self.eval()
|
| 600 |
-
self.init_hidden()
|
| 601 |
-
if sample_size == -1:
|
| 602 |
-
sample_size = self.batch_size
|
| 603 |
-
|
| 604 |
-
num_iter = int(np.ceil(sample_size / self.batch_size))
|
| 605 |
-
hg_list = []
|
| 606 |
-
z_list = []
|
| 607 |
-
for _ in range(num_iter):
|
| 608 |
-
z = Variable(torch.normal(
|
| 609 |
-
torch.zeros(self.batch_size, self.latent_dim),
|
| 610 |
-
torch.ones(self.batch_size * self.latent_dim))).cuda()
|
| 611 |
-
_, each_hg_list = self.decode(z, deterministic=deterministic)
|
| 612 |
-
z_list.append(z)
|
| 613 |
-
hg_list += each_hg_list
|
| 614 |
-
z = torch.cat(z_list)[:sample_size]
|
| 615 |
-
hg_list = hg_list[:sample_size]
|
| 616 |
-
if return_z:
|
| 617 |
-
return hg_list, z.cpu().detach().numpy()
|
| 618 |
-
else:
|
| 619 |
-
return hg_list
|
| 620 |
-
|
| 621 |
-
def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None):
|
| 622 |
-
if z is None:
|
| 623 |
-
z = Variable(torch.normal(
|
| 624 |
-
torch.zeros(self.batch_size, self.latent_dim),
|
| 625 |
-
torch.ones(self.batch_size * self.latent_dim)))
|
| 626 |
-
if self.rank >= 0:
|
| 627 |
-
z = z.to(next(self.parameters()).device)
|
| 628 |
-
|
| 629 |
-
hidden_dict_0 = {}
|
| 630 |
-
for each_hidden in self.latent2hidden_dict.keys():
|
| 631 |
-
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
| 632 |
-
bsize = z.size(0)
|
| 633 |
-
self.decoder.init_hidden(bsize)
|
| 634 |
-
self.decoder.feed_hidden(hidden_dict_0)
|
| 635 |
-
|
| 636 |
-
if out_seq is not None:
|
| 637 |
-
tgt_emb0 = self.latent2tgt_emb(z)
|
| 638 |
-
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
|
| 639 |
-
out_seq_emb = self.tgt_embedding(out_seq)
|
| 640 |
-
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
|
| 641 |
-
tgt_emb_pred_list = []
|
| 642 |
-
tgt_emb_pred = None
|
| 643 |
-
for each_idx in range(self.max_len):
|
| 644 |
-
if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob:
|
| 645 |
-
inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1)
|
| 646 |
-
else:
|
| 647 |
-
cur_logit = self.dec2vocab(tgt_emb_pred)
|
| 648 |
-
yi = torch.argmax(cur_logit, dim=2)
|
| 649 |
-
inp = self.tgt_embedding(yi)
|
| 650 |
-
tgt_emb_pred = self.decoder.forward_one_step(inp)
|
| 651 |
-
tgt_emb_pred_list.append(tgt_emb_pred)
|
| 652 |
-
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
| 653 |
-
return vocab_logit
|
| 654 |
-
else:
|
| 655 |
-
with torch.no_grad():
|
| 656 |
-
tgt_emb = self.latent2tgt_emb(z)
|
| 657 |
-
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
| 658 |
-
tgt_emb_pred_list = []
|
| 659 |
-
stack_list = []
|
| 660 |
-
hg_list = []
|
| 661 |
-
nt_symbol_list = []
|
| 662 |
-
nt_edge_list = []
|
| 663 |
-
gen_finish_list = []
|
| 664 |
-
for _ in range(bsize):
|
| 665 |
-
stack_list.append([])
|
| 666 |
-
hg_list.append(None)
|
| 667 |
-
nt_symbol_list.append(NTSymbol(degree=0,
|
| 668 |
-
is_aromatic=False,
|
| 669 |
-
bond_symbol_list=[]))
|
| 670 |
-
nt_edge_list.append(None)
|
| 671 |
-
gen_finish_list.append(False)
|
| 672 |
-
|
| 673 |
-
for _ in range(self.max_len):
|
| 674 |
-
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
| 675 |
-
tgt_emb_pred_list.append(tgt_emb_pred)
|
| 676 |
-
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
| 677 |
-
for each_batch_idx in range(bsize):
|
| 678 |
-
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
| 679 |
-
# get production rule greedily
|
| 680 |
-
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
| 681 |
-
nt_symbol_list[each_batch_idx],
|
| 682 |
-
deterministic=deterministic)
|
| 683 |
-
# convert production rule into an index
|
| 684 |
-
tgt_id = self.hrg.prod_rule_list.index(prod_rule)
|
| 685 |
-
# apply the production rule
|
| 686 |
-
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
| 687 |
-
# add non-terminals to the stack
|
| 688 |
-
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
| 689 |
-
# if the stack size is 0, generation has finished!
|
| 690 |
-
if len(stack_list[each_batch_idx]) == 0:
|
| 691 |
-
gen_finish_list[each_batch_idx] = True
|
| 692 |
-
else:
|
| 693 |
-
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
| 694 |
-
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
| 695 |
-
else:
|
| 696 |
-
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
| 697 |
-
indice_tensor = torch.LongTensor([tgt_id])
|
| 698 |
-
if self.rank >= 0:
|
| 699 |
-
indice_tensor = indice_tensor.to(next(self.parameters()).device)
|
| 700 |
-
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
| 701 |
-
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
| 702 |
-
return gen_finish_list, vocab_logit, hg_list
|
| 703 |
-
|
| 704 |
-
#TODO Not tested. Need to implement this in case of molecular structure generation
|
| 705 |
-
def conditional_distribution(self, z, tgt_id_list):
|
| 706 |
-
self.eval()
|
| 707 |
-
self.init_hidden()
|
| 708 |
-
z = z.cuda()
|
| 709 |
-
|
| 710 |
-
hidden_dict_0 = {}
|
| 711 |
-
for each_hidden in self.latent2hidden_dict.keys():
|
| 712 |
-
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
| 713 |
-
self.decoder.feed_hidden(hidden_dict_0)
|
| 714 |
-
|
| 715 |
-
with torch.no_grad():
|
| 716 |
-
tgt_emb = self.latent2tgt_emb(z)
|
| 717 |
-
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
| 718 |
-
nt_symbol_list = []
|
| 719 |
-
stack_list = []
|
| 720 |
-
hg_list = []
|
| 721 |
-
nt_edge_list = []
|
| 722 |
-
gen_finish_list = []
|
| 723 |
-
for _ in range(self.batch_size):
|
| 724 |
-
nt_symbol_list.append(NTSymbol(degree=0,
|
| 725 |
-
is_aromatic=False,
|
| 726 |
-
bond_symbol_list=[]))
|
| 727 |
-
stack_list.append([])
|
| 728 |
-
hg_list.append(None)
|
| 729 |
-
nt_edge_list.append(None)
|
| 730 |
-
gen_finish_list.append(False)
|
| 731 |
-
|
| 732 |
-
for each_position in range(len(tgt_id_list[0])):
|
| 733 |
-
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
| 734 |
-
for each_batch_idx in range(self.batch_size):
|
| 735 |
-
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
| 736 |
-
# use the prespecified target ids
|
| 737 |
-
tgt_id = tgt_id_list[each_batch_idx][each_position]
|
| 738 |
-
prod_rule = self.hrg.prod_rule_list[tgt_id]
|
| 739 |
-
# apply the production rule
|
| 740 |
-
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
| 741 |
-
# add non-terminals to the stack
|
| 742 |
-
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
| 743 |
-
# if the stack size is 0, generation has finished!
|
| 744 |
-
if len(stack_list[each_batch_idx]) == 0:
|
| 745 |
-
gen_finish_list[each_batch_idx] = True
|
| 746 |
-
else:
|
| 747 |
-
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
| 748 |
-
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
| 749 |
-
else:
|
| 750 |
-
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
| 751 |
-
indice_tensor = torch.LongTensor([tgt_id])
|
| 752 |
-
indice_tensor = indice_tensor.cuda()
|
| 753 |
-
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
| 754 |
-
|
| 755 |
-
# last one step
|
| 756 |
-
conditional_logprob_list = []
|
| 757 |
-
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
| 758 |
-
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
| 759 |
-
for each_batch_idx in range(self.batch_size):
|
| 760 |
-
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
| 761 |
-
# get production rule greedily
|
| 762 |
-
masked_logprob = self.hrg.prod_rule_corpus.masked_logprob(
|
| 763 |
-
vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
| 764 |
-
nt_symbol_list[each_batch_idx])
|
| 765 |
-
conditional_logprob_list.append(masked_logprob)
|
| 766 |
-
else:
|
| 767 |
-
conditional_logprob_list.append(None)
|
| 768 |
-
return conditional_logprob_list
|
| 769 |
-
|
| 770 |
-
#TODO Not tested. Need to implement this in case of molecular structure generation
|
| 771 |
-
def decode_with_beam_search(self, z, beam_width=1):
|
| 772 |
-
''' Decode a latent vector using beam search.
|
| 773 |
-
|
| 774 |
-
Parameters
|
| 775 |
-
----------
|
| 776 |
-
z
|
| 777 |
-
latent vector
|
| 778 |
-
beam_width : int
|
| 779 |
-
parameter for beam search
|
| 780 |
-
|
| 781 |
-
Returns
|
| 782 |
-
-------
|
| 783 |
-
List of Hypergraphs
|
| 784 |
-
'''
|
| 785 |
-
if self.batch_size != 1:
|
| 786 |
-
raise ValueError('this method works only under batch_size=1')
|
| 787 |
-
if self.padding_idx != -1:
|
| 788 |
-
raise ValueError('this method works only under padding_idx=-1')
|
| 789 |
-
top_k_tgt_id_list = [[]] * beam_width
|
| 790 |
-
logprob_list = [0.] * beam_width
|
| 791 |
-
|
| 792 |
-
for each_len in range(self.max_len):
|
| 793 |
-
expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx
|
| 794 |
-
expanded_length_list = np.array([0] * (beam_width * self.vocab_size))
|
| 795 |
-
for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list):
|
| 796 |
-
conditional_logprob = self.conditional_distribution(z, [each_candidate])[0]
|
| 797 |
-
if conditional_logprob is None:
|
| 798 |
-
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
|
| 799 |
-
= logprob_list[each_beam_idx]
|
| 800 |
-
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
|
| 801 |
-
= -np.inf
|
| 802 |
-
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
|
| 803 |
-
= len(each_candidate)
|
| 804 |
-
else:
|
| 805 |
-
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
|
| 806 |
-
= logprob_list[each_beam_idx] + conditional_logprob
|
| 807 |
-
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
|
| 808 |
-
= -np.inf
|
| 809 |
-
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
|
| 810 |
-
= len(each_candidate) + 1
|
| 811 |
-
score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list)
|
| 812 |
-
if each_len == 0:
|
| 813 |
-
top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width]
|
| 814 |
-
else:
|
| 815 |
-
top_k_list = np.argsort(score_list)[::-1][:beam_width]
|
| 816 |
-
next_top_k_tgt_id_list = []
|
| 817 |
-
next_logprob_list = []
|
| 818 |
-
for each_top_k in top_k_list:
|
| 819 |
-
beam_idx = each_top_k // self.vocab_size
|
| 820 |
-
vocab_idx = each_top_k % self.vocab_size
|
| 821 |
-
if vocab_idx == self.vocab_size - 1:
|
| 822 |
-
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx])
|
| 823 |
-
next_logprob_list.append(expanded_logprob_list[each_top_k])
|
| 824 |
-
else:
|
| 825 |
-
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx])
|
| 826 |
-
next_logprob_list.append(expanded_logprob_list[each_top_k])
|
| 827 |
-
top_k_tgt_id_list = next_top_k_tgt_id_list
|
| 828 |
-
logprob_list = next_logprob_list
|
| 829 |
-
|
| 830 |
-
# construct hypergraphs
|
| 831 |
-
hg_list = []
|
| 832 |
-
for each_tgt_id_list in top_k_tgt_id_list:
|
| 833 |
-
hg = None
|
| 834 |
-
stack = []
|
| 835 |
-
nt_edge = None
|
| 836 |
-
for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list):
|
| 837 |
-
prod_rule = self.hrg.prod_rule_list[each_prod_rule_id]
|
| 838 |
-
hg, nt_edges = prod_rule.applied_to(hg, nt_edge)
|
| 839 |
-
stack.extend(nt_edges[::-1])
|
| 840 |
-
try:
|
| 841 |
-
nt_edge = stack.pop()
|
| 842 |
-
except IndexError:
|
| 843 |
-
if each_idx == len(each_tgt_id_list) - 1:
|
| 844 |
-
break
|
| 845 |
-
else:
|
| 846 |
-
raise ValueError('some bugs')
|
| 847 |
-
hg_list.append(hg)
|
| 848 |
-
return hg_list
|
| 849 |
-
|
| 850 |
-
def graph_embed(self, x, edge_index, edge_attr, batch_size):
|
| 851 |
-
src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size)
|
| 852 |
-
return src_h
|
| 853 |
-
|
| 854 |
-
def encode(self, x, edge_index, edge_attr, batch_size):
|
| 855 |
-
#print("device for src_emb=", src_emb.get_device())
|
| 856 |
-
#print("device for self.encoder=", next(self.encoder.parameters()).get_device())
|
| 857 |
-
src_h = self.graph_embed(x, edge_index, edge_attr, batch_size)
|
| 858 |
-
mu, lv = self.get_mean_var(src_h)
|
| 859 |
-
return mu, lv
|
| 860 |
-
|
| 861 |
-
def get_mean_var(self, src_h):
|
| 862 |
-
#src_h = torch.tanh(src_h)
|
| 863 |
-
mu = self.hidden2mean(src_h)
|
| 864 |
-
lv = self.hidden2logvar(src_h)
|
| 865 |
-
mu = torch.tanh(mu)
|
| 866 |
-
lv = torch.tanh(lv)
|
| 867 |
-
return mu, lv
|
| 868 |
-
|
| 869 |
-
def reparameterize(self, mu, logvar, training=True):
|
| 870 |
-
if training:
|
| 871 |
-
std = logvar.mul(0.5).exp_()
|
| 872 |
-
eps = Variable(std.data.new(std.size()).normal_())
|
| 873 |
-
if self.rank >= 0:
|
| 874 |
-
eps = eps.to(next(self.parameters()).device)
|
| 875 |
-
return eps.mul(std).add_(mu)
|
| 876 |
-
else:
|
| 877 |
-
return mu
|
| 878 |
-
|
| 879 |
-
# Copied from the MHG implementation and adapted
|
| 880 |
-
class GrammarVAELoss(_Loss):
|
| 881 |
-
|
| 882 |
-
'''
|
| 883 |
-
a loss function for Grammar VAE
|
| 884 |
-
|
| 885 |
-
Attributes
|
| 886 |
-
----------
|
| 887 |
-
hrg : HyperedgeReplacementGrammar
|
| 888 |
-
beta : float
|
| 889 |
-
coefficient of KL divergence
|
| 890 |
-
'''
|
| 891 |
-
|
| 892 |
-
def __init__(self, rank, hrg, beta=1.0, **kwargs):
|
| 893 |
-
super().__init__(**kwargs)
|
| 894 |
-
self.hrg = hrg
|
| 895 |
-
self.beta = beta
|
| 896 |
-
self.rank = rank
|
| 897 |
-
|
| 898 |
-
def forward(self, mu, logvar, in_seq_pred, in_seq):
|
| 899 |
-
''' compute VAE loss
|
| 900 |
-
|
| 901 |
-
Parameters
|
| 902 |
-
----------
|
| 903 |
-
in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size)
|
| 904 |
-
logit
|
| 905 |
-
in_seq : torch.Tensor, shape (batch_size, max_len)
|
| 906 |
-
each element corresponds to a word id in vocabulary.
|
| 907 |
-
mu : torch.Tensor, shape (batch_size, hidden_dim)
|
| 908 |
-
logvar : torch.Tensor, shape (batch_size, hidden_dim)
|
| 909 |
-
mean and log variance of the normal distribution
|
| 910 |
-
'''
|
| 911 |
-
batch_size = in_seq_pred.shape[0]
|
| 912 |
-
max_len = in_seq_pred.shape[1]
|
| 913 |
-
vocab_size = in_seq_pred.shape[2]
|
| 914 |
-
mask = torch.zeros(in_seq_pred.shape)
|
| 915 |
-
|
| 916 |
-
for each_batch in range(batch_size):
|
| 917 |
-
flag = True
|
| 918 |
-
for each_idx in range(max_len):
|
| 919 |
-
prod_rule_idx = in_seq[each_batch, each_idx]
|
| 920 |
-
if prod_rule_idx == vocab_size - 1:
|
| 921 |
-
#### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT
|
| 922 |
-
mask[each_batch, each_idx, prod_rule_idx] = 1
|
| 923 |
-
#break
|
| 924 |
-
continue
|
| 925 |
-
lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol
|
| 926 |
-
lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs)
|
| 927 |
-
mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx])
|
| 928 |
-
if self.rank >= 0:
|
| 929 |
-
mask = mask.to(next(self.parameters()).device)
|
| 930 |
-
in_seq_pred = mask * in_seq_pred
|
| 931 |
-
|
| 932 |
-
cross_entropy = F.cross_entropy(
|
| 933 |
-
in_seq_pred.view(-1, vocab_size),
|
| 934 |
-
in_seq.view(-1),
|
| 935 |
-
reduction='sum',
|
| 936 |
-
#ignore_index=self.ignore_index if self.ignore_index is not None else -100
|
| 937 |
-
)
|
| 938 |
-
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
| 939 |
-
return cross_entropy + self.beta * kl_div
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
class VAELoss(_Loss):
|
| 943 |
-
def __init__(self, beta=0.01):
|
| 944 |
-
super().__init__()
|
| 945 |
-
self.beta = beta
|
| 946 |
-
|
| 947 |
-
def forward(self, mean, log_var, dec_outputs, targets):
|
| 948 |
-
|
| 949 |
-
device = mean.get_device()
|
| 950 |
-
if device >= 0:
|
| 951 |
-
targets = targets.to(mean.get_device())
|
| 952 |
-
reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum')
|
| 953 |
-
|
| 954 |
-
KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var))
|
| 955 |
-
loss = - self.beta * KL + reconstruction
|
| 956 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/mhg-gnn_encoder_decoder_example.ipynb
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"id": "829ddc03",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [],
|
| 9 |
-
"source": [
|
| 10 |
-
"import sys\n",
|
| 11 |
-
"sys.path.append('..')"
|
| 12 |
-
]
|
| 13 |
-
},
|
| 14 |
-
{
|
| 15 |
-
"cell_type": "code",
|
| 16 |
-
"execution_count": null,
|
| 17 |
-
"id": "ea820e23",
|
| 18 |
-
"metadata": {},
|
| 19 |
-
"outputs": [],
|
| 20 |
-
"source": [
|
| 21 |
-
"import torch\n",
|
| 22 |
-
"import load"
|
| 23 |
-
]
|
| 24 |
-
},
|
| 25 |
-
{
|
| 26 |
-
"cell_type": "markdown",
|
| 27 |
-
"id": "b9a51fa8",
|
| 28 |
-
"metadata": {},
|
| 29 |
-
"source": [
|
| 30 |
-
"# Load MHG-GNN"
|
| 31 |
-
]
|
| 32 |
-
},
|
| 33 |
-
{
|
| 34 |
-
"cell_type": "code",
|
| 35 |
-
"execution_count": null,
|
| 36 |
-
"id": "c6ea1fc8",
|
| 37 |
-
"metadata": {},
|
| 38 |
-
"outputs": [],
|
| 39 |
-
"source": [
|
| 40 |
-
"model_ckp = \"models/model_checkpoints/mhg_model/pickles/mhggnn_pretrained_model_radius7_1116_2023.pickle\"\n",
|
| 41 |
-
"\n",
|
| 42 |
-
"model = load.load(model_name = model_ckp)\n",
|
| 43 |
-
"if model is None:\n",
|
| 44 |
-
" print(\"Model not loaded, please check you have MHG pickle file\")\n",
|
| 45 |
-
"else:\n",
|
| 46 |
-
" print(\"MHG model loaded\")"
|
| 47 |
-
]
|
| 48 |
-
},
|
| 49 |
-
{
|
| 50 |
-
"cell_type": "markdown",
|
| 51 |
-
"id": "b4a0b557",
|
| 52 |
-
"metadata": {},
|
| 53 |
-
"source": [
|
| 54 |
-
"# Embeddings\n",
|
| 55 |
-
"\n",
|
| 56 |
-
"※ replace the smiles exaple list with your dataset"
|
| 57 |
-
]
|
| 58 |
-
},
|
| 59 |
-
{
|
| 60 |
-
"cell_type": "code",
|
| 61 |
-
"execution_count": null,
|
| 62 |
-
"id": "c63a6be6",
|
| 63 |
-
"metadata": {},
|
| 64 |
-
"outputs": [],
|
| 65 |
-
"source": [
|
| 66 |
-
"with torch.no_grad():\n",
|
| 67 |
-
" repr = model.encode([\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"])\n",
|
| 68 |
-
" \n",
|
| 69 |
-
"# Print the latent vectors\n",
|
| 70 |
-
"print(repr)"
|
| 71 |
-
]
|
| 72 |
-
},
|
| 73 |
-
{
|
| 74 |
-
"cell_type": "markdown",
|
| 75 |
-
"id": "a59f9442",
|
| 76 |
-
"metadata": {},
|
| 77 |
-
"source": [
|
| 78 |
-
"# Decoding"
|
| 79 |
-
]
|
| 80 |
-
},
|
| 81 |
-
{
|
| 82 |
-
"cell_type": "code",
|
| 83 |
-
"execution_count": null,
|
| 84 |
-
"id": "6a0d8a41",
|
| 85 |
-
"metadata": {},
|
| 86 |
-
"outputs": [],
|
| 87 |
-
"source": [
|
| 88 |
-
"orig = model.decode(repr)\n",
|
| 89 |
-
"print(orig)"
|
| 90 |
-
]
|
| 91 |
-
}
|
| 92 |
-
],
|
| 93 |
-
"metadata": {
|
| 94 |
-
"kernelspec": {
|
| 95 |
-
"display_name": "Python 3 (ipykernel)",
|
| 96 |
-
"language": "python",
|
| 97 |
-
"name": "python3"
|
| 98 |
-
},
|
| 99 |
-
"language_info": {
|
| 100 |
-
"codemirror_mode": {
|
| 101 |
-
"name": "ipython",
|
| 102 |
-
"version": 3
|
| 103 |
-
},
|
| 104 |
-
"file_extension": ".py",
|
| 105 |
-
"mimetype": "text/x-python",
|
| 106 |
-
"name": "python",
|
| 107 |
-
"nbconvert_exporter": "python",
|
| 108 |
-
"pygments_lexer": "ipython3",
|
| 109 |
-
"version": "3.7.10"
|
| 110 |
-
}
|
| 111 |
-
},
|
| 112 |
-
"nbformat": 4,
|
| 113 |
-
"nbformat_minor": 5
|
| 114 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf
DELETED
|
Binary file (343 kB)
|
|
|
pickles/.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|