mirror of
https://github.com/meta-llama/llama.git
synced 2026-01-15 16:32:54 -03:00
Merge pull request #900 from flu0r1ne/main
Fix key-value caching for seqlen != 1 (Issue #899)
This commit is contained in:
@@ -289,12 +289,12 @@ class Attention(nn.Module):
|
||||
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if mask is not None:
|
||||
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
@@ -474,9 +474,19 @@ class Transformer(nn.Module):
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full(
|
||||
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
|
||||
(seqlen, seqlen), float("-inf"), device=tokens.device
|
||||
)
|
||||
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
||||
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([
|
||||
torch.zeros((seqlen, start_pos), device=tokens.device),
|
||||
mask
|
||||
]).type_as(h)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
|
||||
Reference in New Issue
Block a user