Skip to content

Commit 4cbaf9f

Browse files
authored
Bring DataDreamer up-to-date (#36)
* Initial set of updates * Support latest torch * Support latest torch #2 * Bump requirement versions * Hide warning * Minor training fixes * Fix trainer warnings * Update trainer fingerprints * Update TRL * Make Bedrock test pass * Fix warnings * Fix coverage * Support FAISS in foreground for Linux * Fix retriever test * Fix vLLM test * Fix VLLM tests * Bump version
1 parent c535dc4 commit 4cbaf9f

27 files changed

+587
-252
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "DataDreamer"
3-
version = "0.35.0"
3+
version = "0.36.0"
44
description = "Prompt. Generate Synthetic Data. Train & Align Models."
55
license = "MIT"
66
authors= [
@@ -66,7 +66,7 @@ warn_unused_ignores = true
6666
mypy_path = "src/_stubs"
6767

6868
[[tool.mypy.overrides]]
69-
module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.arrow_writer,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer,huggingface_hub.repocard,transformers.trainer_pt_utils"
69+
module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.arrow_writer,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer,huggingface_hub.repocard,transformers.trainer_pt_utils,traitlets.utils.warnings,orjson,Pyro5.errors,sympy,tqdm.auto"
7070
ignore_missing_imports = true
7171

7272
[tool.pyright]

src/_patches/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# An update in datasets 2.20.0 adding state_dict to IterableDataset seems to have
2+
# broken IterableDataset. This patch is a temporary fix until the issue is resolved.
3+
4+
import contextlib
5+
from unittest.mock import patch
6+
7+
from datasets.iterable_dataset import (
8+
ArrowExamplesIterable,
9+
ExamplesIterable,
10+
TypedExamplesIterable,
11+
)
12+
13+
__original_init_state_dict = TypedExamplesIterable._init_state_dict
14+
__original_examples__iter__ = ExamplesIterable.__iter__
15+
__original_arrowexamples__iter__ = ArrowExamplesIterable.__iter__
16+
_should_reset_state_dict = False
17+
18+
19+
def patched_examples__iter__(self):
20+
global _should_reset_state_dict
21+
if _should_reset_state_dict:
22+
self._init_state_dict()
23+
return __original_examples__iter__(self)
24+
25+
26+
def patched_arrowexamples__iter__(self):
27+
global _should_reset_state_dict
28+
if _should_reset_state_dict:
29+
self._init_state_dict()
30+
return __original_arrowexamples__iter__(self)
31+
32+
33+
ExamplesIterable.__iter__ = patched_examples__iter__
34+
ArrowExamplesIterable.__iter__ = patched_arrowexamples__iter__
35+
36+
37+
@contextlib.contextmanager
38+
def apply_datasets_reset_state_hack():
39+
def patched_init_state_dict(self):
40+
self._state_dict = None # Set to None to ensure it is reset
41+
return __original_init_state_dict(self)
42+
43+
with patch(
44+
"datasets.iterable_dataset.TypedExamplesIterable._init_state_dict",
45+
patched_init_state_dict,
46+
):
47+
yield None
48+
49+
50+
def start_datasets_reset_state_hack():
51+
global _should_reset_state_dict
52+
_should_reset_state_dict = True
53+
54+
55+
def stop_datasets_reset_state_hack():
56+
global _should_reset_state_dict
57+
_should_reset_state_dict = False

src/_patches/setfit_import_hack.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SetFit is out-of-date with huggingface_hub and throws an error when trying to import
2+
# from it
3+
# like this: ImportError: cannot import name 'DatasetFilter' from 'huggingface_hub'
4+
5+
# To fix this, we need to monkey patch huggingface_hub to prevent the import error
6+
7+
from ..utils.import_utils import ignore_pydantic_warnings
8+
9+
10+
def apply_setfit_import_hack():
11+
with ignore_pydantic_warnings():
12+
import huggingface_hub
13+
14+
huggingface_hub.DatasetFilter = None

src/datadreamer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
from sqlitedict import SqliteDict
1616

1717
from . import logging as datadreamer_logging
18+
from ._patches.datasets_reset_state_hack import (
19+
start_datasets_reset_state_hack,
20+
stop_datasets_reset_state_hack,
21+
)
1822
from .logging import DATEFMT, logger
1923
from .utils.background_utils import get_thread_id
2024
from .utils.fs_utils import safe_fn
2125
from .utils.import_utils import ignore_pydantic_warnings, ignore_transformers_warnings
2226

2327
with ignore_transformers_warnings():
2428
from optimum.utils import logging as optimum_logging
29+
from ._patches.setfit_import_hack import apply_setfit_import_hack # isort:skip
30+
31+
apply_setfit_import_hack()
2532
from setfit import logging as setfit_logging
2633
from transformers import logging as transformers_logging
2734

@@ -517,6 +524,9 @@ def __enter__(self): # noqa: C901
517524
)
518525
self._patch_tqdm()
519526

527+
# Activate datasets reset state hack
528+
start_datasets_reset_state_hack()
529+
520530
# Set initialized to True
521531
DataDreamer.ctx.instance = self
522532
DataDreamer.ctx.initialized = True
@@ -546,6 +556,7 @@ def __exit__(self, exc_type, exc_value, exc_tb):
546556

547557
self._unpatch_loggers()
548558
self._unpatch_tqdm()
559+
stop_datasets_reset_state_hack()
549560
processes_to_terminate = DataDreamer.ctx.background_processes
550561
DataDreamer.ctx = UserDict()
551562
if self.output_folder_path:

src/datasets/datasets.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datasets.fingerprint import Hasher
88
from pandas import DataFrame
99

10+
from .._patches.datasets_reset_state_hack import apply_datasets_reset_state_hack
1011
from ..datasets.utils import get_column_names
1112
from ..pickling import unpickle_transform
1213

@@ -42,11 +43,14 @@ def _features(self) -> Features:
4243
return Features()
4344

4445
def __iter__(self):
45-
if self._pickled or self._pickled_inferred: # type:ignore[attr-defined]
46-
for row in iter(self.dataset): # type:ignore[attr-defined]
47-
yield unpickle_transform(row, features=self._features, batched=False)
48-
else:
49-
yield from iter(self.dataset) # type:ignore[attr-defined]
46+
with apply_datasets_reset_state_hack():
47+
if self._pickled or self._pickled_inferred: # type:ignore[attr-defined]
48+
for row in iter(self.dataset): # type:ignore[attr-defined]
49+
yield unpickle_transform(
50+
row, features=self._features, batched=False
51+
)
52+
else:
53+
yield from iter(self.dataset) # type:ignore[attr-defined]
5054

5155
def __getitem__(self, key: int | slice | str | Iterable[int]) -> Any:
5256
"""Get a row or column from the dataset.
@@ -316,7 +320,8 @@ def cast_column(
316320
)
317321

318322
def __iter__(self):
319-
return iter(self.dataset)
323+
with apply_datasets_reset_state_hack():
324+
return iter(self.dataset)
320325

321326
def __len__(self) -> int:
322327
return self.total_num_rows

src/embedders/sentence_transformers_embedder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def model(self) -> SentenceTransformer:
9797
model = cls(
9898
self.model_name,
9999
trust_remote_code=self.trust_remote_code,
100-
device=self.device,
100+
device=self.device, # type:ignore[arg-type]
101101
**self.kwargs,
102102
)
103103
model[0].tokenizer = get_tokenizer(
@@ -160,7 +160,11 @@ def model_max_length(self) -> int:
160160

161161
@cached_property
162162
def dims(self) -> int:
163-
return self.model.get_sentence_embedding_dimension()
163+
dims = self.model.get_sentence_embedding_dimension()
164+
assert (
165+
dims is not None
166+
), f"Failed to get the embedding dimension for {self.model_name}."
167+
return dims
164168

165169
@torch.no_grad()
166170
def _run_batch(
@@ -181,8 +185,8 @@ def _run_batch(
181185
model_input = [[cast(str, instruction), t] for t in texts]
182186

183187
return list(
184-
self.model.encode(
185-
sentences=model_input,
188+
self.model.encode( # type:ignore[arg-type]
189+
sentences=model_input, # type:ignore[arg-type]
186190
batch_size=len(texts),
187191
show_progress_bar=False,
188192
convert_to_numpy=True,

src/llms/_litellm.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def retry_wrapper(self):
5858
from litellm.exceptions import (
5959
APIConnectionError,
6060
APIError,
61+
InternalServerError,
6162
RateLimitError,
6263
ServiceUnavailableError,
6364
)
@@ -81,6 +82,14 @@ def retry_wrapper(self):
8182
stop=stop_any(lambda _: not self.retry_on_fail), # type: ignore[arg-type]
8283
reraise=True,
8384
)
85+
@retry(
86+
retry=retry_if_exception_type(InternalServerError),
87+
wait=wait_exponential(multiplier=1, min=3, max=300),
88+
before_sleep=before_sleep_log(tenacity_logger, logging.INFO),
89+
after=after_log(tenacity_logger, logging.INFO),
90+
stop=stop_any(lambda _: not self.retry_on_fail), # type: ignore[arg-type]
91+
reraise=True,
92+
)
8493
@retry(
8594
retry=retry_if_exception_type(APIError),
8695
wait=wait_exponential(multiplier=1, min=3, max=300),
@@ -98,7 +107,8 @@ def retry_wrapper(self):
98107
reraise=True,
99108
)
100109
def _retry_wrapper(func, **kwargs):
101-
return func(**kwargs)
110+
with ignore_litellm_warnings():
111+
return func(**kwargs)
102112

103113
_retry_wrapper.__wrapped__.__module__ = None # type: ignore[attr-defined]
104114
_retry_wrapper.__wrapped__.__qualname__ = f"{self.__class__.__name__}.run" # type: ignore[attr-defined]
@@ -126,7 +136,11 @@ def get_max_context_length(self, max_new_tokens: int) -> int:
126136
with ignore_litellm_warnings():
127137
from litellm import get_max_tokens
128138

129-
return get_max_tokens(model=self._model_name_prefix + self.model_name)
139+
max_tokens = get_max_tokens(model=self._model_name_prefix + self.model_name)
140+
assert (
141+
max_tokens is not None
142+
), f"Failed to get the maximum context length for model: {self.model_name}."
143+
return max_tokens
130144

131145
@ring.lru(maxsize=5000)
132146
def count_tokens(self, value: str) -> int:

src/llms/vllm.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
from functools import cached_property, partial
55
from typing import Any, Callable, Generator, Iterable
66

7-
import dill
87
import torch
98
from datasets.fingerprint import Hasher
109

1110
from .. import DataDreamer
1211
from ..logging import logger as datadreamer_logger
1312
from ..utils.arg_utils import AUTO, Default
14-
from ..utils.background_utils import RunIfTimeout, proxy_resource_in_background
13+
from ..utils.background_utils import (
14+
RunIfTimeout,
15+
dill_serializer,
16+
proxy_resource_in_background,
17+
)
1518
from ..utils.device_utils import get_device_env_variables, is_cpu_device
1619
from ..utils.fs_utils import safe_fn
1720
from ..utils.hf_model_utils import get_tokenizer
18-
from ..utils.import_utils import ignore_transformers_warnings, import_module
21+
from ..utils.import_utils import (
22+
ignore_tqdm,
23+
ignore_transformers_warnings,
24+
import_module,
25+
)
1926
from .hf_transformers import CachedTokenizer, HFTransformers
2027
from .llm import (
2128
DEFAULT_BATCH_SIZE,
@@ -62,6 +69,9 @@ def __init__(
6269
cache_folder_path=cache_folder_path,
6370
**kwargs,
6471
)
72+
self.device = (
73+
[self.device] if not isinstance(self.device, list) else self.device # type:ignore[list-item]
74+
)
6575
self.quantization = quantization
6676
if self.quantization is None and "-awq" in model_name.lower():
6777
self.quantization = "awq"
@@ -89,6 +99,10 @@ def _monkey_patch_init_logger(*args, **kwargs):
8999

90100
vllm_logging.init_logger = _monkey_patch_init_logger # type:ignore[attr-defined]
91101
logging.getLogger("vllm.engine.llm_engine").level = logging.ERROR
102+
logging.getLogger("vllm.config").level = logging.ERROR
103+
logging.getLogger(
104+
"vllm.distributed.parallel_state"
105+
).level = logging.ERROR
92106

93107
# Load model
94108
log_if_timeout = RunIfTimeout(
@@ -101,18 +115,19 @@ def _monkey_patch_init_logger(*args, **kwargs):
101115
timeout=10.0,
102116
)
103117
LLM = import_module("vllm").LLM
104-
self_resource.model = LLM(
105-
model=self.model_name,
106-
trust_remote_code=self.trust_remote_code,
107-
dtype=str(self.dtype).replace("torch.", "")
108-
if self.dtype is not None
109-
else "auto",
110-
quantization=self.quantization,
111-
revision=self.revision,
112-
swap_space=self.swap_space,
113-
tensor_parallel_size=tensor_parallel_size,
114-
**kwargs,
115-
)
118+
with ignore_tqdm():
119+
self_resource.model = LLM(
120+
model=self.model_name,
121+
trust_remote_code=self.trust_remote_code,
122+
dtype=str(self.dtype).replace("torch.", "")
123+
if self.dtype is not None
124+
else "auto",
125+
quantization=self.quantization,
126+
revision=self.revision,
127+
swap_space=self.swap_space,
128+
tensor_parallel_size=tensor_parallel_size,
129+
**kwargs,
130+
)
116131

117132
# Finished loading
118133
log_if_timeout.stop(
@@ -124,9 +139,8 @@ def _monkey_patch_init_logger(*args, **kwargs):
124139
)
125140
)
126141

127-
def get_generated_texts_batch(self_resource, args, kwargs):
128-
args = dill.loads(args)
129-
kwargs = dill.loads(kwargs)
142+
@dill_serializer
143+
def get_generated_texts_batch(self_resource, *args, **kwargs):
130144
outputs = self_resource.model.generate(*args, **kwargs)
131145
generated_texts_batch = [
132146
[o.text for o in batch.outputs] for batch in outputs
@@ -202,8 +216,7 @@ def _run_batch( # noqa: C901
202216
**kwargs,
203217
)
204218
generated_texts_batch = self.model.proxy.get_generated_texts_batch(
205-
args=dill.dumps((prompts, sampling_params)),
206-
kwargs=dill.dumps({"use_tqdm": False}),
219+
prompts, sampling_params, use_tqdm=False
207220
)
208221

209222
# Post-process and return
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch==2.1.2,<3.0.0
1+
torch>=2.1.2,<3.0.0

0 commit comments

Comments
 (0)