Skip to content

Commit 33276b0

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Simplify TransitionCriteria classes by replacing MaxGenerationParallelism with MinTrials (#4819)
Summary: Pull Request resolved: #4819 Pros of this change: * Reduce number of transition criterion classes --> may be easier for folks to wrap their head around the options with fewer options * In the future we want to move this type of budget to concurrency or something like that, so we can go ahead and improve the stack clarity now Cons of this change: * maxgenerationparallelism was a pretty explict name that defined what that criterion was for, without it you have to reason about the inputs of mintrials to understand how to configure parallelism * some maxgen special casing need to be checked using tc.block_gen_if_met property -- this is mainly for test cases (okay imo) and gs init where we construct the transition_to property for steps Note: Given we only have MinTrials as subclass of TrialBasedCriterion, we could actually compress this even further to only have TrialBasedCriterion if we would like. Totally open to that, but kept it as is as it could make developing new trialbased criterion easier to have the base class. Differential Revision: D91358715
1 parent 54cca40 commit 33276b0

File tree

13 files changed

+66
-198
lines changed

13 files changed

+66
-198
lines changed

ax/generation_strategy/generation_node.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from ax.generation_strategy.generator_spec import GeneratorSpec
4343
from ax.generation_strategy.transition_criterion import (
4444
AutoTransitionAfterGen,
45-
MaxGenerationParallelism,
4645
MinTrials,
4746
TransitionCriterion,
4847
TrialBasedCriterion,
@@ -1128,7 +1127,7 @@ def __new__(
11281127
)
11291128
if max_parallelism is not None:
11301129
transition_criteria.append(
1131-
MaxGenerationParallelism(
1130+
MinTrials(
11321131
threshold=max_parallelism,
11331132
transition_to=placeholder_transition_to,
11341133
only_in_statuses=[TrialStatus.RUNNING],

ax/generation_strategy/generation_strategy.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:
395395
)
396396
)
397397
for tc in step.transition_criteria:
398-
if tc.criterion_class == "MaxGenerationParallelism":
399-
# MaxGenerationParallelism transitions to self (current step)
398+
# Criteria that block continued generation will transition to self
399+
if tc.block_gen_if_met:
400400
tc._transition_to = step.name
401401
else:
402402
tc._transition_to = next_step_name
@@ -434,24 +434,9 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
434434
# Validate transition edges:
435435
# - All `transition_to` targets must exist in this GS
436436
# - All TCs on one edge must have the same `continue_trial_generation` setting
437-
# All but `MaxGenerationParallelism` TCs must have a `transition_to` set
438437
for node in nodes:
439438
for next_node, tcs in node.transition_edges.items():
440-
if next_node is None:
441-
# TODO[drfreund]: Handle the case of the last generation step not
442-
# having any transition criteria.
443-
# TODO[mgarrard]: Remove MaxGenerationParallelism check when
444-
# we update TransitionCriterion always define `transition_to`
445-
# NOTE: This is done in D86066476
446-
for tc in tcs:
447-
if "MaxGenerationParallelism" not in tc.criterion_class:
448-
raise GenerationStrategyMisconfiguredException(
449-
error_info="Only MaxGenerationParallelism transition"
450-
" criterion can have a null `transition_to` argument,"
451-
f" but {tc.criterion_class} does not define "
452-
f"`transition_to` on {node.name}."
453-
)
454-
elif next_node not in node_names:
439+
if next_node not in node_names:
455440
raise GenerationStrategyMisconfiguredException(
456441
error_info=f"`transition_to` argument "
457442
f"{next_node} does not correspond to any node in"

ax/generation_strategy/tests/test_dispatch_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
DEFAULT_BAYESIAN_PARALLELISM,
2323
)
2424
from ax.generation_strategy.generation_node import GenerationNode
25-
from ax.generation_strategy.transition_criterion import (
26-
MaxGenerationParallelism,
27-
MinTrials,
28-
)
25+
from ax.generation_strategy.transition_criterion import MinTrials, TrialBasedCriterion
2926
from ax.generators.random.sobol import SobolGenerator
3027
from ax.generators.winsorization_config import WinsorizationConfig
3128
from ax.utils.common.testutils import TestCase
@@ -621,12 +618,12 @@ def test_enforce_sequential_optimization(self) -> None:
621618
sobol_gpei._nodes[0].transition_criteria[0], MinTrials
622619
)
623620
self.assertTrue(node0_min_trials.block_gen_if_met)
624-
# Check that max_parallelism is set by verifying MaxGenerationParallelism
625-
# criterion exists on node 1
621+
# Check that max_parallelism is set by verifying a TrialBasedCriterion
622+
# with block_gen_if_met=True exists on node 1
626623
node1_max_parallelism = [
627624
tc
628625
for tc in sobol_gpei._nodes[1].transition_criteria
629-
if isinstance(tc, MaxGenerationParallelism)
626+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion)
630627
]
631628
self.assertTrue(len(node1_max_parallelism) > 0)
632629
with self.subTest("False"):
@@ -647,11 +644,11 @@ def test_enforce_sequential_optimization(self) -> None:
647644
)
648645
self.assertFalse(node0_min_trials.block_gen_if_met)
649646
# Check that max_parallelism is None by verifying no
650-
# MaxGenerationParallelism criterion exists on node 1
647+
# TrialBasedCriterion with block_gen_if_met=True exists on node 1
651648
node1_max_parallelism = [
652649
tc
653650
for tc in sobol_gpei._nodes[1].transition_criteria
654-
if isinstance(tc, MaxGenerationParallelism)
651+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion)
655652
]
656653
self.assertEqual(len(node1_max_parallelism), 0)
657654
with self.subTest("False and max_parallelism_override"):
@@ -820,7 +817,7 @@ def test_fixed_num_initialization_trials(self) -> None:
820817
def _get_max_parallelism(self, node: GenerationNode) -> int | None:
821818
"""Helper to extract max_parallelism from transition criteria."""
822819
for tc in node.transition_criteria:
823-
if isinstance(tc, MaxGenerationParallelism):
820+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion):
824821
return tc.threshold
825822
return None
826823

