diff --git a/llm_bench/load_test.py b/llm_bench/load_test.py index 174a903..fa23a74 100644 --- a/llm_bench/load_test.py +++ b/llm_bench/load_test.py @@ -162,13 +162,20 @@ def _create_dataset(cls, options: argparse.Namespace): prompt = options.prompt dataset_file = "code.txt" + if options.prompt_cache_max_len is not None: + common_tokens = options.prompt_cache_max_len + elif options.prompt_cache_max_pct is not None: + common_tokens = int(options.prompt_tokens * options.prompt_cache_max_pct / 100) + else: + common_tokens = 0 + return TranslationDataset( path=os.path.join(os.path.dirname(os.path.abspath(__file__)), dataset_file), prompt="\n\n" + prompt, tokenizer_path=options.tokenizer, chat=options.chat, num_tokens=options.prompt_tokens, - common_tokens=options.prompt_cache_max_len, + common_tokens=common_tokens, ) else: raise ValueError(f"Unknown dataset: {options.dataset}") @@ -1451,8 +1458,11 @@ def init_parser(parser): "--prompt-cache-max-len", env_var="PROMPT_CACHE_MAX_LEN", type=int, - default=0, - help="Maximum length of the prompt cache to use. Defaults to 0 (no caching).", + default=None, + help="Maximum number of shared prefix tokens across requests. " + "If not specified but --prompt-cache-max-pct is set, auto-computed as " + "int(prompt_tokens * prompt_cache_max_pct / 100). Defaults to 0 (no shared prefix) " + "when neither this nor --prompt-cache-max-pct is provided.", ) parser.add_argument( "--prompt-cache-max-pct",