@@ -117,6 +117,46 @@ def __init__(self, arg_id: str, message: str):
117117 super ().__init__ (self .message )
118118
119119
120+ # This is included with argparse in Python 3.9+ and above, but we're also
121+ # supporting 3.8 so the action is inlined here
122+ class BooleanOptionalAction (argparse .Action ):
123+ def __init__ (
124+ self ,
125+ option_strings ,
126+ dest ,
127+ default = None ,
128+ type = None ,
129+ choices = None ,
130+ required = False ,
131+ help = None ,
132+ metavar = None ,
133+ ):
134+ _option_strings = []
135+ for option_string in option_strings :
136+ _option_strings .append (option_string )
137+ if option_string .startswith ("--" ):
138+ option_string = "--no-" + option_string [2 :]
139+ _option_strings .append (option_string )
140+ super ().__init__ (
141+ option_strings = _option_strings ,
142+ dest = dest ,
143+ nargs = 0 ,
144+ default = default ,
145+ type = type ,
146+ choices = choices ,
147+ required = required ,
148+ help = help ,
149+ metavar = metavar ,
150+ )
151+
152+ def __call__ (self , parser , namespace , values , option_string = None ):
153+ if option_string is not None and option_string in self .option_strings :
154+ setattr (namespace , self .dest , not option_string .startswith ("--no-" ))
155+
156+ def format_usage (self ) -> str :
157+ return " | " .join (self .option_strings )
158+
159+
120160@dataclass
121161class Arg :
122162 """Field for defining the CLI argument in the decorator."""
@@ -306,10 +346,8 @@ def get_arg(
306346 f"boolean arguments need to be flags, e.g. --{ arg .id .replace ('_' , '-' )} " ,
307347 )
308348 arg .type = None
309- if default is True :
310- raise InvalidArgumentError (arg .id , "boolean flags need to default to False" )
311- arg .default = False
312- arg .action = "store_true"
349+ arg .default = False if default is not True else True
350+ arg .action = "store_true" if arg .default is False else BooleanOptionalAction
313351 return arg
314352 if inspect .isclass (param_type ) and issubclass (param_type , Enum ):
315353 arg .choices = list (param_type .__members__ .keys ())
0 commit comments