Skip to content

Commit 9bc1b3e

Browse files
committed
test: use stream_infer instead of generate
1 parent 8f40d98 commit 9bc1b3e

File tree

1 file changed

+8
-33
lines changed

1 file changed

+8
-33
lines changed

tests/test_lmdeploy/test_grammar.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import asyncio
21
import json
32
import re
43

54
import pytest
65
from jsonschema import validate
76

87
from lmdeploy import pipeline
9-
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, TurbomindEngineConfig
8+
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
109

1110
MODEL_IDS = [
1211
'Qwen/Qwen3-0.6B',
@@ -98,28 +97,6 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
9897
pipe.close()
9998

10099

101-
async def collect(*aiters):
102-
results = [[] for _ in range(len(aiters))]
103-
104-
async def drain(idx, aiter):
105-
async for item in aiter:
106-
results[idx].append(item)
107-
108-
await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters)))
109-
110-
responses = []
111-
for r in results:
112-
resp = Response(text='', input_token_len=0, generate_token_len=0)
113-
responses.append(resp)
114-
for out in r:
115-
resp.text += out.response
116-
resp.input_token_len = out.input_token_len
117-
resp.generate_token_len = out.generate_token_len
118-
resp.finish_reason = out.finish_reason
119-
120-
return responses
121-
122-
123100
@pytest.mark.parametrize('model_id', MODEL_IDS)
124101
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
125102
def test_mix_guided_matrix(model_id, backend_name, backend_factory):
@@ -134,17 +111,15 @@ def test_mix_guided_matrix(model_id, backend_name, backend_factory):
134111
schema = SCHEMA_MAP[schema_type]
135112
response_format[schema_type] = dict(name='test', schema=schema)
136113

137-
gen_config = GenerationConfig(response_format=response_format)
114+
prompts = ['Make a self introduction please.'] * 4
115+
config = GenerationConfig(response_format=response_format)
116+
117+
gen_config = [None if idx % 3 else config for idx in range(4)]
138118

139-
configs = [None if idx % 3 else gen_config for idx in range(4)]
140-
tasks = [
141-
pipe.generate(messages='Make a self introduction please.', session_id=session_id, gen_config=gen_config)
142-
for session_id, gen_config in enumerate(configs)
143-
]
119+
responses = pipe.stream_infer(prompts, gen_config=gen_config)
144120

145-
responses = asyncio.run(collect(*tasks))
146-
for resp, config in zip(responses, configs):
147-
if config is None:
121+
for resp, c in zip(responses, gen_config):
122+
if c is None:
148123
assert '}' not in resp.text
149124
else:
150125
validate(instance=json.loads(resp.text), schema=schema)

0 commit comments

Comments
 (0)