This commit is contained in:
xavierm
2023-09-11 14:49:18 +00:00
parent 1bc5221c2a
commit c9c493f20f

View File

@@ -56,6 +56,7 @@ class Llama:
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
seed: int = 1,
) -> "Llama":
"""
Build a Llama instance by initializing and loading a pre-trained model.
@@ -91,7 +92,7 @@ class Llama:
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")