2424#
2525###############################################################################
2626import argparse
27- from typing import Optional , Type
27+ from typing import Literal , Optional , Type , get_args , get_origin
2828
2929from pydantic import BaseModel
3030
@@ -96,19 +96,50 @@ def get_model_arg(cls, type_class_map: dict) -> Optional[Type[BaseModel]]:
9696 None ,
9797 )
9898
99+ @classmethod
100+ def get_literal_choices (cls , type_class_map : dict ) -> Optional [list ]:
101+ """Get the choices from a Literal type if present
102+
103+ Args:
104+ type_class_map (dict): mapping of type classes
105+
106+ Returns:
107+ Optional[list]: list of valid choices for the Literal type, or None if not a Literal
108+ """
109+ # Check if Literal is in the type_class_map
110+ literal_type = type_class_map .get (Literal )
111+ if literal_type and literal_type .inner_type is not None :
112+ # The inner_type contains the first literal value, but we need all of them
113+ # We need to get the original annotation to extract all literal values
114+ # For now, return None and we'll handle this differently
115+ return None
116+ return None
117+
99118 def add_argument (
100119 self ,
101120 type_class_map : dict ,
102121 arg_name : str ,
103122 required : bool ,
123+ annotation : Optional [Type ] = None ,
104124 ) -> None :
105125 """Add an argument to a parser with an appropriate type
106126
107127 Args:
108128 type_class_map (dict): type classes for the arg
109129 arg_name (str): argument name
110130 required (bool): whether or not the arg is required
131+ annotation (Optional[Type]): full type annotation for extracting Literal choices
111132 """
133+ # Check for Literal types and extract choices
134+ literal_choices = None
135+ if Literal in type_class_map and annotation :
136+ # Extract all arguments from the annotation
137+ args = get_args (annotation )
138+ for arg in args :
139+ if get_origin (arg ) is Literal :
140+ literal_choices = list (get_args (arg ))
141+ break
142+
112143 if list in type_class_map :
113144 type_class = type_class_map [list ]
114145 self .parser .add_argument (
@@ -125,6 +156,15 @@ def add_argument(
125156 required = required ,
126157 choices = [True , False ],
127158 )
159+ elif Literal in type_class_map and literal_choices :
160+ # Add argument with choices for Literal types
161+ self .parser .add_argument (
162+ f"--{ arg_name } " ,
163+ type = str ,
164+ required = required ,
165+ choices = literal_choices ,
166+ metavar = f"{{{ ',' .join (literal_choices )} }}" ,
167+ )
128168 elif float in type_class_map :
129169 self .parser .add_argument (
130170 f"--{ arg_name } " , type = float , required = required , metavar = META_VAR_MAP [float ]
@@ -166,6 +206,10 @@ def build_model_arg_parser(self, model: type[BaseModel], required: bool) -> list
166206 if type (None ) in type_class_map and len (attr_data .type_classes ) == 1 :
167207 continue
168208
169- self .add_argument (type_class_map , attr .replace ("_" , "-" ), required )
209+ # Get the full annotation from the model field
210+ field = model .model_fields .get (attr )
211+ annotation = field .annotation if field else None
212+
213+ self .add_argument (type_class_map , attr .replace ("_" , "-" ), required , annotation )
170214
171215 return list (type_map .keys ())
0 commit comments