Skip to content

Commit bbae865

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Simplify TransitionCriteria classes by replacing MaxGenerationParallelism with MinTrials (facebook#4819)
Summary: Pull Request resolved: facebook#4819 Pros of this change: * Reduce number of transition criterion classes --> may be easier for folks to wrap their head around the options with fewer options * In the future we want to move this type of budget to concurrency or something like that, so we can go ahead and improve the stack clarity now Cons of this change: * maxgenerationparallelism was a pretty explict name that defined what that criterion was for, without it you have to reason about the inputs of mintrials to understand how to configure parallelism * some maxgen special casing need to be checked using tc.block_gen_if_met property -- this is mainly for test cases (okay imo) and gs init where we construct the transition_to property for steps Note: Given we only have MinTrials as subclass of TrialBasedCriterion, we could actually compress this even further to only have TrialBasedCriterion if we would like. Totally open to that, but kept it as is as it could make developing new trialbased criterion easier to have the base class. Differential Revision: D91358715
1 parent 54cca40 commit bbae865

File tree

13 files changed

+81
-198
lines changed

13 files changed

+81
-198
lines changed

ax/generation_strategy/generation_node.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from ax.generation_strategy.generator_spec import GeneratorSpec
4343
from ax.generation_strategy.transition_criterion import (
4444
AutoTransitionAfterGen,
45-
MaxGenerationParallelism,
4645
MinTrials,
4746
TransitionCriterion,
4847
TrialBasedCriterion,
@@ -1128,7 +1127,7 @@ def __new__(
11281127
)
11291128
if max_parallelism is not None:
11301129
transition_criteria.append(
1131-
MaxGenerationParallelism(
1130+
MinTrials(
11321131
threshold=max_parallelism,
11331132
transition_to=placeholder_transition_to,
11341133
only_in_statuses=[TrialStatus.RUNNING],

ax/generation_strategy/generation_strategy.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,10 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:
395395
)
396396
)
397397
for tc in step.transition_criteria:
398-
if tc.criterion_class == "MaxGenerationParallelism":
399-
# MaxGenerationParallelism transitions to self (current step)
398+
# Only max parallelism criteria (block_gen_if_met=True AND
399+
# block_transition_if_unmet=False) transition to self; all other
400+
# criteria transition to the next step
401+
if tc.block_gen_if_met and not tc.block_transition_if_unmet:
400402
tc._transition_to = step.name
401403
else:
402404
tc._transition_to = next_step_name
@@ -434,24 +436,9 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
434436
# Validate transition edges:
435437
# - All `transition_to` targets must exist in this GS
436438
# - All TCs on one edge must have the same `continue_trial_generation` setting
437-
# All but `MaxGenerationParallelism` TCs must have a `transition_to` set
438439
for node in nodes:
439440
for next_node, tcs in node.transition_edges.items():
440-
if next_node is None:
441-
# TODO[drfreund]: Handle the case of the last generation step not
442-
# having any transition criteria.
443-
# TODO[mgarrard]: Remove MaxGenerationParallelism check when
444-
# we update TransitionCriterion always define `transition_to`
445-
# NOTE: This is done in D86066476
446-
for tc in tcs:
447-
if "MaxGenerationParallelism" not in tc.criterion_class:
448-
raise GenerationStrategyMisconfiguredException(
449-
error_info="Only MaxGenerationParallelism transition"
450-
" criterion can have a null `transition_to` argument,"
451-
f" but {tc.criterion_class} does not define "
452-
f"`transition_to` on {node.name}."
453-
)
454-
elif next_node not in node_names:
441+
if next_node not in node_names:
455442
raise GenerationStrategyMisconfiguredException(
456443
error_info=f"`transition_to` argument "
457444
f"{next_node} does not correspond to any node in"

ax/generation_strategy/tests/test_dispatch_utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
DEFAULT_BAYESIAN_PARALLELISM,
2323
)
2424
from ax.generation_strategy.generation_node import GenerationNode
25-
from ax.generation_strategy.transition_criterion import (
26-
MaxGenerationParallelism,
27-
MinTrials,
28-
)
25+
from ax.generation_strategy.transition_criterion import MinTrials, TrialBasedCriterion
2926
from ax.generators.random.sobol import SobolGenerator
3027
from ax.generators.winsorization_config import WinsorizationConfig
3128
from ax.utils.common.testutils import TestCase
@@ -621,12 +618,12 @@ def test_enforce_sequential_optimization(self) -> None:
621618
sobol_gpei._nodes[0].transition_criteria[0], MinTrials
622619
)
623620
self.assertTrue(node0_min_trials.block_gen_if_met)
624-
# Check that max_parallelism is set by verifying MaxGenerationParallelism
625-
# criterion exists on node 1
621+
# Check that max_parallelism is set by verifying a TrialBasedCriterion
622+
# with block_gen_if_met=True exists on node 1
626623
node1_max_parallelism = [
627624
tc
628625
for tc in sobol_gpei._nodes[1].transition_criteria
629-
if isinstance(tc, MaxGenerationParallelism)
626+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion)
630627
]
631628
self.assertTrue(len(node1_max_parallelism) > 0)
632629
with self.subTest("False"):
@@ -647,11 +644,11 @@ def test_enforce_sequential_optimization(self) -> None:
647644
)
648645
self.assertFalse(node0_min_trials.block_gen_if_met)
649646
# Check that max_parallelism is None by verifying no
650-
# MaxGenerationParallelism criterion exists on node 1
647+
# TrialBasedCriterion with block_gen_if_met=True exists on node 1
651648
node1_max_parallelism = [
652649
tc
653650
for tc in sobol_gpei._nodes[1].transition_criteria
654-
if isinstance(tc, MaxGenerationParallelism)
651+
if tc.block_gen_if_met and isinstance(tc, TrialBasedCriterion)
655652
]
656653
self.assertEqual(len(node1_max_parallelism), 0)
657654
with self.subTest("False and max_parallelism_override"):
@@ -820,7 +817,11 @@ def test_fixed_num_initialization_trials(self) -> None:
820817
def _get_max_parallelism(self, node: GenerationNode) -> int | None:
821818
"""Helper to extract max_parallelism from transition criteria."""
822819
for tc in node.transition_criteria:
823-
if isinstance(tc, MaxGenerationParallelism):
820+
if (
821+
tc.block_gen_if_met
822+
and not tc.block_transition_if_unmet
823+
and isinstance(tc, TrialBasedCriterion)
824+
):
824825
return tc.threshold
825826
return None
826827

