Update transformer mask comment

Update names for consistency with code

Co-authored-by: ruanslv <ruanslv@gmail.com>
This commit is contained in:
Alex
2023-11-13 13:41:06 -06:00
committed by GitHub
parent e9077bd241
commit 6b3154bfbb

View File

@@ -481,8 +481,8 @@ class Transformer(nn.Module):
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seq_len, total_len), and the only masked entries are (i, j) for
# j > cached_len + i, since row i corresponds to token cached_len + i.
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask