RNN 与 LSTM 的 Python 实战对比揭秘

引言

假如现在你在读一段文本,当你看到这样一句话:

“小明打开了门,他走了______。”

如果让我们来填空,我们应该都会填“进去”,这是因为我们的大脑记住了前面“打开了门”这个信息,从而根据这个信息做出了判断。而对于机器来说,如果能让它们也拥有类似人类的记忆,就能做出更正确的预测。于是循环神经网络应运而生,而在之后,由于它的局限性,又发展出来了更为智能的长短期记忆网络。


RNN:基础版记忆网络

循环神经网络(Recurrent Neural Network, 简称 RNN)是一类专门用于处理序列数据的人工神经网络模型,其核心特点在于具备“记忆”能力:通过隐藏状态在时间维度上传递信息,使模型能够捕捉上下文中的时序依赖关系。与传统的前馈神经网络不同,RNN 采用循环连接结构,即当前时刻的输出不仅依赖于当前输入,还依赖于前一时刻的隐藏状态,从而能够对动态序列中的前后关联进行建模。

主要缺陷:只能记住最近几步的信息;处理长序列时易遗忘开头内容。

RNN模型构建

import torch
import torch.nn as nn
# 简单RNN实现
class SimpleRNN(nn.Module):
     def _init_(self, input_size, hidden _size):
         super().__init__()
         self.rnn 
= nn.RNN(input_size, hidden_size,
batch first=True)
     def forward(self, x, hidden):
         return self.rnn(x,hidden)
         
# RNN的问题:处理长序列时容易"
遗忘rnn =nn.RNN(10,16)
long_input=torch.randn(1,50,10)#长度50的序列
hidden = torch.zeros(1,1,16)
output, hidden = rnn(long input, hidden)
print(f"RNN输出形状:{output.shape}")
#问题:序列后面的步骤几乎忘记了开头的信息

LSTM:解决遗忘问题的智能RNN

长短期网络记忆(Long Short-Term Memory,简称LSTM)是一种时间循环神经网络,是为了解决一般的RNN存在的长期依赖问题而专门设计出来的,讨论所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。其适合于处理和预测时间序列中间隔和延迟非常长的重要事件。

LSTM模型构建

#LSTM = RNN + 三个智能门
class SimpleLSTM(nn.Module):
     def __init__ (self, input size, hidden_size):
         super()._init ()
         self.lstm =nn.LSTM(input size, hidden size,
batch first=True)
    def forward(self,x, hidden):
        return self.lstm(x, hidden)
        
#LSTM的三个门:
# 1.遗忘门:决定忘记什么
#2.输入门:决定记住什么
#3.输出门:决定输出什么
lstm = nn.LSTM(10,16)
output,(hidden, cell)= lstm(long_input)
print(f"LSTM输出形状:{output.shape}")
print(f"LSTM有细胞状态:{cell.shape}")#这是长期记忆的关键

实战对比

训练数据准备

text = "hello world hello python hello deep learning'
chars = list(set(text))
char to idx =fch:i for i,ch in enumerate(chars)}

训练结果对比

训练100轮后:
RNN损失: 1.8523(收敛慢)
(收敛快)LSTM损失:0.9234
生成文本:
RNN生成:hello world hello pythooooo # 容易重复
LSTM生成:hello world hello python hello # 更连贯

RNN与LSTM的区别

基本结构差异

RNN的核心是一个简单的循环单元,每个时间步的隐藏状态仅由当前输入和前一时刻的隐藏状态共同决定。这种结构在短序列任务中表现良好,但缺乏对长期信息的精细控制。而LSTM在RNN的基础上增加了三个关键门控结构(‌遗忘门、‌输入门、‌输出门),通过门控信号动态调节信息的保留与遗忘。

长期依赖问题的处理能力

RNN在处理长序列时容易因梯度连乘导致梯度消失或爆炸,使得模型难以学习远距离的时序依赖。LSTM则通过以下机制显著缓解这一问题:

‌‌细胞状态(Cell State)‌:作为一条“直通路径”,允许信息在不同时间步间直接传递,减少梯度衰减

‌‌门控的线性调控‌:遗忘门和输入门通过sigmoid函数输出0-1的权重,与细胞状态进行逐元素相乘,避免梯度在反向传播中剧烈变化。‌‌

性能与计算成本

计算复杂度‌:LSTM因门控结构参数量更多,训练和推理的计算成本通常高于RNN‌‌

任务适应性‌:对于简单序列任务(如短文本分类),RNN可能因结构简单而更高效;但对于需要长期记忆的任务(如‌机器翻译、‌语音识别),LSTM的性能优势明显。


总结

本次推文介绍了RNN与LSTM,并通过Python代码对两者进行了实战对比。‌RNN和‌LSTM的核心区别在于‌LSTM通过引入‌门控机制,有效解决了RNN在处理长序列数据时的‌梯度消失问题,从而能够更好地捕捉长期依赖关系。这两者都是处理序列数据的利器,而LSTM因其强大的记忆功能,成为了更主流的选择。


参考资料:百度词条——“循环神经网络”、“长短期记忆人工神经网络”

本文转自:SUIBE数据科学系,文案:贾婧怡,转载此文目的在于传递更多信息,版权归原作者所有。如不支持转载,请联系小编demi@eetrend.com删除。

最新文章