ax/generation_strategy/tests/test_generation_strategy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from ax.generation_strategy.generator_spec import GeneratorSpec
5858
from ax.generation_strategy.transition_criterion import (
5959
AutoTransitionAfterGen,
60-
MaxGenerationParallelism,
6160
MinTrials,
6261
)
6362
from ax.generators.random.sobol import SobolGenerator
@@ -1233,7 +1232,7 @@ def test_gs_setup_with_nodes(self) -> None:
12331232
threshold=2,
12341233
transition_to="node_2",
12351234
),
1236-
MaxGenerationParallelism(
1235+
MinTrials(
12371236
threshold=1,
12381237
only_in_statuses=[TrialStatus.RUNNING],
12391238
block_gen_if_met=True,

ax/generation_strategy/tests/test_transition_criterion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
AutoTransitionAfterGen,
2323
AuxiliaryExperimentCheck,
2424
IsSingleObjective,
25-
MaxGenerationParallelism,
2625
MinTrials,
2726
)
2827
from ax.utils.common.logger import get_logger
@@ -186,7 +185,7 @@ def test_default_step_criterion_setup(self) -> None:
186185
threshold=2,
187186
transition_to="GenerationStep_2_BoTorch",
188187
),
189-
MaxGenerationParallelism(
188+
MinTrials(
190189
threshold=1,
191190
only_in_statuses=[TrialStatus.RUNNING],
192191
block_gen_if_met=True,
@@ -466,7 +465,7 @@ def test_repr(self) -> None:
466465
+ "'continue_trial_generation': False, "
467466
+ "'count_only_trials_with_data': False})",
468467
)
469-
max_parallelism = MaxGenerationParallelism(
468+
max_parallelism_criterion = MinTrials(
470469
only_in_statuses=[TrialStatus.EARLY_STOPPED],
471470
threshold=3,
472471
transition_to="GenerationStep_2",
@@ -475,16 +474,17 @@ def test_repr(self) -> None:
475474
not_in_statuses=[TrialStatus.FAILED],
476475
)
477476
self.assertEqual(
478-
str(max_parallelism),
479-
"MaxGenerationParallelism({'threshold': 3, "
477+
str(max_parallelism_criterion),
478+
"MinTrials({'threshold': 3, "
480479
+ "'transition_to': 'GenerationStep_2', "
481480
+ "'only_in_statuses': "
482481
+ "[<enum 'TrialStatus'>.EARLY_STOPPED], "
483482
+ "'not_in_statuses': [<enum 'TrialStatus'>.FAILED], "
484483
+ "'block_transition_if_unmet': False, "
485484
+ "'block_gen_if_met': True, "
486485
+ "'use_all_trials_in_exp': False, "
487-
+ "'continue_trial_generation': False})",
486+
+ "'continue_trial_generation': False, "
487+
+ "'count_only_trials_with_data': False})",
488488
)
489489
auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2")
490490
self.assertEqual(

ax/generation_strategy/transition_criterion.py

Lines changed: 35 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
DATA_REQUIRED_MSG = (
3030
"All trials for current node {node_name} have been generated, "
3131
"but not enough data has been observed to proceed to the next "
32-
"Generation node. Try again when more is are available."
32+
"Generation node. Try again when more data is available."
3333
)
3434

3535

@@ -48,9 +48,7 @@ class TransitionCriterion(SortableBase):
4848
we will raise an error, otherwise we will continue to generate trials
4949
until ``MinTrials`` is met (thus overriding MinTrials).
5050
block_transition_if_unmet: A flag to prevent the node from completing and
51-
being able to transition to another node. Ex: MaxGenerationParallelism
52-
defaults to setting this to False since we can complete and move on from
53-
this node without ever reaching its threshold.
51+
being able to transition to another node.
5452
continue_trial_generation: A flag to indicate that all generation for a given
5553
trial is not completed, and thus even after transition, the next node will
5654
continue to generate arms for the same trial. Example usage: in
@@ -234,9 +232,7 @@ class TrialBasedCriterion(TransitionCriterion):
234232
threshold: The threshold as an integer for this criterion. Ex: If we want to
235233
generate at most 3 trials, then the threshold is 3.
236234
block_transition_if_unmet: A flag to prevent the node from completing and
237-
being able to transition to another node. Ex: MaxGenerationParallelism
238-
defaults to setting this to False since we can complete and move on from
239-
this node without ever reaching its threshold.
235+
being able to transition to another node.
240236
block_gen_if_met: A flag to prevent continued generation from the
241237
associated GenerationNode if this criterion is met but other criterion
242238
remain unmet. Ex: ``MinTrials`` has not been met yet, but
@@ -383,94 +379,18 @@ def is_met(
383379
)
384380

385381

386-
class MaxGenerationParallelism(TrialBasedCriterion):
387-
"""Specific TransitionCriterion implementation which defines the maximum number
388-
of trials that can simultaneously be in the designated trial statuses. The
389-
default behavior is to block generation from the associated GenerationNode if the
390-
threshold is met. This is configured via the `block_gen_if_met` flag being set to
391-
True. This criterion defaults to not blocking transition to another node via the
392-
`block_transition_if_unmet` flag being set to False.
393-
394-
Args:
395-
threshold: The threshold as an integer for this criterion. Ex: If we want to
396-
generate at most 3 trials, then the threshold is 3.
397-
only_in_statuses: A list of trial statuses to filter on when checking the
398-
criterion threshold.
399-
not_in_statuses: A list of trial statuses to exclude when checking the
400-
criterion threshold.
401-
transition_to: The name of the GenerationNode the GenerationStrategy should
402-
transition to when this criterion is met, if it exists.
403-
block_transition_if_unmet: A flag to prevent the node from completing and
404-
being able to transition to another node. Ex: MaxGenerationParallelism
405-
defaults to setting this to False since we can complete and move on from
406-
this node without ever reaching its threshold.
407-
block_gen_if_met: A flag to prevent continued generation from the
408-
associated GenerationNode if this criterion is met but other criterion
409-
remain unmet. Ex: ``MinTrials`` has not been met yet, but
410-
MinTrials has been reached. If this flag is set to true on MinTrials then
411-
we will raise an error, otherwise we will continue to generate trials
412-
until ``MinTrials`` is met (thus overriding MinTrials).
413-
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
414-
only those generated by the current GenerationNode.
415-
continue_trial_generation: A flag to indicate that all generation for a given
416-
trial is not completed, and thus even after transition, the next node will
417-
continue to generate arms for the same trial. Example usage: in
418-
``BatchTrial``s we may enable generation of arms within a batch from
419-
different ``GenerationNodes`` by setting this flag to True. Defaults to
420-
False for MaxGenerationParallelism since this criterion isn't currently
421-
used for node -> node or trial -> trial transition.
422-
count_only_trials_with_data: If set to True, only trials with data will be
423-
counted towards the ``threshold``. Defaults to False.
382+
class MinTrials(TrialBasedCriterion):
424383
"""
384+
Simple class to enforce a threshold for the number of trials with the
385+
designated statuses being generated by a specific GenerationNode.
425386
426-
def __init__(
427-
self,
428-
threshold: int,
429-
transition_to: str,
430-
only_in_statuses: list[TrialStatus] | None = None,
431-
not_in_statuses: list[TrialStatus] | None = None,
432-
block_transition_if_unmet: bool | None = False,
433-
block_gen_if_met: bool | None = True,
434-
use_all_trials_in_exp: bool | None = False,
435-
continue_trial_generation: bool | None = False,
436-
) -> None:
437-
super().__init__(
438-
threshold=threshold,
439-
only_in_statuses=only_in_statuses,
440-
not_in_statuses=not_in_statuses,
441-
transition_to=transition_to,
442-
block_gen_if_met=block_gen_if_met,
443-
block_transition_if_unmet=block_transition_if_unmet,
444-
use_all_trials_in_exp=use_all_trials_in_exp,
445-
continue_trial_generation=continue_trial_generation,
446-
)
447-
448-
def block_continued_generation_error(
449-
self,
450-
node_name: str,
451-
experiment: Experiment,
452-
trials_from_node: set[int],
453-
) -> None:
454-
"""Raises the appropriate error (should only be called when the
455-
``GenerationNode`` is blocked from continued generation). For this
456-
class, the exception is ``MaxParallelismReachedException``.
457-
"""
458-
assert self.block_gen_if_met # Sanity check.
459-
raise MaxParallelismReachedException(
460-
node_name=node_name,
461-
num_running=self.num_contributing_to_threshold(
462-
experiment=experiment, trials_from_node=trials_from_node
463-
),
464-
)
387+
This class can be configured to behave as either:
388+
- A minimum trials criterion (default): blocks transition to next node until
389+
the threshold is met; block_transition_if_unmet=True
390+
- A maximum parallelism criterion: blocks further generation from the current node
391+
when threshold is reached; block_gen_if_met=True
465392
466393
467-
class MinTrials(TrialBasedCriterion):
468-
"""
469-
Simple class to enforce a minimum threshold for the number of trials with the
470-
designated statuses being generated by a specific GenerationNode. The default
471-
behavior is to block transition to the next node if the threshold is unmet, but
472-
not affect continued generation.
473-
474394
Args:
475395
threshold: The threshold as an integer for this criterion. Ex: If we want to
476396
generate at most 3 trials, then the threshold is 3.
@@ -481,15 +401,12 @@ class MinTrials(TrialBasedCriterion):
481401
transition_to: The name of the GenerationNode the GenerationStrategy should
482402
transition to when this criterion is met.
483403
block_transition_if_unmet: A flag to prevent the node from completing and
484-
being able to transition to another node. Ex: MaxGenerationParallelism
485-
defaults to setting this to False since we can complete and move on from
486-
this node without ever reaching its threshold.
404+
being able to transition to another node. Defaults to True for minimum
405+
trials behavior. Set to False for maximum parallelism behavior.
487406
block_gen_if_met: A flag to prevent continued generation from the
488407
associated GenerationNode if this criterion is met but other criterion
489-
remain unmet. Ex: ``MinTrials`` has not been met yet, but
490-
MinTrials has been reached. If this flag is set to true on MinTrials then
491-
we will raise an error, otherwise we will continue to generate trials
492-
until ``MinTrials`` is met (thus overriding MinTrials).
408+
remain unmet. Defaults to False for minimum trials behavior. Set to True
409+
for maximum parallelism behavior.
493410
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
494411
only those generated by the current GenerationNode.
495412
continue_trial_generation: A flag to indicate that all generation for a given
@@ -531,12 +448,27 @@ def block_continued_generation_error(
531448
experiment: Experiment,
532449
trials_from_node: set[int],
533450
) -> None:
534-
"""Raises the appropriate error (should only be called when the
535-
``GenerationNode`` is blocked from continued generation). For this
536-
class, the exception is ``DataRequiredError``.
451+
"""Raises the appropriate error when generation is blocked.
452+
453+
This method is called when block_gen_if_met=True and the criterion is met.
454+
The exception type depends on the criterion's behavior:
455+
- Max parallelism (block_transition_if_unmet=False):
456+
MaxParallelismReachedException
457+
- Enforce num_trials (block_transition_if_unmet=True): DataRequiredError
537458
"""
538459
assert self.block_gen_if_met # Sanity check.
539-
raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name))
460+
if not self.block_transition_if_unmet:
461+
# Max parallelism behavior: temporarily blocked, waiting for trials
462+
# to complete
463+
raise MaxParallelismReachedException(
464+
node_name=node_name,
465+
num_running=self.num_contributing_to_threshold(
466+
experiment=experiment, trials_from_node=trials_from_node
467+
),
468+
)
469+
else:
470+
# Enforce num_trials behavior: hard limit reached, need to transition
471+
raise DataRequiredError(DATA_REQUIRED_MSG.format(node_name=node_name))
540472

541473

542474
class AuxiliaryExperimentCheck(TransitionCriterion):
@@ -571,9 +503,7 @@ class AuxiliaryExperimentCheck(TransitionCriterion):
571503
we will raise an error, otherwise we will continue to generate trials
572504
until ``MinTrials`` is met (thus overriding MinTrials).
573505
block_transition_if_unmet: A flag to prevent the node from completing and
574-
being able to transition to another node. Ex: MaxGenerationParallelism
575-
defaults to setting this to False since we can complete and move on from
576-
this node without ever reaching its threshold.
506+
being able to transition to another node.
577507
continue_trial_generation: A flag to indicate that all generation for a given
578508
trial is not completed, and thus even after transition, the next node will
579509
continue to generate arms for the same trial. Example usage: in

0 commit comments

Comments
 (0)