From c9c493f20f6bf791081ccf8ca0e7dfd7468d604b Mon Sep 17 00:00:00 2001 From: xavierm Date: Mon, 11 Sep 2023 14:49:18 +0000 Subject: [PATCH] add seed --- llama/generation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama/generation.py b/llama/generation.py index 9045c2f..5f8faf9 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -56,6 +56,7 @@ class Llama: max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, + seed: int = 1, ) -> "Llama": """ Build a Llama instance by initializing and loading a pre-trained model. @@ -91,7 +92,7 @@ class Llama: torch.cuda.set_device(local_rank) # seed must be the same in all processes - torch.manual_seed(1) + torch.manual_seed(seed) if local_rank > 0: sys.stdout = open(os.devnull, "w")