STAGE 3 科学智能

AlphaFold 蛋白质结构预测

用 AI 预测蛋白质的 3D 结构 — 解决生物学 50 年难题

📊 Excel 手推 ⏱️ 30 分钟 🎯 蛋白质结构预测
1
📋 本章要点
  • AlphaFold 解决了蛋白质从氨基酸序列预测 3D 结构的难题
  • Evoformer 模块融合多序列比对(MSA)和配对表示
  • 结构模块从 Evoformer 输出直接预测原子坐标
  • 预测精度达到实验级别,加速了药物研发和生命科学研究

蛋白质折叠问题

蛋白质是生命的"分子机器"。一个蛋白质由一串氨基酸序列折叠成特定的 3D 结构,而结构决定功能

核心问题:给定一个氨基酸序列(如 MVLSPADKTNVK...),能否预测它的 3D 空间结构? 这个问题困扰了生物学家 50 年

🧬
氨基酸序列
一级结构:一维字符串
MVLSPADKTNVK...
折叠
物理力驱动的折叠过程
氢键、疏水作用、范德华力
🔬
3D 结构
三级结构:空间坐标
每个原子的 (x, y, z) 位置

Levinthal 悖论

蛋白质折叠是一个极其困难的搜索问题。一条 100 个氨基酸的蛋白质,每个残基有约 3 种可能的构象状态。

可能的构象总数 = 3100 ≈ 5 × 1047

如果随机尝试,每种构象花 10-13 秒:

需要约 1027 年 — 比宇宙年龄 (138 亿年) 还长 1017 倍!

但自然界中,蛋白质在毫秒到秒级别就能完成折叠。这说明折叠不是随机搜索,而是沿着一个漏斗状的能量景观(energy landscape)快速下降。 AlphaFold 的突破在于用深度学习直接从序列预测结构,绕过了物理模拟的巨大计算量。

为什么重要?

💊 药物设计

知道蛋白质结构就能设计能与其结合的药物分子,加速新药研发。

🦠 疾病理解

很多疾病(如阿尔茨海默症)与蛋白质错误折叠有关,理解结构有助于理解病因。

🌱 酶工程

设计新型酶用于工业催化、生物燃料、塑料降解等应用。

🧬 基础科学

蛋白质结构数据库 (PDB) 已有 20 万+ 结构,AlphaFold 预测了 2 亿+ 结构。

2

AlphaFold 的输入特征 — Excel 手推

AlphaFold 的输入不是原始文本,而是经过精心设计的特征矩阵。 点击单元格可以编辑输入值!

Step 2a:氨基酸序列编码 (One-Hot)

自然界有 20 种标准氨基酸,每种用一个 one-hot 向量表示。我们用 5 种氨基酸简化演示:

氨基酸 One-Hot 编码(简化为 5 种)
位置 氨基酸 A (Ala) G (Gly) V (Val) L (Leu) I (Ile)
1 A 1 0 0 0 0
2 G 0 1 0 0 0
3 V 0 0 1 0 0
4 L 0 0 0 1 0
5 A 1 0 0 0 0

输入特征矩阵形状:序列长度 × 20 (实际为 20 种氨基酸)

Step 2b:多序列比对 (MSA) — 同源序列信息

MSA(Multiple Sequence Alignment)是 AlphaFold 最重要的输入特征之一。 通过在数据库中搜索与目标序列同源的已知序列,将它们对齐排列,可以发现哪些位置在进化中被保守保留。

MSA 矩阵(同一蛋白质在不同物种中的序列比对)
序列来源 位置 1 位置 2 位置 3 位置 4 位置 5
目标序列 (人类) A G V L A
小鼠 A G V I A
果蝇 A G L L S
酵母 S G V L A

MSA 告诉我们什么?

  • 位置 2 (G):在所有物种中完全保守 → 该位置对蛋白质功能至关重要
  • 位置 4 (L/I):在不同物种中有变异 → 该位置允许一定灵活性
  • • MSA 提供了共进化信号:如果位置 i 和 j 同时突变,说明它们在 3D 空间中可能靠近

Step 2c:配对表示 (Pair Representation)

除了 MSA 表示,AlphaFold 还维护一个配对表示矩阵,记录每对残基之间的关系。 这个矩阵类似一张"距离图",初始值来自序列上的相对位置和 MSA 的共进化信息。

