From 6b3154bfbbb56c4665ca083be1d46c1e4f1bcc33 Mon Sep 17 00:00:00 2001 From: Alex <76689481+flu0r1ne@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:41:06 -0600 Subject: [PATCH] Update transformer mask comment Update names for consistency with code Co-authored-by: ruanslv --- llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index 63f3cf1..f7bf64c 100755 --- a/llama/model.py +++ b/llama/model.py @@ -481,8 +481,8 @@ class Transformer(nn.Module): # When performing key-value caching, we compute the attention scores # only for the new sequence. Thus, the matrix of scores is of size - # (seq_len, total_len), and the only masked entries are (i, j) for - # j > cached_len + i, since row i corresponds to token cached_len + i. + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack([ torch.zeros((seqlen, start_pos), device=tokens.device), mask