diff --git a/example_chat_completion.py b/example_chat_completion.py index 249bf6b..e5c868b 100644 --- a/example_chat_completion.py +++ b/example_chat_completion.py @@ -17,6 +17,21 @@ def main( max_batch_size: int = 8, max_gen_len: Optional[int] = None, ): + """ + Entry point of the program for generating text using a pretrained model. + + Args: + ckpt_dir (str): The directory containing checkpoint files for the pretrained model. + tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. + temperature (float, optional): The temperature value for controlling randomness in generation. + Defaults to 0.6. + top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. + Defaults to 0.9. + max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512. + max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8. + max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be + set to the model's max sequence length. Defaults to None. + """ generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, diff --git a/example_text_completion.py b/example_text_completion.py index 4376b1e..890673e 100755 --- a/example_text_completion.py +++ b/example_text_completion.py @@ -15,6 +15,20 @@ def main( max_gen_len: int = 64, max_batch_size: int = 4, ): + """ + Entry point of the program for generating text using a pretrained model. + + Args: + ckpt_dir (str): The directory containing checkpoint files for the pretrained model. + tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. + temperature (float, optional): The temperature value for controlling randomness in generation. + Defaults to 0.6. + top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. + Defaults to 0.9. + max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128. + max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64. + max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4. + """ generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, diff --git a/llama/tokenizer.py b/llama/tokenizer.py index e3af011..3eda89a 100755 --- a/llama/tokenizer.py +++ b/llama/tokenizer.py @@ -12,7 +12,14 @@ logger = getLogger() class Tokenizer: + """tokenizing and encoding/decoding text using SentencePiece.""" def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a SentencePiece model. + + Args: + model_path (str): The path to the SentencePiece model file. + """ # reload tokenizer assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) @@ -29,6 +36,17 @@ class Tokenizer: assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + + Returns: + List[int]: A list of token IDs. + """ assert type(s) is str t = self.sp_model.encode(s) if bos: @@ -38,4 +56,13 @@ class Tokenizer: return t def decode(self, t: List[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ return self.sp_model.decode(t)