rwkv笔记
rwkv
transformers
RWKV-V4: https://arxiv.org/abs/2305.13048 RWKV-v5-Eagle: https://arxiv.org/abs/2404.05892 RWKV-v6-Finch: https://arxiv.org/abs/2404.05892 RWKV-v7-Gooserwkv7, RWKV-v7-G1: https://arxiv.org/abs/2503.14456
RWKV 的名字来源于其时间混合(Time-Mixing)模块中的四个核心基础向量组件:
- R (Receptance,接受度):其作用类似于传统注意力机制中的 Query,用来决定模型在当前时间步应该接收和放行多少过去的信息。
- W (Weight,权重):这是一个位置权重衰减向量(可训练参数)。它控制着过去的信息随着时间推移如何呈指数级衰减。
- K (Key,键):类似于传统注意力机制中的 K,代表当前输入包含的特征信息。
- V (Value,值):类似于传统注意力机制中的 V,代表当前输入所携带的实际内容信息。
输入
和绝大多数自然语言模型一样,RWKV 不直接理解文字,它需要先将文本转化为数字向量。
- 分词器 (Tokenizer):RWKV 采用了特制的 RWKV World Tokenizer。它使用基于前缀树(Trie)的贪婪匹配算法,词表大小为 65536。这种设计特别优化了多语言(如中文、阿拉伯语等非欧洲语言)和代码数据的处理效率,避免了传统 BPE 分词器在小语种上效率低下的问题。
- 小初始化嵌入 (Small Init Embedding):文本被转换为 Token 后进入嵌入层(Embedding)。RWKV 使用了一种特殊的技巧——将嵌入矩阵初始化为非常小的值,并在其后直接增加一个 LayerNorm(层归一化)操作。这能让模型在训练初期迅速摆脱噪声状态,极大地加速并稳定了深层网络的训练。
模型
RWKV 的主体由多个堆叠的残差块组成,其宏观结构与 Transformer 类似。每个残差块内部包含两个主要子模块:时间混合(Time-Mixing) 和 通道混合(Channel-Mixing)
Token Shift
模型会将当前时间步的输入 与前一个时间步的输入 进行线性插值混合 这里的 是一个可学习的参数。这种类似 1D 卷积的操作让模型在不增加额外计算复杂度的情况下,天然具备了局部的历史视野,在单一层内捕获相邻 token 之间的局部联系
在 RWKV-6 中,这个混合比例 μ 甚至变成了数据依赖的(Data-dependent,通过 LoRA 动态计算)。不过在最新的 RWKV-7 中,为了极致的训练速度,又退回了简单参数化的 Token Shift,而把动态计算的算力留给了更核心的模块
时间混合模块 (Time-Mixing)
在 Transformer 中,负责融合全局上下文的是 Self-Attention(自注意力机制)。在 RWKV 中,取代它的是 Time Mixing。它的计算基于四个核心概念(也就是 RWKV 名字的由来):
- R (Receptance, 接收度): 类似 Attention 的 Query,决定当前时刻接收历史信息的“意愿”有多强。
- W (Weight / Time Decay, 时间衰减): 控制历史信息随着时间流逝被遗忘的程度。
- K (Key, 键): 类似 Attention 的 Key,当前时刻的特征标识。
- V (Value, 值): 类似 Attention 的 Value,当前时刻的实际内容。
首先,使用前面做过 Token Shift 的输入 ,通过不同的线性变换分别生成 。 接下来计算,论文给的公式如下
不会真的有人在看这个公式吧,我们换个方法看他
- 分子由两项相加构成:左边是历史,右边是当前
- 左边项: 。 到 循环遍历过去所有的 Token,是历史第 步的 Key 和 Value 结合产生的信息,就是时间衰减 (Time Decay)。 是当前时刻 距离历史时刻 的“距离”。距离越远,这个值越大,乘以负数衰减参数 后,指数 的结果就越趋近于 0。也就是说==越久远的记忆,权重越低==
- 右边项: , 这里的 和 就是当前时刻算出来的 Key 和 Value.多出来的 :论文里叫它 Time First 或者是对当前 Token 的额外奖励 (Bonus)。在传统的自注意力中,当前词对自己的注意力得分通常是最高的。为了模拟这种效果,RWKV 给当前时刻的 额外加上了一个可学习的偏置 ,让当前 Token 的信息在混合时占据主导地位。
- 分母的结构和分子完全一样,只是去掉了 和 。
- 左边 对应:“历史 的衰减累加”。
- 右边 对应:当前时刻 (带奖励 )的权重。
问:为什么要除以分母? 答:和 Transformer 里 Attention 最后要过一个 Softmax 是一样的道理,为了把所有的权重归一化到 之间,防止数值爆炸。
最后,把算出来的 WKV 历史总结,乘以当前的接收度 (通常会经过 Sigmoid 激活),再经过一个线性输出层,得到当前模块的输出:
通道混合模块 (Channel-Mixing Block)
这个模块的作用是替代 Transformer 中的前馈神经网络(FFN),用于在特征维度(Channel)上进行信息的深度整合
在基础的 RWKV-4 之后,RWKV-5 (Eagle) 引入了矩阵值状态(Matrix-Valued States)以提升表达能力;RWKV-6 (Finch) 引入了数据依赖的动态衰减机制和数据依赖的 Token Shift (LoRA机制);而最新的 RWKV-7 (Goose) 则进一步引入了广义的 Delta Rule(误差修正规则)和动态向量门控,使其具备了追踪复杂状态和识别所有正则语言的强大理论能力
Channel Mixing 的结构比 Time Mixing 简单很多,它不维护跨步的累加长序列记忆,只关注当前步特征的维度变换
同样先做 Token Shift,混合 和 。对于接收度 和键 ,我们会学习两组独立的时间混合系数( 和 )
用混合后的输入,通过两个不同的权重矩阵 和 进行普通的矩阵相乘生成接收度 和键
在标准的 RWKV 中,这一步会把 的维度放大(通常放大 4 倍,比如隐藏层维度从 1024 放大到 4096),这就和 Transformer 的 FFN 第一层放大维度是一样的。
这里把 通过一个激活函数,RWKV 经典设计是使用 Squared ReLU,即 。平方操作可以增强大特征的表达能力,让网络更平滑。 拿激活后的 乘以第三个权重矩阵 ,将激活后的结果映射为 。 最后同样用 去门控 输出:
组装
现在我们有了 Time Mixing 和 Channel Mixing,就可以像搭积木一样组装出一层完整的 RWKV Block。对每一层
# 假设输入为 x
# 1. Time Mixing 分支
x = x + TimeMixing(LayerNorm1(x))
# 2. Channel Mixing 分支
x = x + ChannelMixing(LayerNorm2(x))
一个完整的 RWKV 模型,就是将上述这个 Block 重复堆叠 层
输出
经过 层时间步 对应的最后一个隐藏状态 已经包含了丰富的上下文信息和语义。
- 最终层归一化 (Final LayerNorm): 对 做最后一次 LayerNorm,使数值更加稳定。
- 预测头 (Head/Output Projection): 将维度从
d_model线性映射到vocab_size(词表大小)。这相当于打分器,给词表里的每一个词打分(得到 Logits)。 - Softmax 与采样: 如果是推理生成,我们把 Logits 扔进 Softmax 变成概率分布,根据 Temperature、Top-K、Top-P 等策略,从中采样出下一个要说的词。
训练与推理
RWKV 能够无缝在两种模式间切换,这是它最核心的优势:
- 训练时(时间并行模式 / Transformer-like):由于 WKV 的计算中时间的依赖只存在于元素的累加和指数衰减中,因此可以通过并行的扫描(Serial Scan)或类似 CUDA 自定义内核的方式,把时间维度展开,像 Transformer 一样实现全序列的高度并行训练。训练单层的复杂度为 ( 为 batch, 为序列长度, 为维度)。
- 推理时(时间序列模式 / RNN-like):在生成下一个 token 时,RWKV 不需要像 Transformer 那样保留庞大的 KV Cache(键值缓存)。它只需要将上一步计算出的几个向量/矩阵状态(例如在 RWKV-4 中是 5 个大小为 D 的状态向量,在 RWKV-5/6 中是包含历史累积信息的矩阵状态)传递给当前步即可。这种机制使得其推理时的显存占用是恒定的 O(1),生成速度不会随着上下文长度的增加而变慢