STAGE 1 序列建模

RNN 循环神经网络

处理序列数据的循环架构 — 从 Vanilla RNN 到 LSTM

📊 Excel 手推 ⏱️ 25 分钟 🎯 序列建模
1
📋 本章要点
  • RNN 通过隐藏状态传递实现序列信息的记忆
  • LSTM 用门控机制(遗忘门/输入门/输出门)解决长程依赖问题
  • 梯度消失/爆炸是 RNN 训练的核心挑战
  • xLSTM 通过指数门控和矩阵记忆进一步改进性能

为什么需要 RNN?

现实世界中有很多序列数据:一句话中的词语有先后顺序,股票价格随时间变化,语音信号是连续的波形。 传统的前馈神经网络(MLP)把每个输入当作独立的,无法利用"顺序"信息。

类比:想象你在读一句话 — "我今天去了___"。要预测下一个词,你需要记住前面的每个词。 人类阅读时,大脑会维护一个"上下文记忆"。RNN 做的正是同样的事情。

核心思想:RNN 有一个隐藏状态 h(hidden state),就像一个"记忆"。 每处理一个时间步的输入,它会更新这个记忆,并把它传递给下一个时间步。

ht = tanh( Wx · xt + Wh · ht-1 + b)

当前记忆 = tanh(输入权重 x 当前输入 + 记忆权重 x 上一步记忆 + 偏置)

2

📊 Vanilla RNN 手推:3 个时间步

我们用简化参数(维度 d=2)来手推 RNN 的 3 个时间步。每个时间步都执行相同的计算。 点击单元格可以编辑输入值!

时间步 t=1

输入 xt 上一步 ht-1 Wx·xt Wh·ht-1 Wx·x + Wh·h + b tanh 激活
维度 1 1.0 0.0 0.50 0.00 0.70 0.604
维度 2 0.5 0.0 0.15 0.00 0.35 0.336
公式 Wx·x Wh·h sum + b tanh(sum)

时间步 t=2

输入 xt 上一步 ht-1 Wx·xt Wh·ht-1 Wx·x + Wh·h + b tanh 激活
维度 1 0.8 0.604 0.40 0.18 0.78 0.652
维度 2 0.3 0.336 0.09 0.10 0.39 0.371
公式 Wx·x Wh·ht-1 sum + b tanh(sum)

时间步 t=3

输入 xt 上一步 ht-1 Wx·xt Wh·ht-1 Wx·x + Wh·h + b tanh 激活
维度 1 0.6 0.652 0.30 0.20 0.70 0.604
维度 2 0.9 0.371 0.27 0.11 0.58 0.523
公式 Wx·x Wh·ht-1 sum + b tanh(sum)

💡 观察:每个时间步使用相同的权重 Wx、Wh 和偏置 b。 这就是"参数共享" — RNN 在每个时间步做相同的计算,只是输入不同。 简化参数:Wx = [0.5, 0.3; 0.3, 0.3], Wh = [0.3, 0.3; 0.3, 0.3], b = [0.2, 0.2]

3

📊 LSTM 手推:三个门控机制

Vanilla RNN 有梯度消失问题 — 时间步越远,梯度越小,无法学到长期依赖。 LSTM 通过三个"门"来解决这个问题,让信息可以选择性地保留或遗忘。

🚪

遗忘门 (Forget Gate)

f = σ(W_f·[h,x] + b_f)

决定丢弃哪些旧记忆

📝

输入门 (Input Gate)

i = σ(W_i·[h,x] + b_i)

决定写入哪些新信息

📤

输出门 (Output Gate)

o = σ(W_o·[h,x] + b_o)

决定输出哪些记忆

下面展示一个时间步的完整 LSTM 计算。点击输入值可编辑:

计算步骤 公式 结果
输入 xt 用户输入 0.5
上一步 ht-1 来自上一步 0.3
上一步 ct-1 细胞状态 0.8
遗忘门 f σ(0.4×h + 0.5×x + 0.1) 0.668
输入门 i σ(0.3×h + 0.6×x + 0.2) 0.690
候选值 c̃ tanh(0.5×h + 0.4×x + 0.1) 0.462
输出门 o σ(0.3×h + 0.4×x + 0.15) 0.655
新细胞状态 ct f × ct-1 + i × c̃ 0.853
新隐藏状态 ht o × tanh(ct) 0.522

💡 关键:细胞状态 c 是 LSTM 的"高速公路",信息可以直接沿 c 传递而不经过非线性变换, 这就是 LSTM 能学到长期依赖的秘密。三个门控制着这条高速公路上的信息流。

4

