From e9077bd24177a74aa79f406bef7d4b57fe393157 Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Thu, 2 Nov 2023 19:33:26 -0500 Subject: [PATCH 1/3] Fix key-value caching for seqlen != 1 This commit fixes a bug in the key-value caching. Currently, a square attention mask is misapplied to the scores matrix despite not matching the shape of the scores matrix. This results in a runtime error. In a correct implementation, the decoder mask needs to describe how the new seq_len tokens interact with all the cached tokens. That is, the attention mask needs to be of shape (seq_len, total_len), indicating how the token at row i (representing token i + cached_len in the transformer model) attends to token j. Accordingly, the matrix needs to mask entries where j > cached_len + i. This patch horizontally appends (seq_len, cached_len) zeros to an upper-triangular mask of size (seq_len, seq_len) to form the (seq_len, total_len) mask. --- llama/model.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index 770526d..63f3cf1 100755 --- a/llama/model.py +++ b/llama/model.py @@ -474,9 +474,19 @@ class Transformer(nn.Module): mask = None if seqlen > 1: mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + (seqlen, seqlen), float("-inf"), device=tokens.device ) - mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seq_len, total_len), and the only masked entries are (i, j) for + # j > cached_len + i, since row i corresponds to token cached_len + i. + mask = torch.hstack([ + torch.zeros((seqlen, start_pos), device=tokens.device), + mask + ]).type_as(h) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) From 6b3154bfbbb56c4665ca083be1d46c1e4f1bcc33 Mon Sep 17 00:00:00 2001 From: Alex <76689481+flu0r1ne@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:41:06 -0600 Subject: [PATCH 2/3] Update transformer mask comment Update names for consistency with code Co-authored-by: ruanslv --- llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index 63f3cf1..f7bf64c 100755 --- a/llama/model.py +++ b/llama/model.py @@ -481,8 +481,8 @@ class Transformer(nn.Module): # When performing key-value caching, we compute the attention scores # only for the new sequence. Thus, the matrix of scores is of size - # (seq_len, total_len), and the only masked entries are (i, j) for - # j > cached_len + i, since row i corresponds to token cached_len + i. + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack([ torch.zeros((seqlen, start_pos), device=tokens.device), mask From cd0719ddb42541fa4433e12d9922528832dd6eeb Mon Sep 17 00:00:00 2001 From: flu0r1ne Date: Mon, 13 Nov 2023 14:05:24 -0600 Subject: [PATCH 3/3] Correct KV comment seqlen -> seqlen + cache_len Update and add comments about the shape of the key and value matrices in the attention component. E.g., the second dimension is of length seqlen + cache_len not seqlen as previously stated. --- llama/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama/model.py b/llama/model.py index f7bf64c..c78570f 100755 --- a/llama/model.py +++ b/llama/model.py @@ -289,12 +289,12 @@ class Attention(nn.Module): values = self.cache_v[:bsz, : start_pos + seqlen] # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)