KV Cache原理

绫波波 发布于 2025-11-27 72 次阅读


简介

KV Cache本质上是一种缓存机制,主要应用在Transformer架构的模型中,尤其是生成式任务的推理阶段。在Transformer中,注意力机制通过计算Query、key和value来确定输入序列中不同位置的关联程度。KV Cache的作用就是存储已经计算过的K矩阵和V矩阵,避免其在处理新的输入时重复计算,从而提高推理效率。

KV Cache原理

在推理过程中,当模型处理第一个token时,正常计算K矩阵和V矩阵,将其存储到KV Cache中。当处理后续的token时,仅计算新的Query矩阵,然后进行注意力计算。假设输入序列长度为n,第一次计算K和V时,计算量和序列长度相关,为O(n)。在没有KV Cache时,如果需要生成一个长度为n的文本序列,每生成一个token,都需要对整个序列重新计算K和V,计算复杂度为Q(n^2)。加入KV Cache后,后续生成每个token时,只需要计算新的Query(每一步的 Query 向量 只由「当前刚进来的那一个 token」(加上它所在的位置信息)产生),计算量也是O(n)。使用KV Cache后,整体计算复杂度为O(n),效率明显提高。

KV Cache的优势

  • 加速推理:通过避免重复计算K和V,KV Cache能够显著减少推理时间。对于实时性较高的应用(如聊天机器人等)非常重要,能够让模型更快地响应用户输入。
  • 节省内存:由于不需要每次推理时重新计算和存储K和V,KV Cache能够减少内存使用,这使得在资源有限的设备(如移动设备)上也能够运行较大规模的模型。

KV Cache的实现

数据结构

Key Cache:通常是一个形状为(batch_size,num_heads,seq_len,k_dim)的张量,用于存储Key矩阵。其中,batch_size表示一次处理的样本数量,num_heads是注意力头的数量,seq_len是序列长度,k_dim是Key的维度。
Value Cache:形状为(batch_size,num_heads,seq_len,v_dim)的张量,用于存储Value矩阵。其中v_dim是Value的维度。

计算流程

  • 初始化阶段:输入初始token,经过线性变换得到Q、K、V。将K和V存储到KV Cache中,为后续结算准备好“原材料”。
  • 后续token处理阶段:输入新的token,计算新的query,并从KV Cache中读取之前存储的K和V,进行注意力计算,得到输出。
    将新生成的K和V追加到KV Cache中,以便后续使用。随着新token的不断输入,KV Cache将不断更新和扩充,保证模型能够始终进行高效的推理。

代码示例

import torch
import torch.nn as nn
import math

class CachedAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # 定义线性变换层,将输入映射到Query、Key和Value空间
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        # 定义输出线性变换层,将注意力计算结果映射回原维度
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        b, t, c = x.shape
        # 将输入x通过线性变换得到Query,并调整形状和维度
        q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        # 将输入x通过线性变换得到Key,并调整形状和维度
        k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
        # 将输入x通过线性变换得到Value,并调整形状和维度
        v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)

        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            # 将缓存中的Key和当前计算的Key拼接起来
            k = torch.cat((cached_k, k), dim=2)
            # 将缓存中的Value和当前计算的Value拼接起来
            v = torch.cat((cached_v, v), dim=2)

        # 计算注意力分数,这里除以根号下head_dim是为了缩放
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        # 对注意力分数进行softmax归一化
        attn = attn.softmax(dim=-1)
        # 根据注意力分数对Value进行加权求和
        y = (attn @ v).transpose(1, 2).contiguous().view(b, t, c)
        # 通过输出线性变换层得到最终输出
        y = self.out_proj(y)
        return y, (k, v)

应用要点

在实际工程应用中,KV Cache 还面临一些挑战和需要优化的地方。例如,缓存大小的动态调整是一个关键问题。随着输入序列长度的不断增加,缓存占用的内存也会逐渐增大,如果不加以控制,可能会导致内存溢出。因此,需要根据实际情况动态调整缓存大小,比如当缓存达到一定阈值时,对缓存进行压缩或者舍弃部分较早的缓存数据。在多 GPU 环境下,KV Cache 的管理也变得更加复杂。不同 GPU 之间需要进行有效的数据同步,确保每个 GPU 都能获取到正确的缓存数据,同时还要避免数据传输带来的额外开销。可以采用分布式缓存管理策略,合理分配缓存数据到各个 GPU,提高整体的计算效率。