Skip to content

Commit 2ea3fd1

Browse files
SebastianAmentmeta-codesync[bot]
authored andcommitted
Updating SearchSpace._validate_derived_parameter to allow string and Boolean types for simple copies (#4851)
Summary: Updating `SearchSpace._validate_derived_parameter` in light of the recent changes to `DerivedParameter`, enabling string and Boolean types for simple copy expressions. This commit will resolve the remaining error in https://www.internalfb.com/ax/experiment/derived_parameter_string_bool/trials. Pull Request resolved: #4851 Reviewed By: sunnyshen321 Differential Revision: D92164696 fbshipit-source-id: ebd8bdf74e76d77ff6d3ab8daa1778531ff29e63
1 parent b967b03 commit 2ea3fd1

File tree

2 files changed

+99
-6
lines changed

2 files changed

+99
-6
lines changed

ax/core/search_space.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -896,24 +896,49 @@ def validate_membership(self, parameters: TParameterization) -> None:
896896

897897
def _validate_derived_parameter(self, parameter: DerivedParameter) -> None:
898898
is_int = parameter.parameter_type == ParameterType.INT
899+
is_simple_copy = parameter._is_simple_copy
900+
derived_is_numeric = parameter.parameter_type in (
901+
ParameterType.INT,
902+
ParameterType.FLOAT,
903+
)
904+
899905
for p_name in parameter.parameter_names_to_weights.keys():
900906
p = self._parameters.get(p_name)
901907
if p is None:
902908
raise ValueError(
903909
f"Parameter {p_name} is not in the search space, but is used in a "
904910
"derived parameter."
905911
)
906-
if not p.is_numeric:
912+
913+
# For arithmetic expressions, source must be numeric
914+
if not is_simple_copy and not p.is_numeric:
907915
raise ValueError(
908916
f"Parameter {p_name} is not a float or int, but is used in a "
909-
"derived parameter."
917+
"derived parameter whose expression is not a simple copy."
918+
)
919+
920+
# For simple copies, validate type compatibility
921+
# Valid: exact type match OR both numeric (int can promote to float)
922+
if is_simple_copy:
923+
types_compatible = parameter.parameter_type == p.parameter_type or (
924+
derived_is_numeric and p.is_numeric
910925
)
911-
elif is_int and p.parameter_type == ParameterType.FLOAT:
926+
if not types_compatible:
927+
raise ValueError(
928+
f"Parameter {p_name} has type {p.parameter_type.name}, but the "
929+
f"derived parameter has type {parameter.parameter_type.name}. "
930+
"Simple copy derived parameters must have the same type as "
931+
"their source parameter."
932+
)
933+
934+
# Float source cannot be used with Int derived parameter
935+
if is_int and p.parameter_type == ParameterType.FLOAT:
912936
raise ValueError(
913937
f"Parameter {p_name} is a float, but is used in a derived "
914938
"parameter with int type."
915939
)
916-
elif isinstance(p, DerivedParameter):
940+
941+
if isinstance(p, DerivedParameter):
917942
raise ValueError(
918943
"Parameter cannot be derived from another derived parameter."
919944
)

ax/core/tests/test_search_space.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,15 +578,42 @@ def test_validate_derived_parameter(self) -> None:
578578
):
579579
self.ss1._validate_derived_parameter(parameter=self.invalid_derived_param)
580580

