imwithye commited on
Commit
8b59b01
·
1 Parent(s): f02352c

implement search

Browse files
Files changed (2) hide show
  1. rlcube/cube2.ipynb +167 -58
  2. rlcube/rlcube/models/search.py +95 -0
rlcube/cube2.ipynb CHANGED
@@ -40,7 +40,6 @@
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
  "from rlcube.envs.cube2 import Cube2Env\n",
43
- "import numpy as np\n",
44
  "import torch\n",
45
  "\n",
46
  "net = DNN()\n",
@@ -50,81 +49,191 @@
50
  },
51
  {
52
  "cell_type": "code",
53
- "execution_count": 9,
54
  "id": "16736f3a",
55
  "metadata": {},
56
  "outputs": [
57
  {
58
- "name": "stdout",
59
  "output_type": "stream",
60
  "text": [
61
- "rotationController.setState([[0, 0, 4, 4], [1, 1, 5, 5], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 1, 1], [5, 5, 0, 0]]);\n",
62
- "0.40487873554229736\n",
63
- "4\n",
64
- "\n",
65
- "rotationController.setState([[0, 4, 0, 4], [1, 1, 5, 5], [2, 5, 2, 0], [3, 4, 3, 1], [4, 2, 1, 2], [5, 3, 0, 3]]);\n",
66
- "0.0839405208826065\n",
67
- "7\n",
68
- "\n",
69
- "rotationController.setState([[0, 4, 0, 4], [5, 1, 5, 1], [1, 5, 4, 0], [0, 4, 5, 1], [3, 2, 3, 2], [2, 3, 2, 3]]);\n",
70
- "-0.23320673406124115\n",
71
- "3\n",
72
- "\n",
73
- "rotationController.setState([[0, 5, 0, 1], [5, 4, 5, 0], [1, 5, 4, 4], [0, 4, 1, 1], [3, 3, 2, 2], [2, 3, 2, 3]]);\n",
74
- "0.31869572401046753\n",
75
- "0\n",
76
- "\n",
77
- "rotationController.setState([[5, 5, 1, 1], [4, 4, 0, 0], [5, 5, 4, 4], [0, 0, 1, 1], [3, 3, 2, 2], [3, 3, 2, 2]]);\n",
78
- "-0.16905824840068817\n",
79
- "7\n",
80
- "\n",
81
- "rotationController.setState([[5, 4, 1, 4], [4, 1, 0, 1], [5, 5, 4, 0], [0, 0, 5, 1], [3, 2, 3, 2], [3, 3, 2, 2]]);\n",
82
- "0.20266102254390717\n",
83
- "3\n",
84
- "\n",
85
- "rotationController.setState([[2, 3, 1, 4], [3, 3, 0, 1], [5, 5, 4, 0], [0, 1, 0, 5], [4, 1, 3, 2], [5, 4, 2, 2]]);\n",
86
- "0.6111429333686829\n",
87
- "3\n",
88
- "\n",
89
- "rotationController.setState([[2, 0, 1, 4], [3, 5, 0, 0], [5, 5, 3, 1], [0, 1, 3, 4], [1, 2, 4, 3], [5, 4, 2, 2]]);\n",
90
- "1.3550236225128174\n",
91
- "2\n",
92
- "\n",
93
- "rotationController.setState([[0, 0, 1, 4], [5, 5, 5, 0], [1, 2, 3, 1], [0, 3, 3, 4], [1, 2, 4, 3], [2, 5, 2, 4]]);\n",
94
- "0.9975889325141907\n",
95
- "7\n",
96
- "\n",
97
- "rotationController.setState([[2, 0, 1, 4], [3, 5, 0, 0], [5, 5, 3, 1], [0, 1, 3, 4], [1, 2, 4, 3], [5, 4, 2, 2]]);\n",
98
- "1.3550236225128174\n",
99
- "2\n",
100
- "\n"
101
  ]
102
  }
103
  ],
104
  "source": [
105
- "batch_obs = []\n",
106
- "env = Cube2Env()\n",
107
- "for _ in range(10):\n",
108
- " obs, _, _, _, _ = env.step(env.action_space.sample())\n",
109
- " batch_obs.append(torch.tensor(obs, dtype=torch.float32))\n",
110
- "batched_obs = torch.stack(batch_obs)\n",
111
- "out = net(batched_obs)\n",
112
  "\n",
113
- "for i in range(10):\n",
114
- " env = Cube2Env.from_obs(batch_obs[i])\n",
115
- " env.print_js_code()\n",
116
- " print(out[\"value\"][i].item())\n",
117
- " print(torch.argmax(out[\"policy\"][i]).item())\n",
118
- " print()"
 
119
  ]
120
  },
121
  {
122
  "cell_type": "code",
123
- "execution_count": null,
124
  "id": "aee2a911",
125
  "metadata": {},
126
  "outputs": [],
127
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  }
129
  ],
130
  "metadata": {
 
40
  "source": [
41
  "from rlcube.models.models import DNN\n",
42
  "from rlcube.envs.cube2 import Cube2Env\n",
 
43
  "import torch\n",
44
  "\n",
45
  "net = DNN()\n",
 
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": 2,
53
  "id": "16736f3a",
54
  "metadata": {},
55
  "outputs": [
56
  {
57
+ "name": "stderr",
58
  "output_type": "stream",
59
  "text": [
60
+ "100%|██████████| 300/300 [00:02<00:00, 132.06it/s]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ]
62
  }
63
  ],
64
  "source": [
65
+ "from rlcube.models.search import MonteCarloTree\n",
 
 
 
 
 
 
66
  "\n",
67
+ "env = Cube2Env()\n",
68
+ "actions = []\n",
69
+ "for _ in range(3):\n",
70
+ " action = env.action_space.sample()\n",
71
+ " actions.append(action)\n",
72
+ " env.step(action)\n",
73
+ "tree = MonteCarloTree(env.obs())"
74
  ]
75
  },
76
  {
77
  "cell_type": "code",
78
+ "execution_count": 3,
79
  "id": "aee2a911",
80
  "metadata": {},
81
  "outputs": [],
82
+ "source": [
83
+ "node = tree.root"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 4,
89
+ "id": "048f58c9",
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "data": {
94
+ "text/plain": [
95
+ "[np.int64(8), np.int64(1), np.int64(4)]"
96
+ ]
97
+ },
98
+ "execution_count": 4,
99
+ "metadata": {},
100
+ "output_type": "execute_result"
101
+ }
102
+ ],
103
+ "source": [
104
+ "actions"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 5,
110
+ "id": "00994021",
111
+ "metadata": {},
112
+ "outputs": [
113
+ {
114
+ "data": {
115
+ "text/plain": [
116
+ "tensor([3.4725e+00, 3.3189e+00, 1.2619e-02, 3.1231e-01, 1.1286e-02, 2.5817e-02,\n",
117
+ " 1.6722e-02, 2.1334e-02, 3.4603e+00, 7.5021e-02, 2.5891e-02, 2.8712e-03])"
118
+ ]
119
+ },
120
+ "execution_count": 5,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "node.u()"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 6,
132
+ "id": "fb9ac54c",
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "data": {
137
+ "text/plain": [
138
+ "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
139
+ " {0: 276,\n",
140
+ " 1: 7,\n",
141
+ " 2: 0,\n",
142
+ " 3: 0,\n",
143
+ " 4: 0,\n",
144
+ " 5: 0,\n",
145
+ " 6: 0,\n",
146
+ " 7: 0,\n",
147
+ " 8: 16,\n",
148
+ " 9: 0,\n",
149
+ " 10: 0,\n",
150
+ " 11: 0})"
151
+ ]
152
+ },
153
+ "execution_count": 6,
154
+ "metadata": {},
155
+ "output_type": "execute_result"
156
+ }
157
+ ],
158
+ "source": [
159
+ "node.N"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 7,
165
+ "id": "2f8a09d1",
166
+ "metadata": {},
167
+ "outputs": [
168
+ {
169
+ "data": {
170
+ "text/plain": [
171
+ "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
172
+ " {0: tensor([3.4720]),\n",
173
+ " 1: tensor([1.8959]),\n",
174
+ " 2: 0,\n",
175
+ " 3: 0,\n",
176
+ " 4: 0,\n",
177
+ " 5: 0,\n",
178
+ " 6: 0,\n",
179
+ " 7: 0,\n",
180
+ " 8: tensor([2.7285]),\n",
181
+ " 9: 0,\n",
182
+ " 10: 0,\n",
183
+ " 11: 0})"
184
+ ]
185
+ },
186
+ "execution_count": 7,
187
+ "metadata": {},
188
+ "output_type": "execute_result"
189
+ }
190
+ ],
191
+ "source": [
192
+ "node.W"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": 8,
198
+ "id": "3e341459",
199
+ "metadata": {},
200
+ "outputs": [
201
+ {
202
+ "data": {
203
+ "text/plain": [
204
+ "defaultdict(<function rlcube.models.search.Node.__init__.<locals>.<lambda>()>,\n",
205
+ " {0: 4,\n",
206
+ " 1: 0,\n",
207
+ " 2: 0,\n",
208
+ " 3: 0,\n",
209
+ " 4: 0,\n",
210
+ " 5: 2,\n",
211
+ " 6: 0,\n",
212
+ " 7: 0,\n",
213
+ " 8: 269,\n",
214
+ " 9: 0,\n",
215
+ " 10: 0,\n",
216
+ " 11: 0})"
217
+ ]
218
+ },
219
+ "execution_count": 8,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
+ "source": [
225
+ "node.children[0].N"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "51dddf56",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "node.children[8].N"
236
+ ]
237
  }
238
  ],
239
  "metadata": {
rlcube/rlcube/models/search.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ from rlcube.models.models import DNN
4
+ from rlcube.envs.cube2 import Cube2Env
5
+ from tqdm import tqdm
6
+
7
+ net = DNN()
8
+ net.load("models/model_best.pth")
9
+ net.eval()
10
+
11
+
12
+ class Node:
13
+ def __init__(self, obs, parent=None):
14
+ self.obs = torch.tensor(obs, dtype=torch.float32)
15
+ self.parent = parent
16
+
17
+ out = net(self.obs.unsqueeze(0))
18
+ value = out["value"].detach()
19
+ policy = torch.softmax(out["policy"].detach(), dim=1)
20
+
21
+ self.is_solved = Cube2Env.from_obs(obs).is_solved()
22
+ self.value = torch.tensor(1) if self.is_solved else value.view(-1)
23
+ self.policy = policy.view(-1)
24
+
25
+ self.children = defaultdict(lambda: None)
26
+ self.N = defaultdict(lambda: 0)
27
+ self.W = defaultdict(lambda: 0)
28
+
29
+ def is_leaf(self):
30
+ return len(self.children) == 0
31
+
32
+ def u(self):
33
+ c = 1.414
34
+ n_sum = torch.sum(torch.tensor([self.N[action] for action in range(12)]))
35
+ u = torch.tensor(
36
+ [
37
+ c
38
+ * self.policy[action].item()
39
+ * torch.sqrt(n_sum)
40
+ / (self.N[action] + 1)
41
+ + self.W[action]
42
+ for action in range(12)
43
+ ]
44
+ )
45
+ return u
46
+
47
+ def select_action(self):
48
+ return torch.argmax(self.u()).item()
49
+
50
+
51
+ class MonteCarloTree:
52
+ def __init__(self, obs, max_simulations=300):
53
+ self.obs = obs
54
+ self.max_simulations = max_simulations
55
+ self.root = Node(obs)
56
+ self.nodes = [self.root]
57
+ self.is_solved = False
58
+ self._build()
59
+
60
+ def _build(self):
61
+ for _ in tqdm(range(self.max_simulations)):
62
+ if self.is_solved:
63
+ break
64
+
65
+ node = self.root
66
+ path = []
67
+
68
+ # Selection
69
+ while not node.is_leaf():
70
+ action = node.select_action()
71
+ path.append((node, action))
72
+ node = node.children[action]
73
+
74
+ # Expansion
75
+ env = Cube2Env.from_obs(node.obs)
76
+ adjacent_obs = env.adjacent_obs()
77
+ for i in range(12):
78
+ obs = adjacent_obs[i]
79
+ child = Node(obs, node)
80
+ node.children[i] = child
81
+ self.nodes.append(child)
82
+ self.is_solved = self.is_solved or child.is_solved
83
+
84
+ # Backup
85
+ for parent, action in reversed(path):
86
+ parent.N[action] += 1
87
+ parent.W[action] = max(parent.W[action], node.value)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ env = Cube2Env()
92
+ for _ in range(3):
93
+ env.step(env.action_space.sample())
94
+ tree = MonteCarloTree(env.obs())
95
+ print(tree.is_solved)