mirror of
https://github.com/meta-llama/llama.git
synced 2026-01-15 16:32:54 -03:00
add seed
This commit is contained in:
@@ -56,6 +56,7 @@ class Llama:
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
seed: int = 1,
|
||||
) -> "Llama":
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a pre-trained model.
|
||||
@@ -91,7 +92,7 @@ class Llama:
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(1)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
Reference in New Issue
Block a user