581-
# test with non-numeric param
581+
# test with non-numeric param used in arithmetic expression
582582
derived_param = DerivedParameter(
583-
name="z", parameter_type=ParameterType.FLOAT, expression_str="c"
583+
name="z", parameter_type=ParameterType.FLOAT, expression_str="2.0 * c"
584584
)
585585
with self.assertRaisesRegex(
586586
ValueError,
587587
"Parameter c is not a float or int, but is used in a derived parameter.",
588588
):
589589
self.ss1._validate_derived_parameter(parameter=derived_param)
590+
591+
# test simple copy type incompatibility: numeric derived from non-numeric source
592+
# tests the unified type compatibility rule: types must match OR both numeric
593+
derived_param = DerivedParameter(
594+
name="z", parameter_type=ParameterType.FLOAT, expression_str="c"
595+
)
596+
with self.assertRaisesRegex(
597+
ValueError,
598+
"Parameter c has type STRING, but the derived parameter has type FLOAT. "
599+
"Simple copy derived parameters must have the same type as their source "
600+
"parameter.",
601+
):
602+
self.ss1._validate_derived_parameter(parameter=derived_param)
603+
604+
# test simple copy type incompatibility: non-numeric derived from numeric source
605+
# same validation rule as above, different type combination
606+
derived_param = DerivedParameter(
607+
name="z", parameter_type=ParameterType.STRING, expression_str="a"
608+
)
609+
with self.assertRaisesRegex(
610+
ValueError,
611+
"Parameter a has type FLOAT, but the derived parameter has type STRING. "
612+
"Simple copy derived parameters must have the same type as their source "
613+
"parameter.",
614+
):
615+
self.ss1._validate_derived_parameter(parameter=derived_param)
616+
590617
# test int derived param with float constituent param
591618
derived_param = DerivedParameter(
592619
name="z", parameter_type=ParameterType.INT, expression_str="a"
@@ -597,6 +624,16 @@ def test_validate_derived_parameter(self) -> None:
597624
):
598625
self.ss1._validate_derived_parameter(parameter=derived_param)
599626

627+
# test int derived param with float constituent param (arithmetic expression)
628+
derived_param = DerivedParameter(
629+
name="z", parameter_type=ParameterType.INT, expression_str="2.0 * a"
630+
)
631+
with self.assertRaisesRegex(
632+
ValueError,
633+
"Parameter a is a float, but is used in a derived parameter with int type.",
634+
):
635+
self.ss1._validate_derived_parameter(parameter=derived_param)
636+
600637
# test derived param with constituent derived param
601638
derived_param = DerivedParameter(
602639
name="z", parameter_type=ParameterType.FLOAT, expression_str="h"
@@ -621,6 +658,37 @@ def test_validate_derived_parameter(self) -> None:
621658
):
622659
self.ss1._validate_derived_parameter(parameter=derived_param)
623660

661+
# test simple copy of STRING parameter - should succeed
662+
string_derived_param = DerivedParameter(
663+
name="derived_c", parameter_type=ParameterType.STRING, expression_str="c"
664+
)
665+
# This should NOT raise - it's a valid simple copy
666+
self.ss1._validate_derived_parameter(parameter=string_derived_param)
667+
668+
# test simple copy of BOOL parameter - should succeed
669+
# Add a non-fixed BOOL parameter to the search space
670+
bool_choice_param = ChoiceParameter(
671+
name="bool_choice", parameter_type=ParameterType.BOOL, values=[True, False]
672+
)
673+
self.ss1.add_parameter(bool_choice_param)
674+
bool_derived_param = DerivedParameter(
675+
name="derived_bool",
676+
parameter_type=ParameterType.BOOL,
677+
expression_str="bool_choice",
678+
)
679+
# This should NOT raise - it's a valid simple copy
680+
self.ss1._validate_derived_parameter(parameter=bool_derived_param)
681+
682+
# test simple copy INT to FLOAT promotion - should succeed
683+
# INT can be promoted to FLOAT (e.g., 3 -> 3.0)
684+
int_to_float_derived_param = DerivedParameter(
685+
name="derived_f_as_float",
686+
parameter_type=ParameterType.FLOAT,
687+
expression_str="f", # f is an INT parameter
688+
)
689+
# This should NOT raise - INT can be promoted to FLOAT
690+
self.ss1._validate_derived_parameter(parameter=int_to_float_derived_param)
691+
624692
def test_get_overlapping_parameters(self) -> None:
625693
with self.subTest("full_overlap"):
626694
range_param_1 = RangeParameter(

0 commit comments

Comments
 (0)