rwkv笔记

rwkv

transformers

Published on

RWKV-V4: https://arxiv.org/abs/2305.13048 RWKV-v5-Eagle: https://arxiv.org/abs/2404.05892 RWKV-v6-Finch: https://arxiv.org/abs/2404.05892 RWKV-v7-Gooserwkv7, RWKV-v7-G1: https://arxiv.org/abs/2503.14456

RWKV 的名字来源于其时间混合(Time-Mixing)模块中的四个核心基础向量组件:

  • R (Receptance,接受度):其作用类似于传统注意力机制中的 Query,用来决定模型在当前时间步应该接收和放行多少过去的信息。
  • W (Weight,权重):这是一个位置权重衰减向量(可训练参数)。它控制着过去的信息随着时间推移如何呈指数级衰减。
  • K (Key,键):类似于传统注意力机制中的 K,代表当前输入包含的特征信息。
  • V (Value,值):类似于传统注意力机制中的 V,代表当前输入所携带的实际内容信息。

输入

和绝大多数自然语言模型一样,RWKV 不直接理解文字,它需要先将文本转化为数字向量。

  • 分词器 (Tokenizer):RWKV 采用了特制的 RWKV World Tokenizer。它使用基于前缀树(Trie)的贪婪匹配算法,词表大小为 65536。这种设计特别优化了多语言(如中文、阿拉伯语等非欧洲语言)和代码数据的处理效率,避免了传统 BPE 分词器在小语种上效率低下的问题。
  • 小初始化嵌入 (Small Init Embedding):文本被转换为 Token 后进入嵌入层(Embedding)。RWKV 使用了一种特殊的技巧——将嵌入矩阵初始化为非常小的值,并在其后直接增加一个 LayerNorm(层归一化)操作。这能让模型在训练初期迅速摆脱噪声状态,极大地加速并稳定了深层网络的训练。

模型

RWKV 的主体由多个堆叠的残差块组成,其宏观结构与 Transformer 类似。每个残差块内部包含两个主要子模块:时间混合(Time-Mixing)通道混合(Channel-Mixing)

Token Shift

模型会将当前时间步的输入 xtx_t​ 与前一个时间步的输入 xt1x_{t−1}​ 进行线性插值混合xshifted=xtμ+xt1(1μ)x_{shifted} = x_t \odot \mu + x_{t-1} \odot (1 - \mu) 这里的 μ\mu 是一个可学习的参数。这种类似 1D 卷积的操作让模型在不增加额外计算复杂度的情况下,天然具备了局部的历史视野,在单一层内捕获相邻 token 之间的局部联系

在 RWKV-6 中,这个混合比例 μ 甚至变成了数据依赖的(Data-dependent,通过 LoRA 动态计算)。不过在最新的 RWKV-7 中,为了极致的训练速度,又退回了简单参数化的 Token Shift,而把动态计算的算力留给了更核心的模块

时间混合模块 (Time-Mixing)

在 Transformer 中,负责融合全局上下文的是 Self-Attention(自注意力机制)。在 RWKV 中,取代它的是 Time Mixing。它的计算基于四个核心概念(也就是 RWKV 名字的由来):

  • R (Receptance, 接收度): 类似 Attention 的 Query,决定当前时刻接收历史信息的“意愿”有多强。
  • W (Weight / Time Decay, 时间衰减): 控制历史信息随着时间流逝被遗忘的程度。
  • K (Key, 键): 类似 Attention 的 Key,当前时刻的特征标识。
  • V (Value, 值): 类似 Attention 的 Value,当前时刻的实际内容。

首先,使用前面做过 Token Shift 的输入 xshiftedx_{shifted},通过不同的线性变换分别生成 Rt,Kt,VtR_t, K_t, V_tRt=Wr(μrxt+(1μr)xt1)R_{t}=W_r\cdot(\mu_r\odot x_t+(1-\mu_r)\odot x_{t-1}) Kt=Wk(μkxt+(1μk)xt1)K_{t}=W_k\cdot(\mu_k\odot x_t+(1-\mu_k)\odot x_{t-1}) Vt=Wv(μvxt+(1μv)xt1)V_{t}=W_v\cdot(\mu_v\odot x_t+(1-\mu_v)\odot x_{t-1}) 接下来计算wkvtwkv_t,论文给的公式如下

wkvt=i=1t1e(t1i)w+kivi+eu+ktvti=1t1e(t1i)w+ki+eu+ktwkv_t=\frac{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}\odot v_i+e^{u+k_t}\odot v_t}{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}

不会真的有人在看这个公式吧,我们换个方法看他

WKVt=历史(KV)的衰减累加+eKtVt历史(K)的衰减累加+eKt\text{WKV}_t = \frac{\text{历史}(K \cdot V)的衰减累加 + e^{K_t}V_t}{\text{历史}(K)的衰减累加 + e^{K_t}}

  • 分子由两项相加构成:左边是历史,右边是当前
    • 左边项: i=1t1e(t1i)w+kivi\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}\odot v_ii=1i=1t1t-1循环遍历过去所有的 Token,ekivie^{k_i} \odot v_i是历史第 ii 步的 Key 和 Value 结合产生的信息,e(t1i)we^{-(t-1-i)w}就是时间衰减 (Time Decay)。(t1i)(t-1-i) 是当前时刻 tt 距离历史时刻 ii 的“距离”。距离越远,这个值越大,乘以负数衰减参数 w-w 后,指数 ee 的结果就越趋近于 0。也就是说==越久远的记忆,权重越低==
    • 右边项: eu+ktvte^{u+k_t}\odot v_t, 这里的 ktk_tvtv_t 就是当前时刻算出来的 Key 和 Value.多出来的 uu:论文里叫它 Time First 或者是对当前 Token 的额外奖励 (Bonus)。在传统的自注意力中,当前词对自己的注意力得分通常是最高的。为了模拟这种效果,RWKV 给当前时刻的 KtK_t 额外加上了一个可学习的偏置 uu,让当前 Token 的信息在混合时占据主导地位。
  • 分母的结构和分子完全一样,只是去掉了 VtV_tViV_i
    • 左边 i=1t1e(t1i)w+ki\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i} 对应:“历史 (K)(K) 的衰减累加”
    • 右边 eu+kte^{u+k_t} 对应:当前时刻 KtK_t (带奖励 uu)的权重。

