PyTorch深度学习框架基础——实现循环神经网络,并解析
yuyutoo 2025-03-06 21:01 1 浏览 0 评论
使用PyTorch实现一个简单的循环神经网络(Recurrent Neural Network, RNN),并对代码进行详细解析。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
# 示例数据
X = np.array([
[1, 2, 3, 4, 5],
[5, 4, 3, 2, 1],
[2, 3, 4, 5, 6],
[6, 5, 4, 3, 2]
], dtype=np.float32)
y = np.array([0, 1, 0, 1], dtype=np.float32)
X_tensor = torch.tensor(X)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding = nn.Embedding(input_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
embedded = self.embedding(x)
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, (hn, cn) = self.lstm(embedded, (h0, c0))
out = out[:, -1, :]
out = self.fc(out)
return out
# 超参数
input_size = 10 # 词汇表大小
hidden_size = 128
num_layers = 2
num_classes = 2 # 二分类
model = SimpleRNN(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 20
for epoch in range(num_epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
代码解析
数据准备部分
- 数据创建:创建了一个简单的数据集,其中每个输入序列由5个整数组成,对应的标签为0或1。
- 张量转换:将NumPy数组转换为PyTorch张量,以便进行后续的计算。
- 数据集和数据加载器:使用TensorDataset和DataLoader将数据组织成批次,便于训练。
模型定义部分
- 嵌入层(Embedding Layer):
- nn.Embedding(input_size, hidden_size):将输入的整数(词汇表中的索引)转换为向量表示。这里,input_size是词汇表的大小,hidden_size是嵌入向量的维度。
- RNN层(LSTM):
- nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True):使用LSTM作为RNN的实现。hidden_size是LSTM隐藏层的维度,num_layers是LSTM层的数量,batch_first=True表示输入和输出的张量第一个维度是批次大小。
- 全连接层(Fully Connected Layer):
- nn.Linear(hidden_size, num_classes):将LSTM的输出映射到最终的分类结果。
- 前向传播:
- 嵌入:将输入序列中的每个整数转换为向量。
- 初始化隐藏状态和细胞状态:为LSTM初始化隐藏状态和细胞状态,均为全零张量。
- LSTM前向传播:将嵌入后的序列输入到LSTM中,得到输出和最终的隐藏状态。
- 取最后一个时间步的输出:由于我们关注的是整个序列的表示,因此取LSTM输出的最后一个时间步的输出。
- 全连接层:将LSTM的输出通过全连接层,得到最终的分类结果。
训练部分
- 损失函数:使用交叉熵损失(nn.CrossEntropyLoss)进行分类任务的损失计算。
- 优化器:使用Adam优化器(optim.Adam)进行参数更新。
- 训练循环:遍历每个批次的数据。前向传播计算输出和损失。反向传播计算梯度。更新模型参数。每经过一个epoch,打印当前的损失。
注意事项
- 数据预处理:在实际应用中,输入数据需要进行适当的预处理,例如词汇表构建、序列填充等。
- 超参数调整:隐藏层大小、层数、学习率等超参数对模型性能有重要影响,需要根据具体任务进行调整。
- 模型评估:训练完成后,需要在验证集或测试集上评估模型性能,以防止过拟合。
相关推荐
- 网站建设:从新手到高手
-
现代化网站应用领域非常广泛,从个人形象网站展示、企业商业网站运作、到政府公益等服务网站,各行各业都需要网站建设。大体上可以归结四类:宣传型网站设计、产品型网站制作、电子商务型网站建设、定制型功能网站开...
- JetBrains 推出全新 AI 编程工具 Junie,助力高效开发
-
JetBrains宣布推出名为Junie的全新AI编程工具。这款工具不仅能执行简单的代码生成与检查任务,还能应对编写测试、验证结果等复杂项目,为开发者提供全方位支持。根据SWEBench...
- AI也能写代码!代码生成、代码补全、注释生成、代码翻译轻松搞定
-
清华GLM技术团队打造的多语言代码生成模型CodeGeeX近期更新了新的开源版本「CodeGeeX2-6B」。CodeGeeX2是多语言代码生成模型CodeGeeX的第二代模型,不同于一代CodeG...
- 一键生成前后端代码,一个36k星的企业级低代码平台
-
「企业级低代码平台」前后端分离架构SpringBoot2.x,SpringCloud,AntDesign&Vue,Mybatis,Shiro,JWT。强大的代码生成器让前后端代码一键生成,无需写任...
- Gitee 代码托管实战指南:5 步完成本地项目云端同步(附避坑要点)
-
核心流程拆解:远程仓库的搭建登录Gitee官网(注册账号比较简单,大家自行操作),点击“新建仓库”,建议勾选“初始化仓库”和“设置模板文件”(如.gitignore),避免上传临时文件。...
- jeecg-boot 源码项目-强烈推荐使用
-
JEECGBOOT低代码开发平台...
- JetBrains推出全新AI编程工具Junie,强调以开发者为中心
-
IT之家2月1日消息,JetBrains发文,宣布推出一款名为Junie的全新AI编程工具,官方声称这款AI工具既能执行简单的代码生成与检查等基础任务,也能应对“编写测试、验证结...
- JetBrains旗下WebStorm和Rider现已加入“非商用免费”阵营
-
IT之家10月25日消息,软件开发商JetBrains今日宣布,旗下WebStorm(JavaScript开发工具)和Rider(.NET开发工具)现已加入“非商用免费”阵营。如果...
- 谈谈websocket跨域
-
了解websocketwebsocket是HTML5的新特性,在客户端和服务端提供了一个基于TCP连接的双向通道。...
- websocket调试工具
-
...
- 利用webSocket实现消息的实时推送
-
1.什么是webSocketwebSocket实现实现推送消息WebSocket是HTML5开始提供的一种在单个TCP连接上进行全双工通讯的协议。以前的推送技术使用Ajax轮询,浏览器需...
- 为 Go 开发的 WebSocket 库
-
#记录我的2024#...
- 「Java基础」Springboot+Websocket的实现后端数据实时推送
-
这篇文章主要就是实现这个功能,只演示一个基本的案例。使用的是websocket技术。...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- mybatis plus (70)
- scheduledtask (71)
- css滚动条 (60)
- java学生成绩管理系统 (59)
- 结构体数组 (69)
- databasemetadata (64)
- javastatic (68)
- jsp实用教程 (53)
- fontawesome (57)
- widget开发 (57)
- vb net教程 (62)
- hibernate 教程 (63)
- case语句 (57)
- svn连接 (74)
- directoryindex (69)
- session timeout (58)
- textbox换行 (67)
- extension_dir (64)
- linearlayout (58)
- vba高级教程 (75)
- iframe用法 (58)
- sqlparameter (59)
- trim函数 (59)
- flex布局 (63)
- contextloaderlistener (56)