imwithye commited on
Commit
156f4d5
·
1 Parent(s): 6537541

add reward net

Browse files
Files changed (1) hide show
  1. rlcube/rlcube/models/models.py +32 -0
rlcube/rlcube/models/models.py CHANGED
@@ -2,6 +2,27 @@ import torch.nn as nn
2
  import torch.nn.functional as F
3
  import torch
4
  from tensordict import TensorDict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class ResidualBlock(nn.Module):
@@ -54,10 +75,21 @@ class DNN(nn.Module):
54
 
55
 
56
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
57
  print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
58
  x = torch.randn(4, 24, 6)
59
  print("Input shape:", x.shape)
60
  print("Output shape:", ResidualBlock(6, 128)(x).shape)
 
61
 
62
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
63
  x = torch.randn(4, 24, 6)
 
2
  import torch.nn.functional as F
3
  import torch
4
  from tensordict import TensorDict
5
+ from rlcube.envs.cube2 import Cube2
6
+ import numpy as np
7
+
8
+
9
+ class RewardNet(nn.Module):
10
+ def __init__(self):
11
+ super(RewardNet, self).__init__()
12
+
13
+ def forward(self, batch_obs):
14
+ one_indices = batch_obs.argmax(dim=2)
15
+ # (batch, 24) -> (batch, 6, 4), 6 faces, 4 stickers
16
+ face_indices = one_indices.view(batch_obs.shape[0], 6, 4)
17
+ # (batch, 6), For each face, check if all stickers have the same index, i.e. compare with the first sticker
18
+ face_solved = (face_indices == face_indices[:, :, 0:1]).all(dim=2) #
19
+ # (batch,), For each batch, check if all faces are solved
20
+ solved = face_solved.all(dim=1)
21
+ return torch.where(
22
+ solved,
23
+ torch.tensor(1, device=batch_obs.device, dtype=batch_obs.dtype),
24
+ torch.tensor(-1, device=batch_obs.device, dtype=batch_obs.dtype),
25
+ )
26
 
27
 
28
  class ResidualBlock(nn.Module):
 
75
 
76
 
77
  if __name__ == "__main__":
78
+ print("Testing RewardNet")
79
+ env = Cube2()
80
+ obs, _ = env.reset()
81
+ obs1, _, _, _, _ = env.step(1)
82
+ obs2, _, _, _, _ = env.step(2)
83
+ x = torch.tensor(np.array([obs, obs1, obs2]))
84
+ print("Input shape:", x.shape)
85
+ print("Output:", RewardNet()(x))
86
+ print()
87
+
88
  print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
89
  x = torch.randn(4, 24, 6)
90
  print("Input shape:", x.shape)
91
  print("Output shape:", ResidualBlock(6, 128)(x).shape)
92
+ print()
93
 
94
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
95
  x = torch.randn(4, 24, 6)