ax/generation_strategy/tests/test_generation_strategy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from ax.generation_strategy.generator_spec import GeneratorSpec
5858
from ax.generation_strategy.transition_criterion import (
5959
AutoTransitionAfterGen,
60-
MaxGenerationParallelism,
6160
MinTrials,
6261
)
6362
from ax.generators.random.sobol import SobolGenerator
@@ -1233,7 +1232,7 @@ def test_gs_setup_with_nodes(self) -> None:
12331232
threshold=2,
12341233
transition_to="node_2",
12351234
),
1236-
MaxGenerationParallelism(
1235+
MinTrials(
12371236
threshold=1,
12381237
only_in_statuses=[TrialStatus.RUNNING],
12391238
block_gen_if_met=True,

ax/generation_strategy/tests/test_transition_criterion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
AutoTransitionAfterGen,
2323
AuxiliaryExperimentCheck,
2424
IsSingleObjective,
25-
MaxGenerationParallelism,
2625
MinTrials,
2726
)
2827
from ax.utils.common.logger import get_logger
@@ -186,7 +185,7 @@ def test_default_step_criterion_setup(self) -> None:
186185
threshold=2,
187186
transition_to="GenerationStep_2_BoTorch",
188187
),
189-
MaxGenerationParallelism(
188+
MinTrials(
190189
threshold=1,
191190
only_in_statuses=[TrialStatus.RUNNING],
192191
block_gen_if_met=True,
@@ -466,7 +465,7 @@ def test_repr(self) -> None:
466465
+ "'continue_trial_generation': False, "
467466
+ "'count_only_trials_with_data': False})",
468467
)
469-
max_parallelism = MaxGenerationParallelism(
468+
max_parallelism_criterion = MinTrials(
470469
only_in_statuses=[TrialStatus.EARLY_STOPPED],
471470
threshold=3,
472471
transition_to="GenerationStep_2",
@@ -475,16 +474,17 @@ def test_repr(self) -> None:
475474
not_in_statuses=[TrialStatus.FAILED],
476475
)
477476
self.assertEqual(
478-
str(max_parallelism),
479-
"MaxGenerationParallelism({'threshold': 3, "
477+
str(max_parallelism_criterion),
478+
"MinTrials({'threshold': 3, "
480479
+ "'transition_to': 'GenerationStep_2', "
481480
+ "'only_in_statuses': "
482481
+ "[<enum 'TrialStatus'>.EARLY_STOPPED], "
483482
+ "'not_in_statuses': [<enum 'TrialStatus'>.FAILED], "
484483
+ "'block_transition_if_unmet': False, "
485484
+ "'block_gen_if_met': True, "
486485
+ "'use_all_trials_in_exp': False, "
487-
+ "'continue_trial_generation': False})",
486+
+ "'continue_trial_generation': False, "
487+
+ "'count_only_trials_with_data': False})",
488488
)
489489
auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2")
490490
self.assertEqual(

ax/generation_strategy/transition_criterion.py

Lines changed: 26 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
DATA_REQUIRED_MSG = (
3030
"All trials for current node {node_name} have been generated, "
3131
"but not enough data has been observed to proceed to the next "
32-
"Generation node. Try again when more is are available."
32+
"Generation node. Try again when more data is available."
3333
)
3434

3535

@@ -48,9 +48,7 @@ class TransitionCriterion(SortableBase):
4848
we will raise an error, otherwise we will continue to generate trials
4949
until ``MinTrials`` is met (thus overriding MinTrials).
5050
block_transition_if_unmet: A flag to prevent the node from completing and
51-
being able to transition to another node. Ex: MaxGenerationParallelism
52-
defaults to setting this to False since we can complete and move on from
53-
this node without ever reaching its threshold.
51+
being able to transition to another node.
5452
continue_trial_generation: A flag to indicate that all generation for a given
5553
trial is not completed, and thus even after transition, the next node will
5654
continue to generate arms for the same trial. Example usage: in
@@ -234,9 +232,7 @@ class TrialBasedCriterion(TransitionCriterion):
234232
threshold: The threshold as an integer for this criterion. Ex: If we want to
235233
generate at most 3 trials, then the threshold is 3.
236234
block_transition_if_unmet: A flag to prevent the node from completing and
237-
being able to transition to another node. Ex: MaxGenerationParallelism
238-
defaults to setting this to False since we can complete and move on from
239-
this node without ever reaching its threshold.
235+
being able to transition to another node.
240236
block_gen_if_met: A flag to prevent continued generation from the
241237
associated GenerationNode if this criterion is met but other criterion
242238
remain unmet. Ex: ``MinTrials`` has not been met yet, but
@@ -383,93 +379,17 @@ def is_met(
383379
)
384380

385381

386-
class MaxGenerationParallelism(TrialBasedCriterion):
387-
"""Specific TransitionCriterion implementation which defines the maximum number
388-
of trials that can simultaneously be in the designated trial statuses. The
389-
default behavior is to block generation from the associated GenerationNode if the
390-
threshold is met. This is configured via the `block_gen_if_met` flag being set to
391-
True. This criterion defaults to not blocking transition to another node via the
392-
`block_transition_if_unmet` flag being set to False.
393-
394-
Args:
395-
threshold: The threshold as an integer for this criterion. Ex: If we want to
396-
generate at most 3 trials, then the threshold is 3.
397-
only_in_statuses: A list of trial statuses to filter on when checking the
398-
criterion threshold.
399-
not_in_statuses: A list of trial statuses to exclude when checking the
400-
criterion threshold.
401-
transition_to: The name of the GenerationNode the GenerationStrategy should
402-
transition to when this criterion is met, if it exists.
403-
block_transition_if_unmet: A flag to prevent the node from completing and
404-
being able to transition to another node. Ex: MaxGenerationParallelism
405-
defaults to setting this to False since we can complete and move on from
406-
this node without ever reaching its threshold.
407-
block_gen_if_met: A flag to prevent continued generation from the
408-
associated GenerationNode if this criterion is met but other criterion
409-
remain unmet. Ex: ``MinTrials`` has not been met yet, but
410-
MinTrials has been reached. If this flag is set to true on MinTrials then
411-
we will raise an error, otherwise we will continue to generate trials
412-
until ``MinTrials`` is met (thus overriding MinTrials).
413-
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
414-
only those generated by the current GenerationNode.
415-
continue_trial_generation: A flag to indicate that all generation for a given
416-
trial is not completed, and thus even after transition, the next node will
417-
continue to generate arms for the same trial. Example usage: in
418-
``BatchTrial``s we may enable generation of arms within a batch from
419-
different ``GenerationNodes`` by setting this flag to True. Defaults to
420-
False for MaxGenerationParallelism since this criterion isn't currently
421-
used for node -> node or trial -> trial transition.
422-
count_only_trials_with_data: If set to True, only trials with data will be
423-
counted towards the ``threshold``. Defaults to False.
382+
class MinTrials(TrialBasedCriterion):
424383
"""
384+
Simple class to enforce a threshold for the number of trials with the
385+
designated statuses being generated by a specific GenerationNode.
425386
426-
def __init__(
427-
self,
428-
threshold: int,
429-
transition_to: str,
430-
only_in_statuses: list[TrialStatus] | None = None,
431-
not_in_statuses: list[TrialStatus] | None = None,
432-
block_transition_if_unmet: bool | None = False,
433-
block_gen_if_met: bool | None = True,
434-
use_all_trials_in_exp: bool | None = False,
435-
continue_trial_generation: bool | None = False,
436-
) -> None:
437-
super().__init__(
438-
threshold=threshold,
439-
only_in_statuses=only_in_statuses,
440-
not_in_statuses=not_in_statuses,
441-
transition_to=transition_to,
442-
block_gen_if_met=block_gen_if_met,
443-
block_transition_if_unmet=block_transition_if_unmet,
444-
use_all_trials_in_exp=use_all_trials_in_exp,
445-
continue_trial_generation=continue_trial_generation,
446-
)
387+
This class can be configured to behave as either:
388+
- A minimum trials criterion (default): blocks transition to next node until
389+
the threshold is met; block_transition_if_unmet=True
390+
- A maximum parallelism criterion: blocks further generation from the current node
391+
when threshold is reached; block_gen_if_met=True
447392
448-
def block_continued_generation_error(
449-
self,
450-
node_name: str,
451-
experiment: Experiment,
452-
trials_from_node: set[int],
453-
) -> None:
454-
"""Raises the appropriate error (should only be called when the
455-
``GenerationNode`` is blocked from continued generation). For this
456-
class, the exception is ``MaxParallelismReachedException``.
457-
"""
458-
assert self.block_gen_if_met # Sanity check.
459-
raise MaxParallelismReachedException(
460-
node_name=node_name,
461-
num_running=self.num_contributing_to_threshold(
462-
experiment=experiment, trials_from_node=trials_from_node
463-
),
464-
)
465-
466-
467-
class MinTrials(TrialBasedCriterion):
468-
"""
469-
Simple class to enforce a minimum threshold for the number of trials with the
470-
designated statuses being generated by a specific GenerationNode. The default
471-
behavior is to block transition to the next node if the threshold is unmet, but
472-
not affect continued generation.
473393
474394
Args:
475395
threshold: The threshold as an integer for this criterion. Ex: If we want to
@@ -481,15 +401,12 @@ class MinTrials(TrialBasedCriterion):
481401
transition_to: The name of the GenerationNode the GenerationStrategy should
482402
transition to when this criterion is met.
483403
block_transition_if_unmet: A flag to prevent the node from completing and
484-
being able to transition to another node. Ex: MaxGenerationParallelism
485-
defaults to setting this to False since we can complete and move on from
486-
this node without ever reaching its threshold.
404+
being able to transition to another node. Defaults to True for minimum
405+
trials behavior. Set to False for maximum parallelism behavior.
487406
block_gen_if_met: A flag to prevent continued generation from the
488407
associated GenerationNode if this criterion is met but other criterion
489-
remain unmet. Ex: ``MinTrials`` has not been met yet, but
490-
MinTrials has been reached. If this flag is set to true on MinTrials then
491-
we will raise an error, otherwise we will continue to generate trials
492-
until ``MinTrials`` is met (thus overriding MinTrials).
408+
remain unmet. Defaults to False for minimum trials behavior. Set to True
409+
for maximum parallelism behavior.
493410
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
494411
only those generated by the current GenerationNode.
495412
continue_trial_generation: A flag to indicate that all generation for a given
@@ -531,12 +448,18 @@ def block_continued_generation_error(
531448
experiment: Experiment,
532449
trials_from_node: set[int],
533450
) -> None:
534-
"""Raises the appropriate error (should only be called when the
535-
``GenerationNode`` is blocked from continued generation). For this
536-
class, the exception is ``DataRequiredError``.
451+
"""Raises MaxParallelismReachedException when generation is blocked.
452+
453+
This method is only called when block_gen_if_met=True (max parallelism
454+
behavior), which always has block_transition_if_unmet=False.
537455
"""
538456
assert self.block_gen_if_met # Sanity check.
539-
raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name))
457+
raise MaxParallelismReachedException(
458+
node_name=node_name,
459+
num_running=self.num_contributing_to_threshold(
460+
experiment=experiment, trials_from_node=trials_from_node
461+
),
462+
)
540463

541464

542465
class AuxiliaryExperimentCheck(TransitionCriterion):
@@ -571,9 +494,7 @@ class AuxiliaryExperimentCheck(TransitionCriterion):
571494
we will raise an error, otherwise we will continue to generate trials
572495
until ``MinTrials`` is met (thus overriding MinTrials).
573496
block_transition_if_unmet: A flag to prevent the node from completing and
574-
being able to transition to another node. Ex: MaxGenerationParallelism
575-
defaults to setting this to False since we can complete and move on from
576-
this node without ever reaching its threshold.
497+
being able to transition to another node.
577498
continue_trial_generation: A flag to indicate that all generation for a given
578499
trial is not completed, and thus even after transition, the next node will
579500
continue to generate arms for the same trial. Example usage: in

ax/orchestration/tests/test_orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
GenerationStep,
5050
GenerationStrategy,
5151
)
52-
from ax.generation_strategy.transition_criterion import MaxGenerationParallelism
52+
from ax.generation_strategy.transition_criterion import TrialBasedCriterion
5353
from ax.metrics.branin import BraninMetric
5454
from ax.metrics.branin_map import BraninTimestampMapMetric
5555
from ax.orchestration.orchestrator import (
@@ -1171,7 +1171,7 @@ def test_run_trials_and_yield_results_with_early_stopper(self) -> None:
11711171
# Extract max_parallelism from transition criteria
11721172
node0_max_parallelism = None
11731173
for tc in self.two_sobol_steps_GS._nodes[0].transition_criteria:
1174-
if isinstance(tc, MaxGenerationParallelism):
1174+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion):
11751175
node0_max_parallelism = tc.threshold
11761176
break
11771177
self.assertEqual(

0 commit comments

Comments
 (0)