Skip to content

Commit d70407d

Browse files
mgarrardfacebook-github-bot
authored andcommitted
improve optimization complete logic (facebook#4828)
Summary: This criteria updates the completion state logic to assume if a node can transition, and that transition is to itself, then the optimization is complete. This works because should_transition_to_next_node only considers transtion blocking criteria (ie not max parallelism) when thinking about should transition or not. And if a node points to itself, we can assume that signifies the end of the optimiztion (steps are initialized this way earlier in this stack). this allows allows for the gs to be re-called into, and the tc criterion to change thus putting it back into a non-complete state. An alternative I considered is to check if all transition edges are completed, and at least one points to self. This would look something like the below snippet. It would be much more expensive to evaluate, and is guarding against a malformed strategy. Edges are already known to be created in order of importance, and self transition edges should be considered ending edges when their importance is considered ``` property def optimization_complete(self) -> bool: if len(self._curr.transition_criteria) == 0: return False # Check ALL transition edges, not just the first matching one for next_node, all_tc in self._curr.transition_edges.items(): transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] if not transition_blocking: continue all_met = all( tc.is_met(experiment=self.experiment, curr_node=self._curr) for tc in transition_blocking ) if all_met: # An edge's criteria are met - check where it points if next_node != self._curr.name: return False # Can transition to different node, not complete # All met edges (if any) point to self # Check if we actually have any met criteria pointing to self can_transition, next_node = self._curr.should_transition_to_next_node( raise_data_required_error=False ) return can_transition and next_node == self._curr.name ``` The thrid alternative is to instate "compeletion node", which i think could be viable in the future if we have more complex generation strategies than we currently support, and the self generation logic is too cumbersome. For now though, I think this is a pretty nice simplification that also should have some compute wins. Going from O (number of nodes * number of TC per node), to O(number of tc on current node) Differential Revision: D91549954
1 parent 56711df commit d70407d

File tree

3 files changed

+42
-24
lines changed

3 files changed

+42
-24
lines changed

ax/generation_strategy/generation_node.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
)
4242
from ax.generation_strategy.generator_spec import GeneratorSpec
4343
from ax.generation_strategy.transition_criterion import (
44-
AutoTransitionAfterGen,
4544
MaxGenerationParallelism,
4645
MinTrials,
4746
TransitionCriterion,
@@ -249,21 +248,6 @@ def experiment(self) -> Experiment:
249248
"""Returns the experiment associated with this GenerationStrategy"""
250249
return self.generation_strategy.experiment
251250

252-
@property
253-
def is_completed(self) -> bool:
254-
"""Returns True if this GenerationNode is complete and should transition to
255-
the next node.
256-
"""
257-
# TODO: @mgarrard make this logic more robust and general
258-
# We won't mark a node completed if it has an AutoTransitionAfterGen criterion
259-
# as this is typically used in cyclic generation strategies
260-
should_transition, _ = self.should_transition_to_next_node(
261-
raise_data_required_error=False
262-
)
263-
return should_transition and not any(
264-
isinstance(tc, AutoTransitionAfterGen) for tc in self.transition_criteria
265-
)
266-
267251
@property
268252
def previous_node(self) -> GenerationNode | None:
269253
"""Returns the previous ``GenerationNode``, if any."""

ax/generation_strategy/generation_strategy.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,20 @@ def last_generator_run(self) -> GeneratorRun | None:
180180

181181
@property
182182
def optimization_complete(self) -> bool:
183-
"""Checks whether all nodes are completed in the generation strategy."""
184-
return all(node.is_completed for node in self._nodes)
183+
"""Checks whether optimization is complete.
184+
185+
A strategy is complete when the current node's transition criteria
186+
are met and point back to itself (self-transition).
187+
188+
Nodes with no transition_criteria are infinite by design and never complete.
189+
"""
190+
if len(self._curr.transition_criteria) == 0:
191+
return False
192+
193+
can_transition, next_node = self._curr.should_transition_to_next_node(
194+
raise_data_required_error=False
195+
)
196+
return can_transition and next_node == self._curr.name
185197

186198
def gen_single_trial(
187199
self,
@@ -612,13 +624,13 @@ def _maybe_transition_to_next_node(
612624
self,
613625
raise_data_required_error: bool = True,
614626
) -> bool:
615-
"""Moves this generation strategy to next node if the current node is completed,
616-
and it is not the last node in this generation strategy. This method is safe to
617-
use both when generating candidates or simply checking how many generator runs
618-
(to be made into trials) can currently be produced.
627+
"""Moves this generation strategy to next node if the current node's
628+
transition criteria are met. This method is safe to use both when generating
629+
candidates or simply checking how many generator runs (to be made into trials)
630+
can currently be produced.
619631
620-
NOTE: this method raises ``GenerationStrategyCompleted`` error if the current
621-
generation node is complete, but it is also the last in generation strategy.
632+
NOTE: this method raises ``GenerationStrategyCompleted`` error if the
633+
optimization is complete
622634
623635
Args:
624636
raise_data_required_error: Whether to raise ``DataRequiredError`` in the

ax/generation_strategy/tests/test_generation_strategy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,28 @@ def test_gs_with_input_constructor(self) -> None:
20002000
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_3")
20012001
self.assertEqual(len(trial.generator_runs[1].arms), 8)
20022002

2003+
def test_optimization_complete_single_node_no_criteria(self) -> None:
2004+
"""Test that a single node with no transition_criteria never completes."""
2005+
exp = get_branin_experiment()
2006+
gs = GenerationStrategy(
2007+
nodes=[
2008+
GenerationNode(
2009+
name="infinite sobol",
2010+
generator_specs=[self.sobol_generator_spec],
2011+
transition_criteria=[], # No criteria = infinite by design
2012+
),
2013+
]
2014+
)
2015+
gs.experiment = exp
2016+
2017+
# Generate many trials - never completes
2018+
for _ in range(3):
2019+
self.assertFalse(gs.optimization_complete)
2020+
gr = gs.gen_single_trial(experiment=exp)
2021+
exp.new_trial(generator_run=gr).mark_running(no_runner_required=True)
2022+
2023+
self.assertFalse(gs.optimization_complete)
2024+
20032025
# ------------- Testing helpers (put tests above this line) -------------
20042026

20052027
def _run_GS_for_N_rounds(

0 commit comments

Comments
 (0)