深度学习模型训练的一般过程

向量数据库大模型机器学习

在上一篇 《一图讲透AI核心概念与名词术语》 中提到,多层感知机(MLP)是最简单的深度神经网络基本结构,一般也是入门深度学习的开始。有一个经典的例子经常用于各种教学:在 MNIST 手写数字数据集上训练一个MLP模型,并测试其效果。这个例子覆盖了深度学习模型训练的一般过程:

数据准备 → 模型定义 → 训练配置 → 训练循环 → 评估测试

我们下面逐步看下。

首先要准备一个可运行训练代码的环境。因为数据集非常小,CPU,GPU都可以。

环境准备

最快捷的方法,用 pytorch 官方现成的 docker 镜像即可。启动容器:

  
docker run --name train -itd --gpus '"device=0"' \  
  -e LANG=C.UTF-8 -e LC_ALL=C.UTF-8 \  
  -v /data/ai/datasets:/datasets \  
  -v /data/ai/workspace/train:/workspace \  
  pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel bash

其中

  • --gpus '"device=0"' 指定容器要使用的GPU卡,如果没有GPU,可以不写则使用CPU。如果容器运行时默认是nvidia的,则会挂载所有GPU卡
  • -e LANG=C.UTF-8 -e LC_ALL=C.UTF-8 设定容器中文环境
  • -v 行指定挂载进容器的 host 目录

进入容器:

  
docker exec -it train bash  
root@cb82d3b9af0a:/workspace# pip list | grep torch  
torch                     2.5.1+cu121  
torchaudio                2.5.1+cu121  
torchelastic              0.2.2  
torchvision               0.20.1+cu121

安装依赖包:

  
pip install jupyter matplotlib 

启动 Jupter

  
nohup jupyter notebook --allow-root --ip=0.0.0.0 > jupyter.log 2>&1 &  
  
tail -f jupyter.log  
  
[I 2025-06-30 08:48:34.259 ServerApp] Jupyter Server 2.16.0 is running at:  
[I 2025-06-30 08:48:34.259 ServerApp] http://jupyter-d4c6f6f9f-dphhp:8888/tree?token=xxx  
[I 2025-06-30 08:48:34.259 ServerApp]     http://127.0.0.1:8888/tree?token=xxx

在浏览器中访问 http://127.0.0.1:8888/tree?token=xxx 进入 Jupyter 界面

用 Jupyter 是为了有些图片图表,可以不依赖服务器图形环境,直接浏览器中就能看到。当然也方便反复调测。
训练代码

以下代码片段,可在 Jupyter 中逐步执行。

导入必要的库:

  
import torch  # PyTorch主库  
import torch.nn as nn  # 神经网络模块  
import torch.optim as optim  # 优化算法模块  
from torchvision import datasets, transforms  # 计算机视觉数据集和预处理  
from torch.utils.data import DataLoader  # 数据加载器  
import matplotlib.pyplot as plt  # 绘图库

设置计算设备:

  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
print(f"使用设备: {device}")
  • 检查是否有可用的GPU,如果有则使用GPU加速计算,否则使用CPU
  • "cuda"表示NVIDIA GPU计算平台
  • 输出当前使用的计算设备信息

数据准备

数据预处理:

  
transform = transforms.Compose([  
    transforms.ToTensor(),  # 将PIL图像或numpy数组转换为张量  
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化处理  
])
  • "ToTensor()": 将图像转换为PyTorch张量,同时将像素值从[0,255]缩放到[0,1]
  • "Normalize()": 使用MNIST数据集的均值和标准差进行标准化

加载MNIST数据集:

  
train_dataset = datasets.MNIST(root='./data',   
                              train=True,  
                              download=True,  
                              transform=transform)  
  
test_dataset = datasets.MNIST(root='./data',  
                             train=False,  
                             transform=transform)
  • 创建训练集和测试集
  • "root='./data'": 数据存储路径
  • "download=True": 如果本地不存在则自动下载
  • "train=True/False": 分别表示加载训练集或测试集
  • 应用之前定义的数据预处理流程

创建数据加载器:

  
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
  • "batch_size=64": 每个训练批次包含64个样本
  • "shuffle=True": 训练数据每轮随机打乱顺序
  • "shuffle=False": 测试数据保持原始顺序

可视化样本数据:

  
def show_sample(images, labels):  
    plt.figure(figsize=(10, 5))  
    for i in range(10):  
        plt.subplot(2, 5, i+1)  
        plt.imshow(images[i][0], cmap='gray')  # 显示灰度图  
        plt.title(f"Label: {labels[i].item()}")  # 显示标签  
        plt.axis('off')  
    plt.tight_layout()  
    plt.show()  
  
# 获取并显示第一批训练数据  
sample_images, sample_labels = next(iter(train_loader))  
show_sample(sample_images, sample_labels)
  • 创建一个2行5列的图显示10个样本
  • "images[i][0]": 因为图像是单通道,取第一个通道显示
  • "cmap='gray'": 使用灰度颜色映射
  • "axis('off')": 不显示坐标轴

如果在 Jupyter 或有图像界面的环境执行,可以看到获取的数据集数据显现类似如下手写数字的图像:

picture.image

模型定义

定义多层感知机 MLP 模型

  
class MLP(nn.Module):  
    def __init__(self):  
        super(MLP, self).__init__()  
        self.fc1 = nn.Linear(28 * 28, 512)  # 输入层到隐藏层  
        self.fc2 = nn.Linear(512, 256)    # 隐藏层到隐藏层  
        self.fc3 = nn.Linear(256, 10)     # 隐藏层到输出层  
  
    def forward(self, x):  
        x = x.view(-1, 28 * 28)  # 展平图像 (批大小, 784)  
        x = torch.relu(self.fc1(x))  # ReLU激活函数  
        x = torch.relu(self.fc2(x))  
        x = self.fc3(x)  # 输出层不使用激活函数  
        return x  
  
