|
6 | 6 |
|
7 | 7 | import vertexai |
8 | 8 | import vertexai.preview.generative_models as generative_models |
| 9 | +from google.oauth2 import service_account |
9 | 10 | from vertexai.generative_models import FinishReason, GenerativeModel, Part |
10 | 11 |
|
11 | 12 | from llmebench.models.model_base import ModelBase |
@@ -53,50 +54,75 @@ class GeminiModel(ModelBase): |
53 | 54 | def __init__( |
54 | 55 | self, |
55 | 56 | project_id=None, |
56 | | - api_key=None, |
57 | 57 | model_name=None, |
| 58 | + location=None, |
| 59 | + credentials_path=None, # path to JSON file |
| 60 | + credentials_info=None, # dict or JSON string |
58 | 61 | timeout=20, |
59 | 62 | temperature=0, |
| 63 | + tolerance=1e-7, |
60 | 64 | top_p=0.95, |
61 | 65 | max_tokens=2000, |
62 | 66 | **kwargs, |
63 | 67 | ): |
64 | | - # API parameters |
65 | | - # self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL") |
66 | | - self.api_key = api_key or os.getenv("GOOGLE_API_KEY") |
67 | 68 | self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") |
68 | 69 | self.model_name = model_name or os.getenv("MODEL") |
69 | | - if self.api_key is None: |
| 70 | + self.location = location or os.getenv("VERTEX_LOCATION") or "us-central1" |
| 71 | + self.credentials = None |
| 72 | + |
| 73 | + # 1. Prefer explicit credentials_info (dict or JSON string) |
| 74 | + if credentials_info: |
| 75 | + if isinstance(credentials_info, str): |
| 76 | + credentials_info = json.loads(credentials_info) |
| 77 | + self.credentials = service_account.Credentials.from_service_account_info( |
| 78 | + credentials_info |
| 79 | + ) |
| 80 | + # 2. Else, load from path (arg or env) |
| 81 | + elif credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): |
| 82 | + path = credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS") |
| 83 | + with open(path, "r") as f: |
| 84 | + info = json.load(f) |
| 85 | + self.credentials = service_account.Credentials.from_service_account_info( |
| 86 | + info |
| 87 | + ) |
| 88 | + elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None: |
| 89 | + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.getenv( |
| 90 | + "GOOGLE_APPLICATION_CREDENTIALS" |
| 91 | + ) |
| 92 | + # 3. Else, None: will fall back to ADC (Application Default Credentials) |
| 93 | + |
| 94 | + if not self.project_id: |
70 | 95 | raise Exception( |
71 | | - "API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)" |
| 96 | + "PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)" |
72 | 97 | ) |
73 | | - if self.project_id is None: |
| 98 | + if not self.model_name: |
| 99 | + raise Exception("MODEL must be set (argument or `MODEL` in .env)") |
| 100 | + if not self.location: |
74 | 101 | raise Exception( |
75 | | - "PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)" |
| 102 | + "LOCATION must be set (argument or `VERTEX_LOCATION` in .env)" |
76 | 103 | ) |
77 | | - self.api_timeout = timeout |
| 104 | + |
| 105 | + vertexai.init( |
| 106 | + project=self.project_id, |
| 107 | + location=self.location, |
| 108 | + credentials=self.credentials, |
| 109 | + ) |
| 110 | + |
| 111 | + self.tolerance = tolerance |
| 112 | + self.temperature = max(temperature, tolerance) |
| 113 | + self.top_p = top_p |
| 114 | + self.max_tokens = max_tokens |
| 115 | + |
78 | 116 | self.safety_settings = { |
79 | 117 | generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, |
80 | 118 | generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, |
81 | 119 | generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, |
82 | 120 | generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, |
83 | 121 | } |
84 | | - # Parameters |
85 | | - tolerance = 1e-7 |
86 | | - self.temperature = temperature |
87 | | - if self.temperature < tolerance: |
88 | | - # Currently, the model inference fails if temperature |
89 | | - # is exactly 0, so we nudge it slightly to work around |
90 | | - # the issue |
91 | | - self.temperature += tolerance |
92 | | - self.top_p = top_p |
93 | | - self.max_tokens = max_tokens |
94 | 122 |
|
95 | 123 | super(GeminiModel, self).__init__( |
96 | 124 | retry_exceptions=(TimeoutError, GeminiFailure), **kwargs |
97 | 125 | ) |
98 | | - vertexai.init(project=self.project_id, location="us-central1") |
99 | | - # self.client = GenerativeModel(self.model_name) |
100 | 126 |
|
101 | 127 | def summarize_response(self, response): |
102 | 128 | """Returns the "outputs" key's value, if available""" |
@@ -127,20 +153,6 @@ def prompt(self, processed_input): |
127 | 153 | This method raises this exception if the server responded with a non-ok |
128 | 154 | response |
129 | 155 | """ |
130 | | - # headers = { |
131 | | - # "Content-Type": "application/json", |
132 | | - # "Authorization": "Bearer " + self.api_key, |
133 | | - # } |
134 | | - # body = { |
135 | | - # "input_data": { |
136 | | - # "input_string": processed_input, |
137 | | - # "parameters": { |
138 | | - # "max_tokens": self.max_tokens, |
139 | | - # "temperature": self.temperature, |
140 | | - # "top_p": self.top_p, |
141 | | - # }, |
142 | | - # } |
143 | | - # } |
144 | 156 | generation_config = { |
145 | 157 | "max_output_tokens": self.max_tokens, |
146 | 158 | "temperature": self.temperature, |
|
0 commit comments