diff --git a/llama/model.py b/llama/model.py index 63f3cf1..f7bf64c 100755 --- a/llama/model.py +++ b/llama/model.py @@ -481,8 +481,8 @@ class Transformer(nn.Module): # When performing key-value caching, we compute the attention scores # only for the new sequence. Thus, the matrix of scores is of size - # (seq_len, total_len), and the only masked entries are (i, j) for - # j > cached_len + i, since row i corresponds to token cached_len + i. + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack([ torch.zeros((seqlen, start_pos), device=tokens.device), mask