imwithye commited on
Commit
54c4741
·
1 Parent(s): e91cbff

create dataset if not exist

Browse files
Files changed (1) hide show
  1. rlcube/rlcube/train/train.py +4 -1
rlcube/rlcube/train/train.py CHANGED
@@ -1,8 +1,9 @@
1
- from rlcube.models.dataset import Cube2Dataset
2
  from rlcube.models.models import Reward, DNN
3
  from tqdm import tqdm
4
  from torch.utils.data import DataLoader
5
  import torch
 
6
 
7
  if torch.backends.mps.is_available():
8
  device = torch.device("mps")
@@ -15,6 +16,8 @@ print(f"Using device: {device}")
15
 
16
 
17
  def train(epochs: int = 100):
 
 
18
  dataset = Cube2Dataset("dataset.pt")
19
  print("Number of samples:", len(dataset))
20
  print("Number of epochs:", epochs)
 
1
+ from rlcube.models.dataset import Cube2Dataset, create_dataset
2
  from rlcube.models.models import Reward, DNN
3
  from tqdm import tqdm
4
  from torch.utils.data import DataLoader
5
  import torch
6
+ import os
7
 
8
  if torch.backends.mps.is_available():
9
  device = torch.device("mps")
 
16
 
17
 
18
  def train(epochs: int = 100):
19
+ if not os.path.exists("dataset.pt"):
20
+ create_dataset(num_envs=10000, num_steps=20, filepath="dataset.pt")
21
  dataset = Cube2Dataset("dataset.pt")
22
  print("Number of samples:", len(dataset))
23
  print("Number of epochs:", epochs)