diff --git a/llama/model.py b/llama/model.py index 25a4bad..770526d 100755 --- a/llama/model.py +++ b/llama/model.py @@ -448,6 +448,8 @@ class Transformer(nn.Module): ) self.freqs_cis = precompute_freqs_cis( + # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. + # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 )