配对表示矩阵 (简化: 序列距离)
残基 1 (A) 残基 2 (G) 残基 3 (V) 残基 4 (L) 残基 5 (A)
残基 1 (A) 0 1 2 3 4
残基 2 (G) 1 0 1 2 3
残基 3 (V) 2 1 0 1 2
残基 4 (L) 3 2 1 0 1
残基 5 (A) 4 3 2 1 0

初始配对表示:序列距离 + 相对位置编码。经过 Evoformer 处理后,这个矩阵会编码残基之间的空间距离关系。

3

Evoformer 模块 — 核心架构

Evoformer 是 AlphaFold 最核心的创新模块。它在 MSA 表示和配对表示之间进行双向信息交换,让模型同时理解序列进化信息和残基间的空间关系。

Evoformer Block 数据流

MSA 表示
形状: (N_seq, N_res, d_msa)
多序列比对的特征矩阵
配对表示
形状: (N_res, N_res, d_pair)
残基对之间的关系矩阵
Evoformer Block × 48

MSA 表示处理

行注意力 (沿序列维度)
列注意力 (沿 MSA 维度)
MSA → 配对表示 (三角更新)

配对表示处理

三角注意力 (端点更新)
三角注意力 (起始点更新)
配对表示 → MSA 表示
更新后的 MSA 表示
融合了进化和结构信息
更新后的配对表示
编码了残基间的空间关系
MSA 行注意力

沿序列维度做注意力,让同一序列中的不同残基相互交流信息,类似于标准的 Transformer 自注意力。

MSA 列注意力

沿 MSA 维度做注意力,让不同物种的同源残基相互交流,捕获进化保守性信息。

三角注意力更新

配对表示满足三角不等式:如果残基 i 靠近 j,j 靠近 k,那么 i 和 k 的距离也有约束。这通过三角注意力来实现。

双向信息交换

MSA 和配对表示之间持续交换信息:MSA 的共进化信号更新配对矩阵,配对矩阵又指导 MSA 的注意力分布。

4

结构模块 — Excel 手推

结构模块将 Evoformer 输出的特征转化为 3D 原子坐标。其核心是不变点注意力 (Invariant Point Attention, IPA)点击单元格可以编辑输入值!

Step 4a:刚体表示

每个残基用一个刚体 (rigid body) 表示:一个旋转矩阵 R 和一个平移向量 t。 这保证了变换的等变性(旋转和平移不变)。

残基刚体参数 (简化为 2D)
残基 旋转角 θ (度) 平移 x 平移 y Cα 坐标 (x, y)
1 (A) 10 0.0 0.0 (0.00, 0.00)
2 (G) 25 3.8 0.0 (3.80, 0.00)
3 (V) -15 3.8 0.0 (7.60, 0.00)
4 (L) 5 3.8 0.0 (11.40, 0.00)
5 (A) -30 3.8 0.0 (15.20, 0.00)

每个残基的局部坐标系通过旋转和平移累积到全局坐标。相邻 Cα 之间的距离约 3.8 Å。

Step 4b:不变点注意力 (IPA)

IPA 是结构模块的关键创新。它在注意力计算中同时使用特征空间3D 空间的信息, 保证输出对旋转和平移是等变的。

IPA 注意力计算:

αij = softmaxj(qiT kj + qiT · pjpoint)

特征注意力分数 + 3D 空间中点的注意力分数

IPA 注意力权重 (简化)
Query \ Key 残基 1 残基 2 残基 3 残基 4
残基 1 → 0.40 0.30 0.20 0.10
残基 2 → 0.25 0.35 0.25 0.15
残基 3 → 0.15 0.25 0.35 0.25
残基 4 → 0.10 0.20 0.30 0.40

每个残基主要关注自身和空间上相邻的残基,注意力权重由特征相似度和 3D 距离共同决定。

Step 4c:坐标预测流程

Evoformer 输出
MSA + 配对特征
初始化刚体
R, t 参数
IPA 注意力
特征+空间
更新刚体
新的 R, t
原子坐标
(x, y, z)

结构模块迭代 8 次,逐步精化坐标预测。

5

互动实验 — 注意力可视化

模拟 Evoformer 中 MSA 行注意力的权重分布。拖动滑块选择不同的 Query 残基位置,观察注意力权重如何变化。

1
1.0
0.5

观察:

  • 温度 τ 越低,注意力越集中在最相关的残基上
  • 序列距离衰减越大,近邻残基获得更高权重
  • 在真实 AlphaFold 中,注意力权重还会受 3D 空间距离影响
  • 注意对角线上的权重(自注意力)通常最高
6

代码实现 (PyTorch)

