# 十分类问题 模型训练
import torchvision
# 准备数据集
# 训练数据集
transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.CIFAR10("dataset/cifar10", train=True, transform=transform, download=True)
# 测试数据集
test_data = torchvision.datasets.CIFAR10("dataset/cifar10", train=False, transform=transform, download=True)
# 用dataloder加载数据集
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
train_data_size = len(train_data)
print("训练数据集大小:", train_data_size)
test_data_size = len(test_data)
print("测试数据集大小:", test_data_size)
import torch.backends.mps
mps_available = torch.backends.mps.is_available()
# 构建神经网络模型
import torch
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 5, 1, 2),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(32, 32, 5, 1, 2),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(32, 64, 5, 1, 2),
torch.nn.MaxPool2d(2),
torch.nn.Flatten(),
torch.nn.Linear(64*4*4, 64),
torch.nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
ß
# 实例化模型
model = Net()
if mps_available:
model = model.to("mps")
# 定义损失函数
loss_nn = torch.nn.CrossEntropyLoss()
if mps_available:
loss_nn = loss_nn.to("mps")
# 定义优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 设置训练参数
total_train_step = 0 # 训练次数
total_test_step = 0 # 测试次数
epochs = 10 # 训练次数
# 训练模型
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("cifar10_logs")
for epoch in range(epochs):
model.train()
for data in train_dataloader:
# 训练模型
inputs, labels = data
if mps_available:
inputs = inputs.to("mps")
labels = labels.to("mps")
outputs = model(inputs)
loss = loss_nn(outputs, labels)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 100 == 0:
print("训练次数:", total_train_step, "Loss:", loss.item())
writer.add_scalar("train_loss", loss.item(), total_train_step)
# 测试模型
model.eval()
with torch.no_grad():
tota_loss = 0
total_accuracy = 0
for data in test_dataloader:
inputs, labels = data
if mps_available:
inputs = inputs.to("mps")
labels = labels.to("mps")
outputs = model(inputs)
loss = loss_nn(outputs, labels)
tota_loss += loss.item()
total_test_step += 1
accuracy = (outputs.argmax(1) == labels).sum()
total_accuracy += accuracy
print("测试次数:", total_test_step, "正确率:", total_accuracy/test_data_size*100, "%")
print("测试次数:", total_test_step, "Loss:", tota_loss)
writer.add_scalar("test_loss", tota_loss, total_test_step)
# 保存模型
torch.save(model, f"cifar10_model_{epoch}.pth")
评论