Spaces:
Sleeping
Sleeping
generate dataset
Browse files- rlcube/.gitignore +3 -1
- rlcube/rlcube/models/dataset.py +7 -5
rlcube/.gitignore
CHANGED
|
@@ -219,4 +219,6 @@ __marimo__/
|
|
| 219 |
*.blend1
|
| 220 |
|
| 221 |
# Dataset
|
| 222 |
-
dataset.pt
|
|
|
|
|
|
|
|
|
| 219 |
*.blend1
|
| 220 |
|
| 221 |
# Dataset
|
| 222 |
+
dataset.pt
|
| 223 |
+
dataset_*.pt
|
| 224 |
+
!dataset_final.pt
|
rlcube/rlcube/models/dataset.py
CHANGED
|
@@ -5,7 +5,9 @@ import torch
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
| 7 |
|
| 8 |
-
def create_dataset(
|
|
|
|
|
|
|
| 9 |
states = []
|
| 10 |
neighbors = []
|
| 11 |
D = []
|
|
@@ -22,11 +24,11 @@ def create_dataset(num_envs: int = 10000, num_steps: int = 50):
|
|
| 22 |
neighbors = np.array(neighbors)
|
| 23 |
D = np.array(D)
|
| 24 |
dataseet = {
|
| 25 |
-
"states": torch.tensor(states),
|
| 26 |
-
"neighbors": torch.tensor(neighbors),
|
| 27 |
-
"D": torch.tensor(D),
|
| 28 |
}
|
| 29 |
-
torch.save(dataseet,
|
| 30 |
|
| 31 |
|
| 32 |
class Cube2Dataset(Dataset):
|
|
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
| 7 |
|
| 8 |
+
def create_dataset(
|
| 9 |
+
num_envs: int = 10000, num_steps: int = 50, filepath: str = "dataset.pt"
|
| 10 |
+
):
|
| 11 |
states = []
|
| 12 |
neighbors = []
|
| 13 |
D = []
|
|
|
|
| 24 |
neighbors = np.array(neighbors)
|
| 25 |
D = np.array(D)
|
| 26 |
dataseet = {
|
| 27 |
+
"states": torch.tensor(states, dtype=torch.float32),
|
| 28 |
+
"neighbors": torch.tensor(neighbors, dtype=torch.float32),
|
| 29 |
+
"D": torch.tensor(D, dtype=torch.float32),
|
| 30 |
}
|
| 31 |
+
torch.save(dataseet, filepath)
|
| 32 |
|
| 33 |
|
| 34 |
class Cube2Dataset(Dataset):
|