简化版 Evoformer block,展示 MSA 注意力和配对表示更新的核心逻辑。

evoformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MSARowAttention(nn.Module):
    """MSA 行注意力:沿序列维度做注意力"""
    def __init__(self, d_model=64, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, msa):
        # msa: (batch, N_seq, N_res, d_model)
        B, S, R, D = msa.shape
        # 对每个 MSA 序列,沿残基维度做注意力
        qkv = self.qkv(msa).reshape(B, S, R, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(3)  # 每个 (B, S, R, heads, d_k)

        # 沿残基维度做注意力 (R 维度)
        scores = torch.einsum('bsihd,bsjhd->bsij', q, k) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bsij,bsjhd->bsihd', attn, v)
        return self.out(out.reshape(B, S, R, D))


class MSAColumnAttention(nn.Module):
    """MSA 列注意力:沿 MSA 维度做注意力"""
    def __init__(self, d_model=64, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, msa):
        # msa: (batch, N_seq, N_res, d_model)
        B, S, R, D = msa.shape
        msa_t = msa.transpose(1, 2)  # (B, R, S, D) — 换轴
        qkv = self.qkv(msa_t).reshape(B, R, S, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(3)

        # 沿 MSA 序列维度做注意力 (S 维度)
        scores = torch.einsum('brihd,brjhd->brij', q, k) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('brij,brjhd->brihd', attn, v)
        return self.out(out.reshape(B, R, S, D)).transpose(1, 2)


class OuterProductMean(nn.Module):
    """MSA → 配对表示:外积均值"""
    def __init__(self, d_msa=64, d_pair=128, d_hidden=32):
        super().__init__()
        self.linear_a = nn.Linear(d_msa, d_hidden)
        self.linear_b = nn.Linear(d_msa, d_hidden)
        self.linear_out = nn.Linear(d_hidden * d_hidden, d_pair)

    def forward(self, msa):
        # msa: (B, S, R, d_msa)
        a = self.linear_a(msa)  # (B, S, R, d_hidden)
        b = self.linear_b(msa)  # (B, S, R, d_hidden)
        # 外积沿 MSA 维度求均值 → (B, R, R, d_h, d_h)
        outer = torch.einsum('bsid,bsjd->bijcd', a, b)
        outer = outer.mean(dim=1)  # 平均 MSA 维度
        B, R1, R2, H1, H2 = outer.shape
        return self.linear_out(outer.reshape(B, R1, R2, H1 * H2))


class EvoformerBlock(nn.Module):
    """一个 Evoformer Block"""
    def __init__(self, d_msa=64, d_pair=128, n_heads=8):
        super().__init__()
        self.msa_row_attn = MSARowAttention(d_msa, n_heads)
        self.msa_col_attn = MSAColumnAttention(d_msa, n_heads)
        self.outer_product = OuterProductMean(d_msa, d_pair)
        self.pair_update = nn.Sequential(
            nn.Linear(d_pair, d_pair),
            nn.ReLU(),
            nn.Linear(d_pair, d_pair),
        )
        self.norm_msa = nn.LayerNorm(d_msa)
        self.norm_pair = nn.LayerNorm(d_pair)

    def forward(self, msa_repr, pair_repr):
        # MSA 处理
        msa_repr = msa_repr + self.msa_row_attn(msa_repr)
        msa_repr = msa_repr + self.msa_col_attn(msa_repr)
        msa_repr = self.norm_msa(msa_repr)

        # MSA → 配对表示
        pair_repr = pair_repr + self.outer_product(msa_repr)
        pair_repr = pair_repr + self.pair_update(pair_repr)
        pair_repr = self.norm_pair(pair_repr)

        return msa_repr, pair_repr


# 使用示例
block = EvoformerBlock(d_msa=64, d_pair=128)
msa = torch.randn(1, 16, 50, 64)   # 16 条序列, 50 个残基
pair = torch.randn(1, 50, 50, 128) # 50×50 配对矩阵
msa_out, pair_out = block(msa, pair)
print(f"MSA: {msa_out.shape}, Pair: {pair_out.shape}")
7

小测验

🔗 相关章节推荐
  • TransformerEvoformer 基于 Transformer 架构
  • CNN结构预测中的空间特征提取
  • VAE结构模块中的概率建模思想

1. Levinthal 悖论说明了什么?

2. 多序列比对 (MSA) 在 AlphaFold 中的作用是什么?

3. Evoformer 模块的核心创新是什么?

Easy Deep Learning · 用 Excel 手推理解每一个公式