Skip to content

Commit edbaf44

Browse files
committed
Refactor processors to use callback-based design
Replace stage parameter with callback methods (preprocess, process_after_batch, postprocess). The builder now invokes these callbacks at appropriate stages: PRE_GENERATION, POST_BATCH, and POST_GENERATION. - Remove build_stage from ProcessorConfig - Add callback methods to Processor base class - Update DropColumns and SchemaTransform to use process_after_batch - Simplify ColumnWiseBuilder processor invocation
1 parent 62bae42 commit edbaf44

File tree

12 files changed

+389
-182
lines changed

12 files changed

+389
-182
lines changed

docs/concepts/processors.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ Each processor:
1313
- Applies its transformation
1414
- Passes the result to the next processor (or to output)
1515

16-
Currently, processors run only at the `POST_BATCH` stage, i.e., after column generation completes for each batch.
16+
Processors can run at three stages:
17+
18+
| Stage | When it runs | Use cases |
19+
|-------|--------------|-----------|
20+
| `PRE_GENERATION` | Once, on full seed data before batching | Filter seed data, validate inputs, normalize data |
21+
| `POST_BATCH` | After each batch completes (default) | Drop columns, transform schema per batch |
22+
| `POST_GENERATION` | Once, on final dataset after all batches | Deduplicate, aggregate statistics, final cleanup |
1723

1824
## Processor Types
1925

packages/data-designer-config/src/data_designer/config/processors.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
from typing_extensions import TypeAlias
1313

1414
from data_designer.config.base import ConfigBase
15-
from data_designer.config.dataset_builders import BuildStage
1615
from data_designer.config.errors import InvalidConfigError
1716

18-
SUPPORTED_STAGES = [BuildStage.POST_BATCH]
19-
2017

2118
class ProcessorType(str, Enum):
2219
"""Enumeration of available processor types.
@@ -33,33 +30,22 @@ class ProcessorType(str, Enum):
3330
class ProcessorConfig(ConfigBase, ABC):
3431
"""Abstract base class for all processor configuration types.
3532
36-
Processors are transformations that run before or after columns are generated.
37-
They can modify, reshape, or augment the dataset before it's saved.
33+
Processors are transformations that run at different stages of the generation
34+
pipeline. They can modify, reshape, or augment the dataset.
35+
36+
The processor implementation determines which stages it handles by overriding
37+
the appropriate callback methods (preprocess, process_after_batch, postprocess).
3838
3939
Attributes:
4040
name: Unique name of the processor, used to identify the processor in results
4141
and to name output artifacts on disk.
42-
build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
43-
is supported, meaning processors run after each batch of columns is generated.
4442
"""
4543

4644
name: str = Field(
4745
description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
4846
)
49-
build_stage: BuildStage = Field(
50-
default=BuildStage.POST_BATCH,
51-
description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
52-
)
5347
processor_type: str
5448

55-
@field_validator("build_stage")
56-
def validate_build_stage(cls, v: BuildStage) -> BuildStage:
57-
if v not in SUPPORTED_STAGES:
58-
raise ValueError(
59-
f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}"
60-
)
61-
return v
62-
6349

6450
def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
6551
"""Create a processor configuration from a processor type and keyword arguments.

packages/data-designer-config/tests/config/test_processors.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
from pydantic import ValidationError
66

7-
from data_designer.config.dataset_builders import BuildStage
87
from data_designer.config.errors import InvalidConfigError
98
from data_designer.config.processors import (
109
DropColumnsProcessorConfig,
@@ -16,92 +15,64 @@
1615

1716

1817
def test_drop_columns_processor_config_creation():
19-
config = DropColumnsProcessorConfig(
20-
name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"]
21-
)
18+
config = DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["col1", "col2"])
2219

23-
assert config.build_stage == BuildStage.POST_BATCH
2420
assert config.column_names == ["col1", "col2"]
2521
assert config.processor_type == ProcessorType.DROP_COLUMNS
2622
assert isinstance(config, ProcessorConfig)
2723

2824

2925
def test_drop_columns_processor_config_validation():
30-
# Test unsupported stage raises error
31-
with pytest.raises(ValidationError, match="Invalid dataset builder stage"):
32-
DropColumnsProcessorConfig(
33-
name="drop_columns_processor", build_stage=BuildStage.PRE_BATCH, column_names=["col1"]
34-
)
35-
3626
# Test missing required field raises error
3727
with pytest.raises(ValidationError, match="Field required"):
38-
DropColumnsProcessorConfig(name="drop_columns_processor", build_stage=BuildStage.POST_BATCH)
28+
DropColumnsProcessorConfig(name="drop_columns_processor")
3929

4030

4131
def test_drop_columns_processor_config_serialization():
42-
config = DropColumnsProcessorConfig(
43-
name="drop_columns_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1", "col2"]
44-
)
32+
config = DropColumnsProcessorConfig(name="drop_columns_processor", column_names=["col1", "col2"])
4533

4634
# Serialize to dict
4735
config_dict = config.model_dump()
48-
assert config_dict["build_stage"] == "post_batch"
4936
assert config_dict["column_names"] == ["col1", "col2"]
5037

5138
# Deserialize from dict
5239
config_restored = DropColumnsProcessorConfig.model_validate(config_dict)
53-
assert config_restored.build_stage == config.build_stage
5440
assert config_restored.column_names == config.column_names
5541

5642

5743
def test_schema_transform_processor_config_creation():
5844
config = SchemaTransformProcessorConfig(
5945
name="output_format_processor",
60-
build_stage=BuildStage.POST_BATCH,
6146
template={"text": "{{ col1 }}"},
6247
)
6348

64-
assert config.build_stage == BuildStage.POST_BATCH
6549
assert config.template == {"text": "{{ col1 }}"}
6650
assert config.processor_type == ProcessorType.SCHEMA_TRANSFORM
6751
assert isinstance(config, ProcessorConfig)
6852

6953

7054
def test_schema_transform_processor_config_validation():
71-
# Test unsupported stage raises error
72-
with pytest.raises(ValidationError, match="Invalid dataset builder stage"):
73-
SchemaTransformProcessorConfig(
74-
name="schema_transform_processor",
75-
build_stage=BuildStage.PRE_BATCH,
76-
template={"text": "{{ col1 }}"},
77-
)
78-
7955
# Test missing required field raises error
8056
with pytest.raises(ValidationError, match="Field required"):
81-
SchemaTransformProcessorConfig(name="schema_transform_processor", build_stage=BuildStage.POST_BATCH)
57+
SchemaTransformProcessorConfig(name="schema_transform_processor")
8258

8359
# Test invalid template raises error
8460
with pytest.raises(InvalidConfigError, match="Template must be JSON serializable"):
85-
SchemaTransformProcessorConfig(
86-
name="schema_transform_processor", build_stage=BuildStage.POST_BATCH, template={"text": {1, 2, 3}}
87-
)
61+
SchemaTransformProcessorConfig(name="schema_transform_processor", template={"text": {1, 2, 3}})
8862

8963

9064
def test_schema_transform_processor_config_serialization():
9165
config = SchemaTransformProcessorConfig(
9266
name="schema_transform_processor",
93-
build_stage=BuildStage.POST_BATCH,
9467
template={"text": "{{ col1 }}"},
9568
)
9669

9770
# Serialize to dict
9871
config_dict = config.model_dump()
99-
assert config_dict["build_stage"] == "post_batch"
10072
assert config_dict["template"] == {"text": "{{ col1 }}"}
10173

10274
# Deserialize from dict
10375
config_restored = SchemaTransformProcessorConfig.model_validate(config_dict)
104-
assert config_restored.build_stage == config.build_stage
10576
assert config_restored.template == config.template
10677

10778

@@ -110,7 +81,6 @@ def test_get_processor_config_from_kwargs():
11081
config_drop_columns = get_processor_config_from_kwargs(
11182
ProcessorType.DROP_COLUMNS,
11283
name="drop_columns_processor",
113-
build_stage=BuildStage.POST_BATCH,
11484
column_names=["col1"],
11585
)
11686
assert isinstance(config_drop_columns, DropColumnsProcessorConfig)
@@ -120,7 +90,6 @@ def test_get_processor_config_from_kwargs():
12090
config_schema_transform = get_processor_config_from_kwargs(
12191
ProcessorType.SCHEMA_TRANSFORM,
12292
name="output_format_processor",
123-
build_stage=BuildStage.POST_BATCH,
12493
template={"text": "{{ col1 }}"},
12594
)
12695
assert isinstance(config_schema_transform, SchemaTransformProcessorConfig)
@@ -134,6 +103,6 @@ class UnknownProcessorType(str, Enum):
134103
UNKNOWN = "unknown"
135104

136105
result = get_processor_config_from_kwargs(
137-
UnknownProcessorType.UNKNOWN, name="unknown_processor", build_stage=BuildStage.POST_BATCH, column_names=["col1"]
106+
UnknownProcessorType.UNKNOWN, name="unknown_processor", column_names=["col1"]
138107
)
139108
assert result is None

0 commit comments

Comments
 (0)