百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程网 > 正文

PyTorch 深度学习实战(3):神经网络基础与手写数字识别

yuyutoo 2025-03-06 21:00 1 浏览 0 评论



在上一篇文章中,我们学习了 PyTorch 的自动求导机制(Autograd),并实现了一个简单的线性回归模型。本文将深入探讨神经网络的基本概念,并使用 PyTorch 构建一个简单的神经网络来解决经典的 手写数字识别 问题。


一、神经网络基础

神经网络是深度学习的核心,它由多个层(Layer)组成,每一层包含若干个神经元(Neuron)。神经元通过权重(Weight)和偏置(Bias)对输入数据进行线性变换,并通过激活函数(Activation Function)引入非线性。

1. 神经网络的结构

一个典型的神经网络包括以下部分:

  • 输入层:接收输入数据。
  • 隐藏层:对数据进行非线性变换。
  • 输出层:输出最终的预测结果。

2. 激活函数

激活函数是神经网络中引入非线性的关键。常用的激活函数包括:

  • ReLU(Rectified Linear Unit):ReLU(x) = max(0, x)
  • Sigmoid:Sigmoid(x) = 1 / (1 + exp(-x))
  • Tanh:Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))

3. 损失函数

损失函数用于衡量模型预测值与真实值之间的差距。常用的损失函数包括:

  • 均方误差(MSE):用于回归问题。
  • 交叉熵损失(Cross-Entropy Loss):用于分类问题。

4. 优化器

优化器用于更新模型参数以最小化损失函数。常用的优化器包括:

  • 随机梯度下降(SGD)
  • Adam

二、手写数字识别实战

手写数字识别是深度学习中的经典问题,我们将使用 MNIST 数据集 来训练一个简单的神经网络模型。

1. 问题描述

MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,每张图像是一个 28x28 的灰度图,表示 0 到 9 的手写数字。我们的目标是构建一个神经网络模型,能够正确识别这些手写数字。

2. 实现步骤

  1. 加载和预处理数据。
  2. 定义神经网络模型。
  3. 定义损失函数和优化器。
  4. 训练模型。
  5. 测试模型并评估性能。

3. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 加载和预处理数据
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化
])

# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# 2. 定义神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # 全连接层,输入 28x28,输出 128
        self.fc2 = nn.Linear(128, 64)  # 全连接层,输入 128,输出 64
        self.fc3 = nn.Linear(64, 10)  # 全连接层,输入 64,输出 10(10 个类别)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 将图像展平为一维向量
        x = torch.relu(self.fc1(x))  # 第一层 + ReLU 激活
        x = torch.relu(self.fc2(x))  # 第二层 + ReLU 激活
        x = self.fc3(x)  # 输出层
        return x

model = SimpleNN()

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器

# 4. 训练模型
num_epochs = 5
loss_history = []

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录损失
        if (i + 1) % 100 == 0:
            loss_history.append(loss.item())
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")

# 5. 测试模型
model.eval()  # 设置模型为评估模式
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"测试集准确率: {100 * correct / total:.2f}%")

# 6. 可视化损失曲线
plt.plot(loss_history)
plt.xlabel("训练步数")
plt.ylabel("损失值")
plt.title("训练损失曲线")
plt.show()

三、代码解析

1.数据加载与预处理:

  • 使用 torchvision.datasets.MNIST 加载 MNIST 数据集。
  • 使用 transforms.ToTensor() 将图像转换为张量,并进行标准化。

2.神经网络模型:

  • 定义了一个简单的全连接神经网络 SimpleNN,包含两个隐藏层和一个输出层。
  • 使用 ReLU 作为激活函数。

3.训练过程:

  • 使用交叉熵损失函数和 Adam 优化器。
  • 训练 5 个 epoch,并记录损失值。

4.测试过程:

  • 在测试集上评估模型性能,计算准确率。

5.可视化:

  • 绘制训练损失曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  1. 训练过程中每 100 步打印一次损失值。
  2. 测试集准确率(通常在 95% 以上)。
  3. 训练损失曲线图。

五、总结

本文介绍了神经网络的基本概念,并使用 PyTorch 实现了一个简单的手写数字识别模型。通过这个例子,我们学习了如何定义神经网络、加载数据、训练模型以及评估性能。

在下一篇文章中,我们将学习如何使用卷积神经网络(CNN)来进一步提升手写数字识别的性能。敬请期待!


代码实例说明:

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda'),images = images.to('cuda')。

希望这篇文章能帮助你更好地理解神经网络的基础知识!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

网站建设:从新手到高手

现代化网站应用领域非常广泛,从个人形象网站展示、企业商业网站运作、到政府公益等服务网站,各行各业都需要网站建设。大体上可以归结四类:宣传型网站设计、产品型网站制作、电子商务型网站建设、定制型功能网站开...

JetBrains 推出全新 AI 编程工具 Junie,助力高效开发

JetBrains宣布推出名为Junie的全新AI编程工具。这款工具不仅能执行简单的代码生成与检查任务,还能应对编写测试、验证结果等复杂项目,为开发者提供全方位支持。根据SWEBench...

AI也能写代码!代码生成、代码补全、注释生成、代码翻译轻松搞定

清华GLM技术团队打造的多语言代码生成模型CodeGeeX近期更新了新的开源版本「CodeGeeX2-6B」。CodeGeeX2是多语言代码生成模型CodeGeeX的第二代模型,不同于一代CodeG...

一键生成前后端代码,一个36k星的企业级低代码平台

「企业级低代码平台」前后端分离架构SpringBoot2.x,SpringCloud,AntDesign&Vue,Mybatis,Shiro,JWT。强大的代码生成器让前后端代码一键生成,无需写任...

Gitee 代码托管实战指南:5 步完成本地项目云端同步(附避坑要点)

核心流程拆解:远程仓库的搭建登录Gitee官网(注册账号比较简单,大家自行操作),点击“新建仓库”,建议勾选“初始化仓库”和“设置模板文件”(如.gitignore),避免上传临时文件。...

jeecg-boot 源码项目-强烈推荐使用

JEECGBOOT低代码开发平台...

JetBrains推出全新AI编程工具Junie,强调以开发者为中心

IT之家2月1日消息,JetBrains发文,宣布推出一款名为Junie的全新AI编程工具,官方声称这款AI工具既能执行简单的代码生成与检查等基础任务,也能应对“编写测试、验证结...

JetBrains旗下WebStorm和Rider现已加入“非商用免费”阵营

IT之家10月25日消息,软件开发商JetBrains今日宣布,旗下WebStorm(JavaScript开发工具)和Rider(.NET开发工具)现已加入“非商用免费”阵营。如果...

谈谈websocket跨域

了解websocketwebsocket是HTML5的新特性,在客户端和服务端提供了一个基于TCP连接的双向通道。...

websocket调试工具

...

利用webSocket实现消息的实时推送

1.什么是webSocketwebSocket实现实现推送消息WebSocket是HTML5开始提供的一种在单个TCP连接上进行全双工通讯的协议。以前的推送技术使用Ajax轮询,浏览器需...

Flutter UI自动化测试技术方案选型与探索

...

为 Go 开发的 WebSocket 库

#记录我的2024#...

「Java基础」Springboot+Websocket的实现后端数据实时推送

这篇文章主要就是实现这个功能,只演示一个基本的案例。使用的是websocket技术。...

【Spring Boot】WebSocket 的 6 种集成方式

介绍...

取消回复欢迎 发表评论: