RNN 循环神经网络
处理序列数据的循环架构 — 从 Vanilla RNN 到 LSTM
- RNN 通过隐藏状态传递实现序列信息的记忆
- LSTM 用门控机制(遗忘门/输入门/输出门)解决长程依赖问题
- 梯度消失/爆炸是 RNN 训练的核心挑战
- xLSTM 通过指数门控和矩阵记忆进一步改进性能
为什么需要 RNN?
现实世界中有很多序列数据:一句话中的词语有先后顺序,股票价格随时间变化,语音信号是连续的波形。 传统的前馈神经网络(MLP)把每个输入当作独立的,无法利用"顺序"信息。
类比:想象你在读一句话 — "我今天去了___"。要预测下一个词,你需要记住前面的每个词。 人类阅读时,大脑会维护一个"上下文记忆"。RNN 做的正是同样的事情。
核心思想:RNN 有一个隐藏状态 h(hidden state),就像一个"记忆"。 每处理一个时间步的输入,它会更新这个记忆,并把它传递给下一个时间步。
当前记忆 = tanh(输入权重 x 当前输入 + 记忆权重 x 上一步记忆 + 偏置)
📊 Vanilla RNN 手推:3 个时间步
我们用简化参数(维度 d=2)来手推 RNN 的 3 个时间步。每个时间步都执行相同的计算。 点击单元格可以编辑输入值!
时间步 t=1
| 输入 xt | 上一步 ht-1 | Wx·xt | Wh·ht-1 | Wx·x + Wh·h + b | tanh 激活 | |
|---|---|---|---|---|---|---|
| 维度 1 | 1.0 | 0.0 | 0.50 | 0.00 | 0.70 | 0.604 |
| 维度 2 | 0.5 | 0.0 | 0.15 | 0.00 | 0.35 | 0.336 |
| 公式 | Wx·x | Wh·h | sum + b | tanh(sum) |
时间步 t=2
| 输入 xt | 上一步 ht-1 | Wx·xt | Wh·ht-1 | Wx·x + Wh·h + b | tanh 激活 | |
|---|---|---|---|---|---|---|
| 维度 1 | 0.8 | 0.604 | 0.40 | 0.18 | 0.78 | 0.652 |
| 维度 2 | 0.3 | 0.336 | 0.09 | 0.10 | 0.39 | 0.371 |
| 公式 | Wx·x | Wh·ht-1 | sum + b | tanh(sum) |
时间步 t=3
| 输入 xt | 上一步 ht-1 | Wx·xt | Wh·ht-1 | Wx·x + Wh·h + b | tanh 激活 | |
|---|---|---|---|---|---|---|
| 维度 1 | 0.6 | 0.652 | 0.30 | 0.20 | 0.70 | 0.604 |
| 维度 2 | 0.9 | 0.371 | 0.27 | 0.11 | 0.58 | 0.523 |
| 公式 | Wx·x | Wh·ht-1 | sum + b | tanh(sum) |
💡 观察:每个时间步使用相同的权重 Wx、Wh 和偏置 b。 这就是"参数共享" — RNN 在每个时间步做相同的计算,只是输入不同。 简化参数:Wx = [0.5, 0.3; 0.3, 0.3], Wh = [0.3, 0.3; 0.3, 0.3], b = [0.2, 0.2]
📊 LSTM 手推:三个门控机制
Vanilla RNN 有梯度消失问题 — 时间步越远,梯度越小,无法学到长期依赖。 LSTM 通过三个"门"来解决这个问题,让信息可以选择性地保留或遗忘。
遗忘门 (Forget Gate)
f = σ(W_f·[h,x] + b_f)
决定丢弃哪些旧记忆
输入门 (Input Gate)
i = σ(W_i·[h,x] + b_i)
决定写入哪些新信息
输出门 (Output Gate)
o = σ(W_o·[h,x] + b_o)
决定输出哪些记忆
下面展示一个时间步的完整 LSTM 计算。点击输入值可编辑:
| 计算步骤 | 公式 | 结果 |
|---|---|---|
| 输入 xt | 用户输入 | 0.5 |
| 上一步 ht-1 | 来自上一步 | 0.3 |
| 上一步 ct-1 | 细胞状态 | 0.8 |
| 遗忘门 f | σ(0.4×h + 0.5×x + 0.1) | 0.668 |
| 输入门 i | σ(0.3×h + 0.6×x + 0.2) | 0.690 |
| 候选值 c̃ | tanh(0.5×h + 0.4×x + 0.1) | 0.462 |
| 输出门 o | σ(0.3×h + 0.4×x + 0.15) | 0.655 |
| 新细胞状态 ct | f × ct-1 + i × c̃ | 0.853 |
| 新隐藏状态 ht | o × tanh(ct) | 0.522 |
💡 关键:细胞状态 c 是 LSTM 的"高速公路",信息可以直接沿 c 传递而不经过非线性变换, 这就是 LSTM 能学到长期依赖的秘密。三个门控制着这条高速公路上的信息流。
🚀 xLSTM 简介 (2024)
xLSTM(Extended Long Short-Term Memory)是 2024 年提出的 LSTM 改进版本, 由 LSTM 的原发明者 Sepp Hochreiter 团队开发。它在保持 LSTM 核心思想的同时,引入了两个关键创新。
指数门控 (Exponential Gating)
将传统 sigmoid 门控替换为指数函数门控,提供更灵活的信息流控制。 指数门控可以更好地处理极小或极大的梯度信号。
矩阵记忆 (Matrix Memory)
引入矩阵形式的内存结构,替代向量形式的细胞状态。 这大幅提升了模型的记忆容量和表达能力,类似 Transformer 中的键值记忆。
LSTM vs xLSTM 对比
| 特性 | LSTM | xLSTM |
|---|---|---|
| 门控函数 | sigmoid (0~1) | 指数门控 (更大范围) |
| 记忆结构 | 向量 (1D) | 矩阵 (2D) |
| 记忆容量 | O(n) | O(n²) |
| 长期依赖 | 好 | 更好 |
| 参数量 | 较少 | 较多 |
📋 RNN vs LSTM vs xLSTM 对比表
| 特性 | Vanilla RNN | LSTM | xLSTM |
|---|---|---|---|
| 门控机制 | 无 | 3 个 sigmoid 门 | 指数门控 |
| 梯度消失 | 严重 | 缓解 | 进一步缓解 |
| 长期记忆 | 差 | 好 | 更好 |
| 参数量 | 少 | 多 (约 4x) | 更多 |
| 记忆结构 | 向量 h | 向量 h + c | 矩阵记忆 |
| 典型应用 | 短序列 | NLP, 时间序列 | 长序列建模 |
| 提出年份 | 1986 | 1997 | 2024 |
🎬 RNN 时间步展开动画
点击"播放"按钮,观察数据如何在 3 个时间步中流动。隐藏状态 h 从左传递到右。
🎮 互动实验:调整初始隐藏状态
拖动滑块改变初始隐藏状态 h₀,观察它如何影响 3 个时间步的输出。
💡 观察:
- • 初始 h₀ 越大,后续时间步的 h 值也越大
- • tanh 函数会把值压缩到 (-1, 1) 范围
- • 经过几个时间步后,初始值的影响逐渐被"记忆稀释"
- • 这就是梯度消失的直观表现 — 远距离信号衰减
💻 PyTorch 代码
import torch
import torch.nn as nn
# Vanilla RNN
rnn = nn.RNN(
input_size=10, # 输入维度
hidden_size=20, # 隐藏状态维度
num_layers=1, # RNN 层数
batch_first=True
)
# 输入: (batch, seq_len, input_size)
x = torch.randn(3, 5, 10) # batch=3, seq_len=5
h0 = torch.zeros(1, 3, 20) # 初始隐藏状态
# 前向传播
output, hn = rnn(x, h0)
print(output.shape) # torch.Size([3, 5, 20])
print(hn.shape) # torch.Size([1, 3, 20])
# LSTM — 多了细胞状态 c
lstm = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=2, # 2 层 LSTM
batch_first=True,
dropout=0.1
)
x = torch.randn(3, 5, 10)
h0 = torch.zeros(2, 3, 20) # 初始 h
c0 = torch.zeros(2, 3, 20) # 初始 c
# LSTM 返回 (output, (hn, cn))
output, (hn, cn) = lstm(x, (h0, c0))
print(output.shape) # torch.Size([3, 5, 20])
print(hn.shape) # torch.Size([2, 3, 20])
print(cn.shape) # torch.Size([2, 3, 20])
# 常见用法:取最后时间步的输出做分类
last_hidden = output[:, -1, :] # (3, 20)
🧠 小测验
- Transformer — Transformer 是 RNN 的替代方案
- Loss & 反向传播 — BPTT(反向传播穿越时间)的基础
- LSTM — LSTM 用门控机制解决长程依赖
1. RNN 的"记忆"机制是什么?
2. LSTM 的三个门分别是什么?
3. 梯度消失问题如何缓解?
📝 总结
Vanilla RNN
通过隐藏状态 h 传递记忆,参数在时间步之间共享。简单但有梯度消失问题。
LSTM
三个门控 + 细胞状态,让信息可以长距离传播。解决了 RNN 的梯度消失问题。
xLSTM
指数门控 + 矩阵记忆,进一步提升长期记忆能力和表达能力。
核心思想
RNN 系列的核心是"记忆传递" — 让神经网络能处理有序的序列数据。