Skip to content

Commit 12c53d3

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Add caching to common methods (#4830)
Summary: This method is called many, many times during generation and it's computational cost adds up over time. By cacheing it we can significant improvements in computation time, especially in high trial count regimes. Differential Revision: D91552553
1 parent 463aa35 commit 12c53d3

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

ax/generation_strategy/generation_node.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def __init__(
190190
self.fallback_specs = (
191191
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
192192
)
193+
# Cache for trials_from_node property to avoid recomputation
194+
# on every access. Invalidated when trial count changes.
195+
self._trials_from_node_cache: set[int] | None = None
196+
self._cached_trial_count: int = -1
193197

194198
@property
195199
def name(self) -> str:
@@ -724,17 +728,28 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec:
724728
def trials_from_node(self) -> set[int]:
725729
"""Returns a set containing the indices of trials generated by this node.
726730
731+
Results are cached and invalidated when the experiment's trial count changes.
732+
727733
Returns:
728734
Set[int]: A set containing all the indices of trials generated by this node.
729735
"""
736+
current_trial_count = len(self.experiment.trials)
737+
if (
738+
self._trials_from_node_cache is not None
739+
and self._cached_trial_count == current_trial_count
740+
):
741+
return self._trials_from_node_cache
742+
743+
# (re)-build cache
730744
trials_from_node = set()
731-
for _idx, trial in self.experiment.trials.items():
745+
for trial in self.experiment.trials.values():
732746
for gr in trial.generator_runs:
733-
if (
734-
gr._generation_node_name is not None
735-
and gr._generation_node_name == self.name
736-
):
747+
if gr._generation_node_name == self.name:
737748
trials_from_node.add(trial.index)
749+
break
750+
751+
self._trials_from_node_cache = trials_from_node
752+
self._cached_trial_count = current_trial_count
738753
return trials_from_node
739754

740755
@property

ax/generation_strategy/generation_strategy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ def _unset_non_persistent_state_fields(self) -> None:
361361
n._step_index = None
362362
if len(n.generator_specs) > 1:
363363
n._generator_spec_to_gen_from = None
364+
# Reset cache fields that are used for performance optimization only
365+
# and should not affect equality comparisons.
366+
n._trials_from_node_cache = None
367+
n._cached_trial_count = -1
364368

365369
# TODO: Deprecate `steps` argument fully in Q1'26.
366370
def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:

0 commit comments

Comments
 (0)