mirror of
https://github.com/meta-llama/llama.git
synced 2026-01-15 08:22:55 -03:00
Update transformer mask comment
Update names for consistency with code Co-authored-by: ruanslv <ruanslv@gmail.com>
This commit is contained in:
@@ -481,8 +481,8 @@ class Transformer(nn.Module):
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seq_len, total_len), and the only masked entries are (i, j) for
|
||||
# j > cached_len + i, since row i corresponds to token cached_len + i.
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([
|
||||
torch.zeros((seqlen, start_pos), device=tokens.device),
|
||||
mask
|
||||
|
||||
Reference in New Issue
Block a user