Skip to content

Commit c65e241

Browse files
committed
Mistral backend option
1 parent 9f7eadd commit c65e241

File tree

3 files changed

+52
-20
lines changed

3 files changed

+52
-20
lines changed

dbtai/cli.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,36 +52,42 @@ def setup():
5252
),
5353
inquirer.List('backend',
5454
message ="LLM Backend",
55-
choices = ["OpenAI", "Azure OpenAI"],
55+
choices = ["OpenAI", "Azure OpenAI", "Mistral"],
5656
default = "OpenAI"
5757
),
5858
inquirer.List("auth_type",
5959
message = "Authentication Type",
6060
choices = ["API Key", "Native Authentication (DefaultAzureCredential)"],
6161
default = "API Key",
62-
ignore = lambda answers: answers['backend'] == "OpenAI"
62+
ignore = lambda answers: answers['backend'] in ["OpenAI", "Mistral"]
6363
),
6464
inquirer.Text('api_key',
65-
message='OpenAI API Key',
65+
message='API Key',
6666
ignore = lambda answers: answers['auth_type'] == "Native Authentication (DefaultAzureCredential)"
6767
),
6868
inquirer.List("openai_model_name",
6969
message = "Model Name",
7070
choices = ["gpt-3.5-turbo", "gpt-4-turbo-preview"],
7171
default = "gpt-4-turbo-preview",
72-
ignore = lambda answers: answers['backend'] == "Azure OpenAI"
72+
ignore = lambda answers: answers['backend'] != "OpenAI"
73+
),
74+
inquirer.List("mistral_model_name",
75+
message = "Model Name",
76+
choices = ["mistral-large-latest"],
77+
default = "mistral-large-latest",
78+
ignore = lambda answers: answers['backend'] != "Mistral"
7379
),
7480
inquirer.Text("azure_endpoint",
7581
message = "Azure OpenAI Endpoint",
76-
ignore = lambda answers: answers['backend'] == "OpenAI"
82+
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
7783
),
7884
inquirer.Text("azure_openai_model",
7985
message = "Azure OpenAI Model",
80-
ignore = lambda answers: answers['backend'] == "OpenAI"
86+
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
8187
),
8288
inquirer.Text("azure_openai_deployment",
8389
message = "Azure OpenAI Deployment",
84-
ignore = lambda answers: answers['backend'] == "OpenAI"
90+
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
8591
),
8692
]
8793
answer = inquirer.prompt(question)
@@ -147,8 +153,6 @@ def gen(model_name, description, input):
147153
@click.option("--diff", "-d", is_flag=True, help="Show the diff between existing and suggested code", default=False)
148154
def fix(model_name, description, diff):
149155
manifest = Manifest()
150-
click.echo(model_name)
151-
click.echo(description)
152156

153157
model = manifest.fix(model_name, description)
154158

@@ -211,4 +215,14 @@ def hello():
211215
\____ | |___ /__| \____|__ /___|
212216
\/ \/ \/
213217
"""
214-
click.echo(greeting)
218+
click.echo(greeting)
219+
220+
221+
@dbtai.command(help="Generate a dbt test")
222+
@click.argument("model", required=True)
223+
@click.argument("description", required=True)
224+
def test(model, description):
225+
raise NotImplementedError("Not yet implemented")
226+
manifest = Manifest()
227+
test = manifest.generate_test(model, description)
228+
click.echo(test)

dbtai/manifest.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import appdirs
1212
import yaml
1313
from openai import OpenAI
14+
from mistralai.client import MistralClient
1415
from ruamel.yaml import YAML
1516
from ruamel.yaml.comments import CommentedMap
1617
import io
@@ -40,21 +41,26 @@ def __init__(
4041
self.manifest = json.load(file)
4142

4243
self.config = self._load_config()
43-
self.client = self._make_openai_client()
44-
44+
if self.config['backend'] == "Mistral":
45+
self.client = self._make_mistral_client()
46+
elif self.config['backend'] == "Azure OpenAI":
47+
raise NotImplementedError("Azure OpenAI not yet implemented")
48+
else:
49+
self.client = self._make_openai_client()
4550
def _make_openai_client(self):
4651
"""Make the OpenAI client with auth."""
52+
4753
if self.config['backend'] == "OpenAI":
4854
api_key = self.config['api_key'] or os.getenv("OPENAI_API_KEY")
4955
return OpenAI(api_key=api_key)
5056
else:
5157
raise NotImplementedError("Azure OpenAI not yet implemented")
5258

53-
# return OpenAI(
54-
# endpoint=self.config['azure_endpoint'],
55-
# model=self.config['azure_openai_model'],
56-
# deployment=self.config['azure_openai_deployment']
57-
# )
59+
def _make_mistral_client(self):
60+
"""Make the Mistral client with auth."""
61+
client = MistralClient(api_key=self.config['api_key'])
62+
return client
63+
5864

5965
def chat_completion(self, messages, response_format_type="json_object"):
6066
"""Convenience method to call the chat completion endpoint.
@@ -67,13 +73,24 @@ def chat_completion(self, messages, response_format_type="json_object"):
6773
openai.ChatCompletion: The response from the chat API
6874
"""
6975
if self.config["backend"] == "OpenAI":
76+
if not self.config.get("openai_model_name"):
77+
raise ValueError("OpenAI model name not set in config")
78+
7079
return self.client.chat.completions.create(
71-
model=self.config['openai_model_name'],
80+
model=self.config.get('openai_model_name', 'gpt-4-turbo-preview'),
7281
messages=messages,
7382
response_format={"type": response_format_type}
7483
)
84+
elif self.config["backend"] == "Mistral":
85+
86+
return self.client.chat(
87+
model=self.config.get("mistral_model_name", "mistral-large-latest"),
88+
messages=messages,
89+
response_format={"type": response_format_type},
90+
)
91+
7592
else:
76-
raise NotImplementedError("Azure OpenAI not yet implemented")
93+
raise NotImplementedError("Your backend is set to Azure OpenAI not yet implemented")
7794

7895
def _load_config(self):
7996
"""Convenience function to load the user config from the config file."""

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "dbtai"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
description = "`dbtai` is a utility CLI command to generate dbt model documentation for a given model using OpenAI."
55
authors = ["Henning Holgersen"]
66
keywords = [
@@ -13,6 +13,7 @@ license = "Apache 2.0"
1313
[tool.poetry.dependencies]
1414
python = "<3.14,>=3.8.0"
1515
openai = ">1.1.0"
16+
mistralai = ">=0.1.3,<2"
1617
click = "^8.1.3"
1718
"ruamel.yaml" = "^0.18.6"
1819
inquirer = "^3.2.4"

0 commit comments

Comments
 (0)