🚀 xLSTM 简介 (2024)

xLSTM(Extended Long Short-Term Memory)是 2024 年提出的 LSTM 改进版本, 由 LSTM 的原发明者 Sepp Hochreiter 团队开发。它在保持 LSTM 核心思想的同时,引入了两个关键创新。

📈

指数门控 (Exponential Gating)

将传统 sigmoid 门控替换为指数函数门控,提供更灵活的信息流控制。 指数门控可以更好地处理极小或极大的梯度信号。

🧮

矩阵记忆 (Matrix Memory)

引入矩阵形式的内存结构,替代向量形式的细胞状态。 这大幅提升了模型的记忆容量和表达能力,类似 Transformer 中的键值记忆。

LSTM vs xLSTM 对比

特性 LSTM xLSTM
门控函数 sigmoid (0~1) 指数门控 (更大范围)
记忆结构 向量 (1D) 矩阵 (2D)
记忆容量 O(n) O(n²)
长期依赖 更好
参数量 较少 较多
5

📋 RNN vs LSTM vs xLSTM 对比表

特性 Vanilla RNN LSTM xLSTM
门控机制 3 个 sigmoid 门 指数门控
梯度消失 严重 缓解 进一步缓解
长期记忆 更好
参数量 多 (约 4x) 更多
记忆结构 向量 h 向量 h + c 矩阵记忆
典型应用 短序列 NLP, 时间序列 长序列建模
提出年份 1986 1997 2024
6

🎬 RNN 时间步展开动画

点击"播放"按钮,观察数据如何在 3 个时间步中流动。隐藏状态 h 从左传递到右。

x₁ = [1.0, 0.5]
x₂ = [0.8, 0.3]
x₃ = [0.6, 0.9]
t = 1 h₁
t = 2 h₂
t = 3 h₃
h₁ = [0.60, 0.34]
h₂ = [0.65, 0.37]
h₃ = [0.60, 0.52]
7

🎮 互动实验:调整初始隐藏状态

拖动滑块改变初始隐藏状态 h₀,观察它如何影响 3 个时间步的输出。

0.0
0.0

💡 观察:

  • • 初始 h₀ 越大,后续时间步的 h 值也越大
  • • tanh 函数会把值压缩到 (-1, 1) 范围
  • • 经过几个时间步后,初始值的影响逐渐被"记忆稀释"
  • • 这就是梯度消失的直观表现 — 远距离信号衰减
8

💻 PyTorch 代码

rnn_example.py
import torch
import torch.nn as nn

# Vanilla RNN
rnn = nn.RNN(
    input_size=10,    # 输入维度
    hidden_size=20,   # 隐藏状态维度
    num_layers=1,     # RNN 层数
    batch_first=True
)

# 输入: (batch, seq_len, input_size)
x = torch.randn(3, 5, 10)  # batch=3, seq_len=5
h0 = torch.zeros(1, 3, 20)  # 初始隐藏状态

# 前向传播
output, hn = rnn(x, h0)
print(output.shape)  # torch.Size([3, 5, 20])
print(hn.shape)      # torch.Size([1, 3, 20])
lstm_example.py
# LSTM — 多了细胞状态 c
lstm = nn.LSTM(
    input_size=10,
    hidden_size=20,
    num_layers=2,     # 2 层 LSTM
    batch_first=True,
    dropout=0.1
)

x = torch.randn(3, 5, 10)
h0 = torch.zeros(2, 3, 20)  # 初始 h
c0 = torch.zeros(2, 3, 20)  # 初始 c

# LSTM 返回 (output, (hn, cn))
output, (hn, cn) = lstm(x, (h0, c0))
print(output.shape)  # torch.Size([3, 5, 20])
print(hn.shape)      # torch.Size([2, 3, 20])
print(cn.shape)      # torch.Size([2, 3, 20])

# 常见用法:取最后时间步的输出做分类
last_hidden = output[:, -1, :]  # (3, 20)
9

🧠 小测验

🔗 相关章节推荐

1. RNN 的"记忆"机制是什么?

2. LSTM 的三个门分别是什么?

3. 梯度消失问题如何缓解?

10

📝 总结

🔄

Vanilla RNN

通过隐藏状态 h 传递记忆,参数在时间步之间共享。简单但有梯度消失问题。

🚪

LSTM

三个门控 + 细胞状态,让信息可以长距离传播。解决了 RNN 的梯度消失问题。

🚀

xLSTM

指数门控 + 矩阵记忆,进一步提升长期记忆能力和表达能力。

🎯

核心思想

RNN 系列的核心是"记忆传递" — 让神经网络能处理有序的序列数据。