问:为什么要除以分母? 答:和 Transformer 里 Attention 最后要过一个 Softmax 是一样的道理,为了把所有的权重归一化到 010 \sim 1 之间,防止数值爆炸。

最后,把算出来的 WKV 历史总结,乘以当前的接收度 RtR_t(通常会经过 Sigmoid 激活),再经过一个线性输出层,得到当前模块的输出:

Outputt=Wout(σ(Rt)WKVt)Output_t = W_{out} \cdot (\sigma(R_t) \odot \text{WKV}_t)

通道混合模块 (Channel-Mixing Block)

这个模块的作用是替代 Transformer 中的前馈神经网络(FFN),用于在特征维度(Channel)上进行信息的深度整合

在基础的 RWKV-4 之后,RWKV-5 (Eagle) 引入了矩阵值状态(Matrix-Valued States)以提升表达能力;RWKV-6 (Finch) 引入了数据依赖的动态衰减机制和数据依赖的 Token Shift (LoRA机制);而最新的 RWKV-7 (Goose) 则进一步引入了广义的 Delta Rule(误差修正规则)和动态向量门控,使其具备了追踪复杂状态和识别所有正则语言的强大理论能力

Channel Mixing 的结构比 Time Mixing 简单很多,它不维护跨步的累加长序列记忆,只关注当前步特征的维度变换

同样先做 Token Shift,混合 xtx_txt1x_{t-1}。对于接收度 RR 和键 KK,我们会学习两组独立的时间混合系数(μr\mu_rμk\mu_kxr=xtμr+xt1(1μr)x'_r = x_t \cdot \mu_r + x_{t-1} \cdot (1 - \mu_r)

xk=xtμk+xt1(1μk)x'_k = x_t \cdot \mu_k + x_{t-1} \cdot (1 - \mu_k)

用混合后的输入,通过两个不同的权重矩阵 WrW_rWkW_k 进行普通的矩阵相乘生成接收度 RtR_t 和键 KtK_t

Rt=WrxrR_t = W_r \cdot x'_r

Kt=WkxkK_t = W_k \cdot x'_k

在标准的 RWKV 中,这一步会把 KtK_t 的维度放大(通常放大 4 倍,比如隐藏层维度从 1024 放大到 4096),这就和 Transformer 的 FFN 第一层放大维度是一样的。

这里把 KtK_t 通过一个激活函数,RWKV 经典设计是使用 Squared ReLU,即 max(0,x)2\max(0, x)^2。平方操作可以增强大特征的表达能力,让网络更平滑。Kt=max(Kt,0)2K'_{t} = \max(K_t, 0)^2 拿激活后的 KtK'_{t} 乘以第三个权重矩阵 WvW_v,将激活后的结果映射为 VtV_tVt=WvKtV_t = W_v \cdot K'_{t} 最后同样用 RtR_t 去门控 VtV_t 输出:Outputt=σ(Rt)VtOutput_t = \sigma(R_t) \odot V_t

组装

现在我们有了 Time Mixing 和 Channel Mixing,就可以像搭积木一样组装出一层完整的 RWKV Block。对每一层

# 假设输入为 x
# 1. Time Mixing 分支
x = x + TimeMixing(LayerNorm1(x))

# 2. Channel Mixing 分支
x = x + ChannelMixing(LayerNorm2(x))

一个完整的 RWKV 模型,就是将上述这个 Block 重复堆叠 LL

输出

经过 LL 层时间步 tt 对应的最后一个隐藏状态 hth_t 已经包含了丰富的上下文信息和语义。

  1. 最终层归一化 (Final LayerNorm): 对 hth_t 做最后一次 LayerNorm,使数值更加稳定。
  2. 预测头 (Head/Output Projection): 将维度从 d_model 线性映射到 vocab_size(词表大小)。这相当于打分器,给词表里的每一个词打分(得到 Logits)。
  3. Softmax 与采样: 如果是推理生成,我们把 Logits 扔进 Softmax 变成概率分布,根据 Temperature、Top-K、Top-P 等策略,从中采样出下一个要说的词。

训练与推理

RWKV 能够无缝在两种模式间切换,这是它最核心的优势:

  • 训练时(时间并行模式 / Transformer-like):由于 WKV 的计算中时间的依赖只存在于元素的累加和指数衰减中,因此可以通过并行的扫描(Serial Scan)或类似 CUDA 自定义内核的方式,把时间维度展开,像 Transformer 一样实现全序列的高度并行训练。训练单层的复杂度为 O(BTd2)O(BTd^2)BB 为 batch,TT 为序列长度,dd 为维度)。
  • 推理时(时间序列模式 / RNN-like):在生成下一个 token 时,RWKV 不需要像 Transformer 那样保留庞大的 KV Cache(键值缓存)。它只需要将上一步计算出的几个向量/矩阵状态(例如在 RWKV-4 中是 5 个大小为 D 的状态向量,在 RWKV-5/6 中是包含历史累积信息的矩阵状态)传递给当前步即可。这种机制使得其推理时的显存占用是恒定的 O(1),生成速度不会随着上下文长度的增加而变慢