From cd0719ddb42541fa4433e12d9922528832dd6eeb Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Mon, 13 Nov 2023 14:05:24 -0600 Subject: [PATCH] Correct KV comment seqlen -> seqlen + cache_len Update and add comments about the shape of the key and value matrices in the attention component. E.g., the second dimension is of length seqlen + cache_len not seqlen as previously stated. --- llama/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama/model.py b/llama/model.py index f7bf64c..c78570f 100755 --- a/llama/model.py +++ b/llama/model.py @@ -289,12 +289,12 @@ class Attention(nn.Module): values = self.cache_v[:bsz, : start_pos + seqlen] # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)