model = MLP().to(device)  # 将模型移动到计算设备

这段代码定义一个三层的 MLP 网络:

  • 输入层: 28x28=784像素 → 512神经元
  • 隐藏层: 512 → 256神经元
  • 输出层: 256 → 10神经元 (对应0-9十个数字)
  • "view(-1, 28*28)": 将二维图像展平为一维向量
  • 使用ReLU激活函数引入非线性
  • "to(device)": 将模型参数转移到指定设备

训练配置

设置损失函数和优化器

  
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数  
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器
  • "CrossEntropyLoss()": 适用于多分类问题
  • "Adam": 自适应学习率的优化算法
  • "lr=0.001": 学习率

定义训练函数

  
def train(epoch):  
    model.train()  # 设置为训练模式  
    for batch_idx, (data, target) in enumerate(train_loader):  
        data, target = data.to(device), target.to(device)  # 数据移到设备  
  
        optimizer.zero_grad()  # 清除梯度缓存  
        output = model(data)   # 前向传播  
        loss = criterion(output, target)  # 计算损失  
        loss.backward()         # 反向传播计算梯度  
        optimizer.step()        # 更新参数  
  
        # 每100个批次打印进度  
        if batch_idx % 100 == 0:  
            print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'  
                  f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
  • "model.train()": 启用dropout和batch normalization
  • "zero_grad()": 防止梯度累积
  • "loss.backward()": 自动微分计算梯度
  • "optimizer.step()": 参数更新
  • 每处理100个批次打印一次进度和损失

定义测试函数

  
def test():  
    model.eval()  # 设置为评估模式  
    test_loss = 0  
    correct = 0  
  
    with torch.no_grad():  # 禁用梯度计算  
        for data, target in test_loader:  
            data, target = data.to(device), target.to(device)  
            output = model(data)  
            test_loss += criterion(output, target).item()  # 累加损失  
            pred = output.argmax(dim=1, keepdim=True)  # 获取预测结果  
            # 统计正确预测数  
            correct += pred.eq(target.view_as(pred)).sum().item()  
  
    test_loss /= len(test_loader.dataset)  # 计算平均损失  
    accuracy = 100. * correct / len(test_loader.dataset)  
  
    # 打印测试结果  
    print(f'\n测试集: 平均损失: {test_loss:.4f}, 准确率: {correct}/{len(test_loader.dataset)}'  
          f' ({accuracy:.2f}%)\n')  
    return accuracy
  • "model.eval()": 禁用dropout和batch normalization
  • "torch.no_grad()": 节省内存和计算资源
  • "argmax(dim=1)": 获取概率最大的类别作为预测结果
  • "view_as(pred)": 调整目标张量形状以匹配预测结果
  • "eq()": 比较预测与真实标签
  • 计算整体准确率

训练循环

执行训练和测试循环

  
accuracies = []  
for epoch in range(1, 6):  # 训练5个epoch  
    train(epoch)  
    acc = test()  
    accuracies.append(acc)
  • 进行5轮(epoch)训练
  • 每轮结束后测试并记录准确率

可以看到类似如下的输出:

  
使用设备: cuda  
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz  
Failed to download (trying next):  
HTTP Error 404: Not Found  
  
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz  
...  
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw  
  
Epoch: 1 [0/60000 (0%)] Loss: 2.313594  
Epoch: 1 [6400/60000 (11%)] Loss: 0.219006  
Epoch: 1 [12800/60000 (21%)] Loss: 0.106987  
Epoch: 1 [19200/60000 (32%)] Loss: 0.149337  
Epoch: 1 [25600/60000 (43%)] Loss: 0.237175  
Epoch: 1 [32000/60000 (53%)] Loss: 0.137358  
Epoch: 1 [38400/60000 (64%)] Loss: 0.273473  
Epoch: 1 [44800/60000 (75%)] Loss: 0.046474  
Epoch: 1 [51200/60000 (85%)] Loss: 0.028811  
Epoch: 1 [57600/60000 (96%)] Loss: 0.039038  
  
。。。  
  
Epoch: 5 [0/60000 (0%)] Loss: 0.004618  
Epoch: 5 [6400/60000 (11%)] Loss: 0.124851  
Epoch: 5 [12800/60000 (21%)] Loss: 0.039882  
Epoch: 5 [19200/60000 (32%)] Loss: 0.011858  
Epoch: 5 [25600/60000 (43%)] Loss: 0.006973  
Epoch: 5 [32000/60000 (53%)] Loss: 0.008163  
Epoch: 5 [38400/60000 (64%)] Loss: 0.025966  
Epoch: 5 [44800/60000 (75%)] Loss: 0.094705  
Epoch: 5 [51200/60000 (85%)] Loss: 0.030636  
Epoch: 5 [57600/60000 (96%)] Loss: 0.008534  
  
测试集: 平均损失: 0.0001, 准确率: 9799/10000 (97.99%)

可以看到,简单的一个模型,实现了近98%的识别准确率。因网络简单,数据集迷你,训练消耗也比较小:

picture.image

结果可视化

  
plt.plot(range(1, 6), accuracies, 'o-')  
plt.title("Model Accuracy per Epoch") # 模型准确率随训练轮次变化  
plt.xlabel('Epoch')                   # 训练轮次  
plt.ylabel('Accuracy (%)')            # 准确率  
plt.grid(True)  
plt.show()
  • 绘制准确率随训练轮次变化的折线图
  • "o-": 带圆点的实线
  • 添加标题、坐标轴标签和网格线

在 Jupyter 中执行,可以看到如下图像:

picture.image

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论