模型训练基本流程


# 十分类问题 模型训练

import torchvision

# 准备数据集
# 训练数据集
transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.CIFAR10("dataset/cifar10", train=True, transform=transform, download=True)

# 测试数据集
test_data = torchvision.datasets.CIFAR10("dataset/cifar10", train=False, transform=transform, download=True)

# 用dataloder加载数据集

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

train_data_size = len(train_data)
print("训练数据集大小:", train_data_size)

test_data_size = len(test_data)
print("测试数据集大小:", test_data_size)

import torch.backends.mps
mps_available = torch.backends.mps.is_available()

# 构建神经网络模型
import torch

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 5, 1, 2),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 32, 5, 1, 2),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 64, 5, 1, 2),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.Linear(64*4*4, 64),
            torch.nn.Linear(64, 10)
        )
    
    def forward(self, x):
        x = self.model(x)
        return x
    ß
# 实例化模型
model = Net()
if mps_available:
    model = model.to("mps")

# 定义损失函数
loss_nn = torch.nn.CrossEntropyLoss()
if mps_available:
    loss_nn = loss_nn.to("mps")
    
# 定义优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


# 设置训练参数
total_train_step = 0 # 训练次数
total_test_step = 0 # 测试次数
epochs = 10 # 训练次数

# 训练模型
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("cifar10_logs")
for epoch in range(epochs):
    model.train()
    for data in train_dataloader:
        # 训练模型
        inputs, labels = data
        if mps_available:
            inputs = inputs.to("mps")
            labels = labels.to("mps")
        outputs = model(inputs)
        loss = loss_nn(outputs, labels)
        
        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数:", total_train_step, "Loss:", loss.item())
            writer.add_scalar("train_loss", loss.item(), total_train_step)
    
    
    # 测试模型
    model.eval()
    with torch.no_grad():
        tota_loss = 0
        total_accuracy = 0
        
        for data in test_dataloader:
            inputs, labels = data
            if mps_available:
                inputs = inputs.to("mps")
                labels = labels.to("mps")
            outputs = model(inputs)
            loss = loss_nn(outputs, labels)
            tota_loss += loss.item()
            total_test_step += 1
            accuracy = (outputs.argmax(1) == labels).sum()
            total_accuracy += accuracy
        print("测试次数:", total_test_step, "正确率:", total_accuracy/test_data_size*100, "%")
        print("测试次数:", total_test_step, "Loss:", tota_loss)
        writer.add_scalar("test_loss", tota_loss, total_test_step)
    
    # 保存模型
    torch.save(model, f"cifar10_model_{epoch}.pth")

文章作者: Wanheng
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Wanheng !
评论
 上一篇
个人记事本 个人记事本
记录一些容易忘记的又常用的信息
2024-03-20 Wanheng
下一篇 
redis相关 redis相关
Redis 是一种基于内存的数据库,对数据的读写操作都是在内存中完成,因此读写速度非常快,常用于缓存,消息队列、分布式锁等场景。
2024-02-28
  目录