| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import datasets, transforms | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from model import ColorNet | |
| transform = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True) | |
| test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True) | |
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) | |
| model = ColorNet() | |
| criterion = nn.MSELoss() | |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
| model.train_model(model, train_loader, criterion, optimizer, num_epochs=10) | |