Skip to content

Commit e984607

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Remove arms per node override (#4822)
Summary: Pull Request resolved: #4822 We initially provided this override during transition from legacy dispatch to GS, and at the time voted to persist the override in case folks wanted to provide this information manually. However, in the ~year+ since we added this, I'm not sure we've had any usage of this, and keeping it around increases the complexity of an already challenging part of the stack. My suggestion is to remove this, and for advanced users to directly leverage nodes or models if custom modeling is needed Reviewed By: saitcakmak, lena-kashtelyan Differential Revision: D91513177 fbshipit-source-id: ac5ceaa324a20552574b2f23eb746973c611761c
1 parent 31f5ea9 commit e984607

File tree

5 files changed

+11
-136
lines changed

5 files changed

+11
-136
lines changed

ax/generation_strategy/center_generation_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def gen(
7171
skip_fit: bool = False,
7272
data: Data | None = None,
7373
n: int | None = None,
74-
arms_per_node: dict[str, int] | None = None,
7574
**gs_gen_kwargs: Any,
7675
) -> GeneratorRun | None:
7776
"""Generate candidates or skip if search space is exhausted.
@@ -100,7 +99,6 @@ def gen(
10099
skip_fit=skip_fit,
101100
data=data,
102101
n=n,
103-
arms_per_node=arms_per_node,
104102
**gs_gen_kwargs,
105103
)
106104

ax/generation_strategy/generation_node.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ def gen(
426426
skip_fit: bool = False,
427427
data: Data | None = None,
428428
n: int | None = None,
429-
arms_per_node: dict[str, int] | None = None,
430429
**gs_gen_kwargs: Any,
431430
) -> GeneratorRun | None:
432431
"""This method generates candidates using `self._gen` and handles deduplication
@@ -451,9 +450,6 @@ def gen(
451450
data: Optional override for the experiment data used to generate candidates;
452451
if not specified, will use ``experiment.lookup_data()`` (extracted in
453452
``Adapter``).
454-
arms_per_node: A manual override for users interacting with a gen. strategy
455-
via a Python API; a mapping from node name to the specific number of
456-
arms it should produce. Passed down here by `GenerationStrategy.gen`.
457453
gs_gen_kwargs: Keyword arguments, passed to ``GenerationStrategy.gen``.
458454
These might be modified by this node's input constructors, before
459455
being passed down to ``ModelSpec.gen``, where these will override any
@@ -485,13 +481,6 @@ def gen(
485481
logger.debug(f"Skipping generation for node {self.name}.")
486482
return None
487483

488-
if arms_per_node:
489-
if self.name not in arms_per_node:
490-
raise UnsupportedError(
491-
"If manually specifying arms per node, all nodes must be specified."
492-
)
493-
generator_gen_kwargs["n"] = arms_per_node[self.name]
494-
495484
# TODO[drfreund]: Move this to `Adapter` or another more suitable place.
496485
# Keeping here for now to limit the scope of the current changeset.
497486
generator_gen_kwargs["fixed_features"] = (

ax/generation_strategy/generation_node_input_constructors.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,10 @@ def _get_default_n(experiment: Experiment, next_node: GenerationNode) -> int:
257257
The default number of arms to generate from the next node, used if no n is
258258
provided to the ``GenerationStrategy``'s gen call.
259259
"""
260-
# If the generator spec contains `n` use that value first
261-
# TODO #1 [drfreund, mgarrard]: Eliminate the need to do this; the order should be:
262-
# `arms_per_node[node_name]` > `input_constuctors(n)` > `gen_spec...kwargs["n"]`
263-
# NOTE: We might need to simply disallow `n` in `gen_spec...kwargs`: it should
264-
# probably never be hardcoded there. Without it, we can just enforce that at a
265-
# point within generation strategy, an `n` is passed down to `gen_spec.gen`.
266-
# And if we keep it, we don't have a clear point in this stack at which we are
267-
# "no longer allowed to have a null `n`."
260+
# If the generator spec contains `n` use that value first.
261+
# TODO [drfreund, mgarrard]: Consider disallowing `n` in `gen_spec...kwargs`:
262+
# it should probably never be hardcoded there. This would enforce that `n`
263+
# is always passed down through the generation strategy at runtime.
268264
if next_node.generator_spec_to_gen_from.generator_gen_kwargs.get("n") is not None:
269265
return next_node.generator_spec_to_gen_from.generator_gen_kwargs["n"]
270266

ax/generation_strategy/generation_strategy.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@
1919
from ax.core.generator_run import GeneratorRun
2020
from ax.core.observation import ObservationFeatures
2121
from ax.core.utils import extend_pending_observations, extract_pending_observations
22-
from ax.exceptions.core import (
23-
AxError,
24-
DataRequiredError,
25-
UnsupportedError,
26-
UserInputError,
27-
)
22+
from ax.exceptions.core import AxError, DataRequiredError, UnsupportedError
2823
from ax.exceptions.generation_strategy import (
2924
GenerationStrategyCompleted,
3025
GenerationStrategyMisconfiguredException,
@@ -247,7 +242,6 @@ def gen(
247242
n: int | None = None,
248243
fixed_features: ObservationFeatures | None = None,
249244
num_trials: int = 1,
250-
arms_per_node: dict[str, int] | None = None,
251245
) -> list[list[GeneratorRun]]:
252246
"""Produce GeneratorRuns for multiple trials at once with the possibility of
253247
using multiple models per trial, getting multiple GeneratorRuns per trial.
@@ -275,12 +269,6 @@ def gen(
275269
important to specify all necessary fixed features.
276270
num_trials: Number of trials to generate generator runs for in this call.
277271
If not provided, defaults to 1.
278-
arms_per_node: An optional map from node name to the number of arms to
279-
generate from that node. If not provided, will default to the number
280-
of arms specified in the node's ``InputConstructors`` or n if no
281-
``InputConstructors`` are defined on the node. We expect either n or
282-
arms_per_node to be provided, but not both, and this is an advanced
283-
argument that should only be used by advanced users.
284272
285273
Returns:
286274
A list of lists of lists generator runs. Each outer list represents
@@ -306,7 +294,6 @@ def gen(
306294
data=data,
307295
n=n,
308296
pending_observations=pending_observations,
309-
arms_per_node=arms_per_node,
310297
fixed_features=fixed_features,
311298
first_generation_in_multi=len(grs_for_multiple_trials) < 1,
312299
)
@@ -467,24 +454,6 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
467454

468455
self._curr = nodes[0]
469456

470-
def _validate_arms_per_node(self, arms_per_node: dict[str, int] | None) -> None:
471-
"""Validate that the arms_per_node argument is valid if it is provided.
472-
473-
Args:
474-
arms_per_node: A map from node name to the number of arms to
475-
generate from that node.
476-
"""
477-
if arms_per_node is not None and not set(self.nodes_by_name).issubset(
478-
arms_per_node
479-
):
480-
raise UserInputError(
481-
"Each node defined in the `GenerationStrategy` must have an "
482-
"associated number of arms to generate from that node defined "
483-
f"in `arms_per_node`. {arms_per_node} does not include all of "
484-
f"{self.nodes_by_name.keys()}. "
485-
"It may help to double-check the spelling."
486-
)
487-
488457
def _make_default_name(self) -> str:
489458
"""Make a default name for this generation strategy; used when no name is passed
490459
to the constructor. For node-based generation strategies, the name is
@@ -515,10 +484,6 @@ def _gen_with_multiple_nodes(
515484
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
516485
data: Data | None = None,
517486
fixed_features: ObservationFeatures | None = None,
518-
# TODO: Consider naming `arms_per_node` smtg like `arms_per_node_override`,
519-
# to convey its manually-specified nature (if it's not specified, GS selects
520-
# what to do on its own).
521-
arms_per_node: dict[str, int] | None = None,
522487
first_generation_in_multi: bool = True,
523488
) -> list[GeneratorRun]:
524489
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
@@ -548,12 +513,6 @@ def _gen_with_multiple_nodes(
548513
passed down to the underlying nodes. Note: if provided this will
549514
override any algorithmically determined fixed features so it is
550515
important to specify all necessary fixed features.
551-
arms_per_node: An optional map from node name to the number of arms to
552-
generate from that node. If not provided, will default to the number
553-
of arms specified in the node's ``InputConstructors`` or n if no
554-
``InputConstructors`` are defined on the node. We expect either n or
555-
arms_per_node to be provided, but not both, and this is an advanced
556-
argument that should only be used by advanced users.
557516
558517
Returns:
559518
A list of ``GeneratorRuns`` for a single trial.
@@ -570,7 +529,6 @@ def _gen_with_multiple_nodes(
570529
pending_observations if pending_observations is not None else {}
571530
)
572531
self.experiment = experiment
573-
self._validate_arms_per_node(arms_per_node=arms_per_node)
574532
pack_gs_gen_kwargs = {
575533
"grs_this_gen": grs_this_gen,
576534
"fixed_features": fixed_features,
@@ -596,7 +554,6 @@ def _gen_with_multiple_nodes(
596554
pending_observations=pending_observations,
597555
skip_fit=not (first_generation_in_multi or transitioned),
598556
n=n,
599-
arms_per_node=arms_per_node,
600557
**pack_gs_gen_kwargs,
601558
)
602559
except DataRequiredError as err:

ax/generation_strategy/tests/test_generation_strategy.py

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,17 +1448,12 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
14481448
]
14491449
)
14501450
gs.experiment = exp
1451-
arms_per_node = {
1452-
"sobol_1": 2,
1453-
"sobol_2": 1,
1454-
"sobol_3": 3,
1455-
}
14561451
with mock_patch_method_original(
14571452
mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen",
14581453
original_method=GeneratorSpec.gen,
14591454
) as gen_spec_gen_mock:
14601455
# Generate a trial that should be composed of arms from 3 nodes
1461-
grs = gs.gen(experiment=exp, arms_per_node=arms_per_node)[0]
1456+
grs = gs.gen(experiment=exp, n=6)[0]
14621457
self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
14631458
self.assertEqual(gen_spec_gen_mock.call_count, 3)
14641459
pending_in_each_gen = enumerate(
@@ -1491,15 +1486,17 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
14911486
# check that we can pass in pending points
14921487
grs = gs.gen(
14931488
experiment=exp,
1494-
arms_per_node=arms_per_node,
1489+
n=3,
14951490
)[0]
14961491
self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
14971492
pending_in_each_gen = enumerate(
14981493
call_kwargs.get("pending_observations")
14991494
for _, call_kwargs in gen_spec_gen_mock.call_args_list
15001495
)
1501-
# check pending points is now 12 (from the previous trial having 6 arms)
1502-
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 12)
1496+
# check pending points is now 27 (18 arms from previous trial with 3 nodes
1497+
# each generating n=6 arms, plus 9 arms from the new generation with 3 nodes
1498+
# each generating n=3 arms)
1499+
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 27)
15031500

15041501
def test_gs_initializes_default_props_correctly(self) -> None:
15051502
"""Test that all previous nodes are initialized to None"""
@@ -1707,68 +1704,6 @@ def test_transition_edges(self) -> None:
17071704
},
17081705
)
17091706

1710-
def test_multiple_arms_per_node(self) -> None:
1711-
"""Test that a ``GenerationStrategy`` which expects some trials to be composed
1712-
of multiple nodes can generate multiple arms per node using `arms_per_node`.
1713-
"""
1714-
exp = get_branin_experiment()
1715-
gs = self.complex_multinode_per_trial_gs
1716-
gs.experiment = exp
1717-
# first check that arms_per node validation works
1718-
arms_per_node = {
1719-
"sobol": 3,
1720-
"sobol_2": 2,
1721-
"sobol_3": 1,
1722-
"sobol_4": 4,
1723-
}
1724-
with self.assertRaisesRegex(UserInputError, "defined in `arms_per_node`"):
1725-
gs.gen(exp, arms_per_node=arms_per_node)
1726-
1727-
# now we will check that the first trial contains 3 arms, the second trial
1728-
# contains 6 arms (2 from mbm, 1 from sobol_2, 3 from sobol_3), and all
1729-
# remaining trials contain 4 arms
1730-
arms_per_node = {
1731-
"sobol": 3,
1732-
"mbm": 1,
1733-
"sobol_2": 2,
1734-
"sobol_3": 3,
1735-
"sobol_4": 4,
1736-
}
1737-
# for the first trial, we start on sobol, we generate the trial, but it hasn't
1738-
# been run yet, so we remain on sobol
1739-
trial0 = exp.new_batch_trial(
1740-
generator_runs=gs.gen(exp, arms_per_node=arms_per_node)[0]
1741-
)
1742-
self.assertEqual(len(trial0.arms_by_name), 3)
1743-
self.assertEqual(trial0.generator_runs[0]._generation_node_name, "sobol")
1744-
trial0.run()
1745-
1746-
# after trial 0 is run, we create a trial with nodes mbm, sobol_2, and sobol_3
1747-
# However, the sobol_3 criterion requires that we have two running trials. We
1748-
# don't move onto sobol_4 until we have two running trials, instead we reset
1749-
# to the last first node in a trial.
1750-
for _i in range(0, 2):
1751-
trial = exp.new_batch_trial(
1752-
generator_runs=gs.gen(exp, arms_per_node=arms_per_node)[0]
1753-
)
1754-
self.assertEqual(gs.current_node_name, "sobol_3")
1755-
self.assertEqual(len(trial.arms_by_name), 6)
1756-
self.assertEqual(len(trial.generator_runs), 3)
1757-
self.assertEqual(trial.generator_runs[0]._generation_node_name, "mbm")
1758-
self.assertEqual(len(trial.generator_runs[0].arms), 1)
1759-
self.assertEqual(trial.generator_runs[1]._generation_node_name, "sobol_2")
1760-
self.assertEqual(len(trial.generator_runs[1].arms), 2)
1761-
self.assertEqual(trial.generator_runs[2]._generation_node_name, "sobol_3")
1762-
self.assertEqual(len(trial.generator_runs[2].arms), 3)
1763-
1764-
# after running the next trial should be made from sobol 4
1765-
trial.run()
1766-
trial = exp.new_batch_trial(
1767-
generator_runs=gs.gen(exp, arms_per_node=arms_per_node)[0]
1768-
)
1769-
self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4")
1770-
self.assertEqual(len(trial.generator_runs[0].arms), 4)
1771-
17721707
def test_gen_with_multiple_uses_total_concurrent_arms_for_a_default(self) -> None:
17731708
exp = get_branin_experiment()
17741709
self.sobol_node._input_constructors = {

0 commit comments

Comments
 (0)