|
2 | 2 | import json |
3 | 3 | import numpy as np |
4 | 4 | import sqlite3 |
5 | | -from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Tuple |
| 5 | +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Tuple |
6 | 6 |
|
7 | 7 | from evalscope.constants import HEARTBEAT_INTERVAL_SEC |
8 | 8 | from evalscope.utils.logger import get_logger |
|
26 | 26 | @exception_handler |
27 | 27 | async def get_requests(args: Arguments, api_plugin: 'ApiPluginBase') -> AsyncGenerator[dict, None]: |
28 | 28 |
|
29 | | - async def generate_requests_from_prompt(): |
| 29 | + async def _generate_from_prompt(): |
| 30 | + """Generate requests by repeating a single prompt.""" |
30 | 31 | prompt = load_prompt(args.prompt) |
31 | 32 | messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt |
32 | 33 | request = api_plugin.build_request(messages) |
33 | 34 | for _ in range(args.number): |
34 | 35 | yield request |
35 | 36 |
|
36 | | - async def generate_requests_from_dataset(): |
37 | | - message_generator_class = DatasetRegistry.get_class(args.dataset) |
38 | | - message_generator = message_generator_class(args) |
39 | | - |
| 37 | + async def _generate_from_dataset(): |
| 38 | + """Generate requests by cycling through a dataset.""" |
| 39 | + message_generator = DatasetRegistry.get_class(args.dataset)(args) |
40 | 40 | dataset_messages = [] |
41 | | - try: |
42 | | - for messages in message_generator.build_messages(): |
| 41 | + |
| 42 | + # Load dataset messages into memory (limited by args.number) |
| 43 | + # We catch StopIteration implicitly via the loop |
| 44 | + with tqdm(message_generator.build_messages(), desc='Generating datasets', total=args.number, initial=1) as pbar: |
| 45 | + for messages in pbar: |
43 | 46 | dataset_messages.append(messages) |
44 | 47 | if len(dataset_messages) >= args.number: |
45 | 48 | break |
46 | | - except StopIteration: |
47 | | - pass |
48 | 49 |
|
49 | 50 | if not dataset_messages: |
50 | | - raise Exception('Dataset is empty!') |
| 51 | + raise ValueError('Dataset is empty!') |
51 | 52 |
|
| 53 | + # Yield requests cyclically until total count is reached |
52 | 54 | count = 0 |
53 | 55 | dataset_index = 0 |
| 56 | + num_messages = len(dataset_messages) |
54 | 57 |
|
55 | 58 | while count < args.number: |
56 | 59 | messages = dataset_messages[dataset_index] |
57 | 60 | request = api_plugin.build_request(messages) |
58 | 61 | if request is not None: |
59 | 62 | yield request |
60 | 63 | count += 1 |
| 64 | + dataset_index = (dataset_index + 1) % num_messages |
61 | 65 |
|
62 | | - dataset_index = (dataset_index + 1) % len(dataset_messages) |
63 | | - |
| 66 | + # Dispatch based on arguments |
64 | 67 | if args.prompt: |
65 | | - generator = generate_requests_from_prompt() |
| 68 | + generator = _generate_from_prompt() |
66 | 69 | elif args.dataset: |
67 | | - generator = generate_requests_from_dataset() |
| 70 | + generator = _generate_from_dataset() |
68 | 71 | else: |
69 | 72 | raise ValueError('Either prompt or dataset is required!') |
70 | 73 |
|
| 74 | + # Yield requests with rate limiting |
71 | 75 | async for request in generator: |
72 | 76 | yield request |
73 | 77 | if args.rate != -1: |
|
0 commit comments