added remanjg docs

This commit is contained in:
rajveer43
2023-08-28 12:18:39 +05:30
parent 7bcee80f48
commit 8cd608cc01
3 changed files with 56 additions and 0 deletions

View File

@@ -17,6 +17,21 @@ def main(
max_batch_size: int = 8,
max_gen_len: Optional[int] = None,
):
"""
Entry point of the program for generating text using a pretrained model.
Args:
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
temperature (float, optional): The temperature value for controlling randomness in generation.
Defaults to 0.6.
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
Defaults to 0.9.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
set to the model's max sequence length. Defaults to None.
"""
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,

View File

@@ -15,6 +15,20 @@ def main(
max_gen_len: int = 64,
max_batch_size: int = 4,
):
"""
Entry point of the program for generating text using a pretrained model.
Args:
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
temperature (float, optional): The temperature value for controlling randomness in generation.
Defaults to 0.6.
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
Defaults to 0.9.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
"""
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,

View File

@@ -12,7 +12,14 @@ logger = getLogger()
class Tokenizer:
"""tokenizing and encoding/decoding text using SentencePiece."""
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.
Args:
model_path (str): The path to the SentencePiece model file.
"""
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
@@ -29,6 +36,17 @@ class Tokenizer:
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
Returns:
List[int]: A list of token IDs.
"""
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
@@ -38,4 +56,13 @@ class Tokenizer:
return t
def decode(self, t: List[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
return self.sp_model.decode(t)