@@ -52,36 +52,42 @@ def setup():
5252 ),
5353 inquirer .List ('backend' ,
5454 message = "LLM Backend" ,
55- choices = ["OpenAI" , "Azure OpenAI" ],
55+ choices = ["OpenAI" , "Azure OpenAI" , "Mistral" ],
5656 default = "OpenAI"
5757 ),
5858 inquirer .List ("auth_type" ,
5959 message = "Authentication Type" ,
6060 choices = ["API Key" , "Native Authentication (DefaultAzureCredential)" ],
6161 default = "API Key" ,
62- ignore = lambda answers : answers ['backend' ] == "OpenAI"
62+ ignore = lambda answers : answers ['backend' ] in [ "OpenAI" , "Mistral" ]
6363 ),
6464 inquirer .Text ('api_key' ,
65- message = 'OpenAI API Key' ,
65+ message = 'API Key' ,
6666 ignore = lambda answers : answers ['auth_type' ] == "Native Authentication (DefaultAzureCredential)"
6767 ),
6868 inquirer .List ("openai_model_name" ,
6969 message = "Model Name" ,
7070 choices = ["gpt-3.5-turbo" , "gpt-4-turbo-preview" ],
7171 default = "gpt-4-turbo-preview" ,
72- ignore = lambda answers : answers ['backend' ] == "Azure OpenAI"
72+ ignore = lambda answers : answers ['backend' ] != "OpenAI"
73+ ),
74+ inquirer .List ("mistral_model_name" ,
75+ message = "Model Name" ,
76+ choices = ["mistral-large-latest" ],
77+ default = "mistral-large-latest" ,
78+ ignore = lambda answers : answers ['backend' ] != "Mistral"
7379 ),
7480 inquirer .Text ("azure_endpoint" ,
7581 message = "Azure OpenAI Endpoint" ,
76- ignore = lambda answers : answers ['backend' ] == " OpenAI"
82+ ignore = lambda answers : answers ['backend' ] != "Azure OpenAI"
7783 ),
7884 inquirer .Text ("azure_openai_model" ,
7985 message = "Azure OpenAI Model" ,
80- ignore = lambda answers : answers ['backend' ] == " OpenAI"
86+ ignore = lambda answers : answers ['backend' ] != "Azure OpenAI"
8187 ),
8288 inquirer .Text ("azure_openai_deployment" ,
8389 message = "Azure OpenAI Deployment" ,
84- ignore = lambda answers : answers ['backend' ] == " OpenAI"
90+ ignore = lambda answers : answers ['backend' ] != "Azure OpenAI"
8591 ),
8692 ]
8793 answer = inquirer .prompt (question )
@@ -147,8 +153,6 @@ def gen(model_name, description, input):
147153@click .option ("--diff" , "-d" , is_flag = True , help = "Show the diff between existing and suggested code" , default = False )
148154def fix (model_name , description , diff ):
149155 manifest = Manifest ()
150- click .echo (model_name )
151- click .echo (description )
152156
153157 model = manifest .fix (model_name , description )
154158
@@ -211,4 +215,14 @@ def hello():
211215\____ | |___ /__| \____|__ /___|
212216 \/ \/ \/
213217"""
214- click .echo (greeting )
218+ click .echo (greeting )
219+
220+
221+ @dbtai .command (help = "Generate a dbt test" )
222+ @click .argument ("model" , required = True )
223+ @click .argument ("description" , required = True )
224+ def test (model , description ):
225+ raise NotImplementedError ("Not yet implemented" )
226+ manifest = Manifest ()
227+ test = manifest .generate_test (model , description )
228+ click .echo (test )
0 commit comments