mirror of
https://github.com/meta-llama/llama.git
synced 2026-01-15 16:32:54 -03:00
70 lines
2.4 KiB
Python
Executable File
70 lines
2.4 KiB
Python
Executable File
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
import fire
|
|
|
|
from llama import Llama
|
|
from typing import List
|
|
|
|
def main(
|
|
ckpt_dir: str,
|
|
tokenizer_path: str,
|
|
temperature: float = 0.6,
|
|
top_p: float = 0.9,
|
|
max_seq_len: int = 128,
|
|
max_gen_len: int = 64,
|
|
max_batch_size: int = 4,
|
|
):
|
|
"""
|
|
Entry point of the program for generating text using a pretrained model.
|
|
|
|
Args:
|
|
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
|
|
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
|
|
temperature (float, optional): The temperature value for controlling randomness in generation.
|
|
Defaults to 0.6.
|
|
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
|
|
Defaults to 0.9.
|
|
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
|
|
max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
|
|
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
|
|
"""
|
|
generator = Llama.build(
|
|
ckpt_dir=ckpt_dir,
|
|
tokenizer_path=tokenizer_path,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
)
|
|
|
|
prompts: List[str] = [
|
|
# For these prompts, the expected answer is the natural continuation of the prompt
|
|
"I believe the meaning of life is",
|
|
"Simply put, the theory of relativity states that ",
|
|
"""A brief message congratulating the team on the launch:
|
|
|
|
Hi everyone,
|
|
|
|
I just """,
|
|
# Few shot prompt (providing a few examples before asking model to complete more);
|
|
"""Translate English to French:
|
|
|
|
sea otter => loutre de mer
|
|
peppermint => menthe poivrée
|
|
plush girafe => girafe peluche
|
|
cheese =>""",
|
|
]
|
|
results = generator.text_completion(
|
|
prompts,
|
|
max_gen_len=max_gen_len,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
)
|
|
for prompt, result in zip(prompts, results):
|
|
print(prompt)
|
|
print(f"> {result['generation']}")
|
|
print("\n==================================\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|