2929DATA_REQUIRED_MSG = (
3030 "All trials for current node {node_name} have been generated, "
3131 "but not enough data has been observed to proceed to the next "
32- "Generation node. Try again when more is are available."
32+ "Generation node. Try again when more data is available."
3333)
3434
3535
@@ -48,9 +48,7 @@ class TransitionCriterion(SortableBase):
4848 we will raise an error, otherwise we will continue to generate trials
4949 until ``MinTrials`` is met (thus overriding MinTrials).
5050 block_transition_if_unmet: A flag to prevent the node from completing and
51- being able to transition to another node. Ex: MaxGenerationParallelism
52- defaults to setting this to False since we can complete and move on from
53- this node without ever reaching its threshold.
51+ being able to transition to another node.
5452 continue_trial_generation: A flag to indicate that all generation for a given
5553 trial is not completed, and thus even after transition, the next node will
5654 continue to generate arms for the same trial. Example usage: in
@@ -234,9 +232,7 @@ class TrialBasedCriterion(TransitionCriterion):
234232 threshold: The threshold as an integer for this criterion. Ex: If we want to
235233 generate at most 3 trials, then the threshold is 3.
236234 block_transition_if_unmet: A flag to prevent the node from completing and
237- being able to transition to another node. Ex: MaxGenerationParallelism
238- defaults to setting this to False since we can complete and move on from
239- this node without ever reaching its threshold.
235+ being able to transition to another node.
240236 block_gen_if_met: A flag to prevent continued generation from the
241237 associated GenerationNode if this criterion is met but other criterion
242238 remain unmet. Ex: ``MinTrials`` has not been met yet, but
@@ -383,94 +379,18 @@ def is_met(
383379 )
384380
385381
386- class MaxGenerationParallelism (TrialBasedCriterion ):
387- """Specific TransitionCriterion implementation which defines the maximum number
388- of trials that can simultaneously be in the designated trial statuses. The
389- default behavior is to block generation from the associated GenerationNode if the
390- threshold is met. This is configured via the `block_gen_if_met` flag being set to
391- True. This criterion defaults to not blocking transition to another node via the
392- `block_transition_if_unmet` flag being set to False.
393-
394- Args:
395- threshold: The threshold as an integer for this criterion. Ex: If we want to
396- generate at most 3 trials, then the threshold is 3.
397- only_in_statuses: A list of trial statuses to filter on when checking the
398- criterion threshold.
399- not_in_statuses: A list of trial statuses to exclude when checking the
400- criterion threshold.
401- transition_to: The name of the GenerationNode the GenerationStrategy should
402- transition to when this criterion is met, if it exists.
403- block_transition_if_unmet: A flag to prevent the node from completing and
404- being able to transition to another node. Ex: MaxGenerationParallelism
405- defaults to setting this to False since we can complete and move on from
406- this node without ever reaching its threshold.
407- block_gen_if_met: A flag to prevent continued generation from the
408- associated GenerationNode if this criterion is met but other criterion
409- remain unmet. Ex: ``MinTrials`` has not been met yet, but
410- MinTrials has been reached. If this flag is set to true on MinTrials then
411- we will raise an error, otherwise we will continue to generate trials
412- until ``MinTrials`` is met (thus overriding MinTrials).
413- use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
414- only those generated by the current GenerationNode.
415- continue_trial_generation: A flag to indicate that all generation for a given
416- trial is not completed, and thus even after transition, the next node will
417- continue to generate arms for the same trial. Example usage: in
418- ``BatchTrial``s we may enable generation of arms within a batch from
419- different ``GenerationNodes`` by setting this flag to True. Defaults to
420- False for MaxGenerationParallelism since this criterion isn't currently
421- used for node -> node or trial -> trial transition.
422- count_only_trials_with_data: If set to True, only trials with data will be
423- counted towards the ``threshold``. Defaults to False.
382+ class MinTrials (TrialBasedCriterion ):
424383 """
384+ Simple class to enforce a threshold for the number of trials with the
385+ designated statuses being generated by a specific GenerationNode.
425386
426- def __init__ (
427- self ,
428- threshold : int ,
429- transition_to : str ,
430- only_in_statuses : list [TrialStatus ] | None = None ,
431- not_in_statuses : list [TrialStatus ] | None = None ,
432- block_transition_if_unmet : bool | None = False ,
433- block_gen_if_met : bool | None = True ,
434- use_all_trials_in_exp : bool | None = False ,
435- continue_trial_generation : bool | None = False ,
436- ) -> None :
437- super ().__init__ (
438- threshold = threshold ,
439- only_in_statuses = only_in_statuses ,
440- not_in_statuses = not_in_statuses ,
441- transition_to = transition_to ,
442- block_gen_if_met = block_gen_if_met ,
443- block_transition_if_unmet = block_transition_if_unmet ,
444- use_all_trials_in_exp = use_all_trials_in_exp ,
445- continue_trial_generation = continue_trial_generation ,
446- )
447-
448- def block_continued_generation_error (
449- self ,
450- node_name : str ,
451- experiment : Experiment ,
452- trials_from_node : set [int ],
453- ) -> None :
454- """Raises the appropriate error (should only be called when the
455- ``GenerationNode`` is blocked from continued generation). For this
456- class, the exception is ``MaxParallelismReachedException``.
457- """
458- assert self .block_gen_if_met # Sanity check.
459- raise MaxParallelismReachedException (
460- node_name = node_name ,
461- num_running = self .num_contributing_to_threshold (
462- experiment = experiment , trials_from_node = trials_from_node
463- ),
464- )
387+ This class can be configured to behave as either:
388+ - A minimum trials criterion (default): blocks transition to next node until
389+ the threshold is met; block_transition_if_unmet=True
390+ - A maximum parallelism criterion: blocks further generation from the current node
391+ when threshold is reached; block_gen_if_met=True
465392
466393
467- class MinTrials (TrialBasedCriterion ):
468- """
469- Simple class to enforce a minimum threshold for the number of trials with the
470- designated statuses being generated by a specific GenerationNode. The default
471- behavior is to block transition to the next node if the threshold is unmet, but
472- not affect continued generation.
473-
474394 Args:
475395 threshold: The threshold as an integer for this criterion. Ex: If we want to
476396 generate at most 3 trials, then the threshold is 3.
@@ -481,15 +401,12 @@ class MinTrials(TrialBasedCriterion):
481401 transition_to: The name of the GenerationNode the GenerationStrategy should
482402 transition to when this criterion is met.
483403 block_transition_if_unmet: A flag to prevent the node from completing and
484- being able to transition to another node. Ex: MaxGenerationParallelism
485- defaults to setting this to False since we can complete and move on from
486- this node without ever reaching its threshold.
404+ being able to transition to another node. Defaults to True for minimum
405+ trials behavior. Set to False for maximum parallelism behavior.
487406 block_gen_if_met: A flag to prevent continued generation from the
488407 associated GenerationNode if this criterion is met but other criterion
489- remain unmet. Ex: ``MinTrials`` has not been met yet, but
490- MinTrials has been reached. If this flag is set to true on MinTrials then
491- we will raise an error, otherwise we will continue to generate trials
492- until ``MinTrials`` is met (thus overriding MinTrials).
408+ remain unmet. Defaults to False for minimum trials behavior. Set to True
409+ for maximum parallelism behavior.
493410 use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
494411 only those generated by the current GenerationNode.
495412 continue_trial_generation: A flag to indicate that all generation for a given
@@ -531,12 +448,27 @@ def block_continued_generation_error(
531448 experiment : Experiment ,
532449 trials_from_node : set [int ],
533450 ) -> None :
534- """Raises the appropriate error (should only be called when the
535- ``GenerationNode`` is blocked from continued generation). For this
536- class, the exception is ``DataRequiredError``.
451+ """Raises the appropriate error when generation is blocked.
452+
453+ This method is called when block_gen_if_met=True and the criterion is met.
454+ The exception type depends on the criterion's behavior:
455+ - Max parallelism (block_transition_if_unmet=False):
456+ MaxParallelismReachedException
457+ - Enforce num_trials (block_transition_if_unmet=True): DataRequiredError
537458 """
538459 assert self .block_gen_if_met # Sanity check.
539- raise DataRequiredError (DATA_REQUIRED_MSG .format (node_name = node_name ))
460+ if not self .block_transition_if_unmet :
461+ # Max parallelism behavior: temporarily blocked, waiting for trials
462+ # to complete
463+ raise MaxParallelismReachedException (
464+ node_name = node_name ,
465+ num_running = self .num_contributing_to_threshold (
466+ experiment = experiment , trials_from_node = trials_from_node
467+ ),
468+ )
469+ else :
470+ # Enforce num_trials behavior: hard limit reached, need to transition
471+ raise DataRequiredError (DATA_REQUIRED_MSG .format (node_name = node_name ))
540472
541473
542474class AuxiliaryExperimentCheck (TransitionCriterion ):
@@ -571,9 +503,7 @@ class AuxiliaryExperimentCheck(TransitionCriterion):
571503 we will raise an error, otherwise we will continue to generate trials
572504 until ``MinTrials`` is met (thus overriding MinTrials).
573505 block_transition_if_unmet: A flag to prevent the node from completing and
574- being able to transition to another node. Ex: MaxGenerationParallelism
575- defaults to setting this to False since we can complete and move on from
576- this node without ever reaching its threshold.
506+ being able to transition to another node.
577507 continue_trial_generation: A flag to indicate that all generation for a given
578508 trial is not completed, and thus even after transition, the next node will
579509 continue to generate arms for the same trial. Example usage: in
0 commit comments