Skip to content

Commit 89083f1

Browse files
authored
Fix a variety of bugs + support newest OpenAI models (gpt-4o, gpt-4o-mini) [release] (#37)
* Fix notebook login bug * Fix validate quantization config bug * Fix prompt steps * Allow VLLM progress bars to be enabled * Support GPT-4o and GPT-4o-mini
1 parent 4cbaf9f commit 89083f1

File tree

10 files changed

+104
-34
lines changed

10 files changed

+104
-34
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "DataDreamer"
3-
version = "0.36.0"
3+
version = "0.37.0"
44
description = "Prompt. Generate Synthetic Data. Train & Align Models."
55
license = "MIT"
66
authors= [

src/llms/openai.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,30 @@ def _is_gpt_3_5_legacy(model_name: str):
6363
@lru_cache(maxsize=None)
6464
def _is_gpt_4(model_name: str):
6565
model_name = _normalize_model_name(model_name)
66-
return model_name == "gpt-4" or any(
67-
gpt4_name in model_name for gpt4_name in ["gpt-4-"]
66+
return (
67+
model_name == "gpt-4"
68+
or any(gpt4_name in model_name for gpt4_name in ["gpt-4-"])
69+
or _is_gpt_4o(model_name)
6870
)
6971

7072

73+
@lru_cache(maxsize=None)
74+
def _is_gpt_4o(model_name: str):
75+
model_name = _normalize_model_name(model_name)
76+
return any(gpt4_name in model_name for gpt4_name in ["gpt-4o"])
77+
78+
79+
@lru_cache(maxsize=None)
80+
def _is_gpt_mini(model_name: str):
81+
model_name = _normalize_model_name(model_name)
82+
return any(gpt_mini_name in model_name for gpt_mini_name in ["-mini"])
83+
84+
7185
@lru_cache(maxsize=None)
7286
def _is_128k_model(model_name: str):
7387
model_name = _normalize_model_name(model_name)
7488
return _is_gpt_4(model_name) and (
75-
"-preview" in model_name or "2024-04-09" in model_name
89+
_is_gpt_4o(model_name) or "-preview" in model_name or "2024-04-09" in model_name
7690
)
7791

7892

@@ -249,11 +263,14 @@ def get_max_context_length(self, max_new_tokens: int) -> int: # pragma: no cove
249263
return max_context_length - max_new_tokens - format_tokens
250264

251265
def _get_max_output_length(self) -> None | int: # pragma: no cover
252-
if _is_128k_model(self.model_name) or (
266+
if _is_128k_model(self.model_name) and _is_gpt_mini(self.model_name):
267+
return 16384
268+
elif _is_128k_model(self.model_name) or (
253269
_is_gpt_3_5(self.model_name) and not (_is_gpt_3_5_legacy(self.model_name))
254270
):
255271
return 4096
256-
return None
272+
else:
273+
return None
257274

258275
@ring.lru(maxsize=5000)
259276
def count_tokens(self, value: str) -> int:

src/llms/vllm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import logging
33
import os
4+
from contextlib import nullcontext
45
from functools import cached_property, partial
56
from typing import Any, Callable, Generator, Iterable
67

@@ -115,7 +116,7 @@ def _monkey_patch_init_logger(*args, **kwargs):
115116
timeout=10.0,
116117
)
117118
LLM = import_module("vllm").LLM
118-
with ignore_tqdm():
119+
with ignore_tqdm() if datadreamer_logger.level > logging.DEBUG else nullcontext():
119120
self_resource.model = LLM(
120121
model=self.model_name,
121122
trust_remote_code=self.trust_remote_code,
@@ -216,7 +217,9 @@ def _run_batch( # noqa: C901
216217
**kwargs,
217218
)
218219
generated_texts_batch = self.model.proxy.get_generated_texts_batch(
219-
prompts, sampling_params, use_tqdm=False
220+
prompts,
221+
sampling_params,
222+
use_tqdm=(datadreamer_logger.level <= logging.DEBUG),
220223
)
221224

222225
# Post-process and return

src/steps/prompt/few_shot_prompt_with_retrieval.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,10 @@ def output_examples_generator():
143143

144144
return input_examples_generator, output_examples_generator
145145

146+
def _run_prompts(self, args, *positionalargs, **kwargs):
147+
args.pop("embedder")
148+
args.pop("k")
149+
return super()._run_prompts(args, *positionalargs, **kwargs)
150+
146151

147152
__all__ = ["FewShotPromptWithRetrieval"]

src/steps/prompt/rag_prompt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def retrieved_texts_generator():
9999
def run(self):
100100
# Get inputs and arguments
101101
args = self.args
102+
args.pop("retriever")
103+
args.pop("k")
102104
llm = args["llm"]
103105
prompts = self.inputs["prompts"]
104106
retrieved_text_label = args.pop("retrieved_text_label")

src/tests/llms/test_llms.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,17 +1102,20 @@ def test_metadata(self, create_datadreamer):
11021102
assert llm.citation[0].endswith("year={2020}\n}")
11031103
assert llm.citation[1].startswith("@article{ouyang2022training")
11041104
assert llm.citation[1].endswith("year={2022}\n}")
1105-
llm = OpenAI("gpt-4")
1106-
assert llm.model_card == "https://cdn.openai.com/papers/gpt-4-system-card.pdf"
1107-
assert llm.license == "https://openai.com/policies"
1108-
assert isinstance(llm.citation, list)
1109-
assert len(llm.citation) == 2
1110-
assert llm.citation[0].startswith("@article{OpenAI2023GPT4TR,")
1111-
assert llm.citation[0].endswith(
1112-
"url={https://api.semanticscholar.org/CorpusID:257532815}\n}"
1113-
)
1114-
assert llm.citation[1].startswith("@article{ouyang2022training")
1115-
assert llm.citation[1].endswith("year={2022}\n}")
1105+
for gpt_4_model_name in ["gpt-4", "gpt-4o", "gpt-4o-mini"]:
1106+
llm = OpenAI(gpt_4_model_name)
1107+
assert (
1108+
llm.model_card == "https://cdn.openai.com/papers/gpt-4-system-card.pdf"
1109+
)
1110+
assert llm.license == "https://openai.com/policies"
1111+
assert isinstance(llm.citation, list)
1112+
assert len(llm.citation) == 2
1113+
assert llm.citation[0].startswith("@article{OpenAI2023GPT4TR,")
1114+
assert llm.citation[0].endswith(
1115+
"url={https://api.semanticscholar.org/CorpusID:257532815}\n}"
1116+
)
1117+
assert llm.citation[1].startswith("@article{ouyang2022training")
1118+
assert llm.citation[1].endswith("year={2022}\n}")
11161119

11171120
def test_count_tokens(self, create_datadreamer):
11181121
with create_datadreamer():
@@ -1122,6 +1125,10 @@ def test_count_tokens(self, create_datadreamer):
11221125
def test_get_max_context_length(self, create_datadreamer):
11231126
with create_datadreamer():
11241127
# Check max context length
1128+
llm = OpenAI("gpt-4o")
1129+
assert llm.get_max_context_length(max_new_tokens=0) == 127982
1130+
llm = OpenAI("gpt-4o-mini")
1131+
assert llm.get_max_context_length(max_new_tokens=0) == 127982
11251132
llm = OpenAI("gpt-4")
11261133
assert llm.get_max_context_length(max_new_tokens=0) == 8174
11271134
llm = OpenAI("gpt-4-turbo-2024-04-09")
@@ -1136,6 +1143,10 @@ def test_get_max_context_length(self, create_datadreamer):
11361143
def test_get_max_output_length(self, create_datadreamer):
11371144
with create_datadreamer():
11381145
# Check max output length
1146+
llm = OpenAI("gpt-4o")
1147+
assert llm._get_max_output_length() == 4096
1148+
llm = OpenAI("gpt-4o-mini")
1149+
assert llm._get_max_output_length() == 16384
11391150
llm = OpenAI("gpt-4")
11401151
assert llm._get_max_output_length() is None
11411152
llm = OpenAI("gpt-4-turbo-2024-04-09")

src/tests/test_utils/fixtures/mock_llm.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,28 @@
66

77

88
@pytest.fixture
9-
def mock_llm() -> Callable[..., LLM]:
9+
def mock_llm(
10+
allowed_kwargs=frozenset(
11+
{
12+
"inputs",
13+
"batch_size",
14+
"max_new_tokens",
15+
"temperature",
16+
"top_p",
17+
"n",
18+
"stop",
19+
"repetition_penalty",
20+
"logit_bias",
21+
"seed",
22+
"max_length_func",
23+
"cached_tokenizer",
24+
}
25+
),
26+
) -> Callable[..., LLM]:
1027
def _mock_llm(llm: LLM, responses: dict[str, str]) -> LLM:
1128
def _run_batch_mocked(**kwargs):
29+
for kwarg in kwargs:
30+
assert kwarg in allowed_kwargs, f"LLM got unexpected keyword: {kwarg}"
1231
return [responses[prompt] for prompt in kwargs["inputs"]]
1332

1433
llm._run_batch = _run_batch_mocked # type: ignore[attr-defined]

src/utils/hf_hub_utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from io import BytesIO
44
from itertools import chain
55
from typing import TYPE_CHECKING, Any, Callable
6+
from unittest import mock
67

7-
from .import_utils import ignore_pydantic_warnings
8+
from .import_utils import ignore_hf_token_warnings, ignore_pydantic_warnings
89

910
with ignore_pydantic_warnings():
1011
from huggingface_hub import HfApi, hf_hub_download, login
@@ -179,22 +180,27 @@ def get_citation_info(
179180

180181
def hf_hub_login(token: None | str = None) -> HfApi: # pragma: no cover
181182
# Login
182-
api = HfApi()
183-
if token is not None:
184-
try:
185-
login(token=token, add_to_git_credential=False, write_permission=True)
186-
except ValueError:
187-
pass
188-
while True:
189-
try:
190-
api.whoami()
191-
break
192-
except LocalTokenNotFoundError:
183+
with ignore_hf_token_warnings(), mock.patch(
184+
"huggingface_hub._login.is_notebook", new=lambda: False
185+
):
186+
api = HfApi()
187+
if token is not None:
193188
try:
194189
login(token=token, add_to_git_credential=False, write_permission=True)
195190
except ValueError:
196191
pass
197-
return api
192+
while True:
193+
try:
194+
api.whoami()
195+
break
196+
except LocalTokenNotFoundError:
197+
try:
198+
login(
199+
token=token, add_to_git_credential=False, write_permission=True
200+
)
201+
except ValueError:
202+
pass
203+
return api
198204

199205

200206
def prepare_to_publish(

src/utils/hf_model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def validate_quantization_config(
238238
quantization_config = copy(quantization_config)
239239
if (
240240
getattr(quantization_config, "quant_method", None) == "bitsandbytes"
241-
): # pragma: no cover
241+
) and dtype is not None: # pragma: no cover
242242
quantization_config.bnb_4bit_compute_dtype = dtype # type:ignore[union-attr]
243243
quantization_config.bnb_4bit_quant_storage = dtype # type:ignore[union-attr]
244244
return quantization_config

src/utils/import_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def ignore_setfit_warnings():
133133
yield None
134134

135135

136+
@contextlib.contextmanager
137+
def ignore_hf_token_warnings(): # pragma: no cover
138+
with warnings.catch_warnings():
139+
warnings.filterwarnings("ignore", category=UserWarning)
140+
yield None
141+
142+
136143
@contextlib.contextmanager
137144
def ignore_faiss_warnings():
138145
with warnings.catch_warnings():

0 commit comments

Comments
 (0)