imwithye commited on
Commit
4883d24
·
1 Parent(s): 156f4d5

generate dataset

Browse files
rlcube/.gitignore CHANGED
@@ -219,4 +219,6 @@ __marimo__/
219
  *.blend1
220
 
221
  # Dataset
222
- dataset.pt
 
 
 
219
  *.blend1
220
 
221
  # Dataset
222
+ dataset.pt
223
+ dataset_*.pt
224
+ !dataset_final.pt
rlcube/rlcube/models/dataset.py CHANGED
@@ -5,7 +5,9 @@ import torch
5
  from tqdm import tqdm
6
 
7
 
8
- def create_dataset(num_envs: int = 10000, num_steps: int = 50):
 
 
9
  states = []
10
  neighbors = []
11
  D = []
@@ -22,11 +24,11 @@ def create_dataset(num_envs: int = 10000, num_steps: int = 50):
22
  neighbors = np.array(neighbors)
23
  D = np.array(D)
24
  dataseet = {
25
- "states": torch.tensor(states),
26
- "neighbors": torch.tensor(neighbors),
27
- "D": torch.tensor(D),
28
  }
29
- torch.save(dataseet, "dataset.pt")
30
 
31
 
32
  class Cube2Dataset(Dataset):
 
5
  from tqdm import tqdm
6
 
7
 
8
+ def create_dataset(
9
+ num_envs: int = 10000, num_steps: int = 50, filepath: str = "dataset.pt"
10
+ ):
11
  states = []
12
  neighbors = []
13
  D = []
 
24
  neighbors = np.array(neighbors)
25
  D = np.array(D)
26
  dataseet = {
27
+ "states": torch.tensor(states, dtype=torch.float32),
28
+ "neighbors": torch.tensor(neighbors, dtype=torch.float32),
29
+ "D": torch.tensor(D, dtype=torch.float32),
30
  }
31
+ torch.save(dataseet, filepath)
32
 
33
 
34
  class Cube2Dataset(Dataset):