@@ -433,24 +433,9 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
433433 # Validate transition edges:
434434 # - All `transition_to` targets must exist in this GS
435435 # - All TCs on one edge must have the same `continue_trial_generation` setting
436- # All but `MaxGenerationParallelism` TCs must have a `transition_to` set
437436 for node in nodes :
438437 for next_node , tcs in node .transition_edges .items ():
439- if next_node is None :
440- # TODO[drfreund]: Handle the case of the last generation step not
441- # having any transition criteria.
442- # TODO[mgarrard]: Remove MaxGenerationParallelism check when
443- # we update TransitionCriterion always define `transition_to`
444- # NOTE: This is done in D86066476
445- for tc in tcs :
446- if "MaxGenerationParallelism" not in tc .criterion_class :
447- raise GenerationStrategyMisconfiguredException (
448- error_info = "Only MaxGenerationParallelism transition"
449- " criterion can have a null `transition_to` argument,"
450- f" but { tc .criterion_class } does not define "
451- f"`transition_to` on { node .name } ."
452- )
453- elif next_node not in node_names :
438+ if next_node not in node_names :
454439 raise GenerationStrategyMisconfiguredException (
455440 error_info = f"`transition_to` argument "
456441 f"{ next_node } does not correspond to any node in"
@@ -612,7 +597,6 @@ def _should_continue_gen_for_trial(self) -> bool:
612597 # if we will transition nodes, check if the transition criterion which define
613598 # the transition from this node to the next node indicate that we should
614599 # continue generating in the same trial, otherwise end the generation.
615- assert next_node is not None
616600 return all (
617601 tc .continue_trial_generation
618602 for tc in self ._curr .transition_edges [next_node ]
@@ -648,12 +632,5 @@ def _maybe_transition_to_next_node(
648632 f"Generation strategy { self } generated all the trials as "
649633 "specified in its nodes."
650634 )
651- if next_node is None :
652- # If the last node did not specify which node to transition to,
653- # move to the next node in the list.
654- current_node_index = self ._nodes .index (self ._curr )
655- next_node = self ._nodes [current_node_index + 1 ].name
656- for node in self ._nodes :
657- if node .name == next_node :
658- self ._curr = node
635+ self ._curr = self .nodes_by_name [next_node ]
659636 return move_to_next_node
0 commit comments