diff --git a/llama/generation.py b/llama/generation.py index 9045c2f..5f8faf9 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -56,6 +56,7 @@ class Llama: max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, + seed: int = 1, ) -> "Llama": """ Build a Llama instance by initializing and loading a pre-trained model. @@ -91,7 +92,7 @@ class Llama: torch.cuda.set_device(local_rank) # seed must be the same in all processes - torch.manual_seed(1) + torch.manual_seed(seed) if local_rank > 0: sys.stdout = open(os.devnull, "w")