Skip to content

Commit 01196ab

Browse files
authored
updated gemini code (#396)
* updated gemini code * updated gemini tests * updated gemini tests
1 parent 0df4531 commit 01196ab

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

llmebench/models/Gemini.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import vertexai
88
import vertexai.preview.generative_models as generative_models
9+
from google.oauth2 import service_account
910
from vertexai.generative_models import FinishReason, GenerativeModel, Part
1011

1112
from llmebench.models.model_base import ModelBase
@@ -53,50 +54,75 @@ class GeminiModel(ModelBase):
5354
def __init__(
5455
self,
5556
project_id=None,
56-
api_key=None,
5757
model_name=None,
58+
location=None,
59+
credentials_path=None, # path to JSON file
60+
credentials_info=None, # dict or JSON string
5861
timeout=20,
5962
temperature=0,
63+
tolerance=1e-7,
6064
top_p=0.95,
6165
max_tokens=2000,
6266
**kwargs,
6367
):
64-
# API parameters
65-
# self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL")
66-
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
6768
self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID")
6869
self.model_name = model_name or os.getenv("MODEL")
69-
if self.api_key is None:
70+
self.location = location or os.getenv("VERTEX_LOCATION") or "us-central1"
71+
self.credentials = None
72+
73+
# 1. Prefer explicit credentials_info (dict or JSON string)
74+
if credentials_info:
75+
if isinstance(credentials_info, str):
76+
credentials_info = json.loads(credentials_info)
77+
self.credentials = service_account.Credentials.from_service_account_info(
78+
credentials_info
79+
)
80+
# 2. Else, load from path (arg or env)
81+
elif credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
82+
path = credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
83+
with open(path, "r") as f:
84+
info = json.load(f)
85+
self.credentials = service_account.Credentials.from_service_account_info(
86+
info
87+
)
88+
elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None:
89+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.getenv(
90+
"GOOGLE_APPLICATION_CREDENTIALS"
91+
)
92+
# 3. Else, None: will fall back to ADC (Application Default Credentials)
93+
94+
if not self.project_id:
7095
raise Exception(
71-
"API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)"
96+
"PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)"
7297
)
73-
if self.project_id is None:
98+
if not self.model_name:
99+
raise Exception("MODEL must be set (argument or `MODEL` in .env)")
100+
if not self.location:
74101
raise Exception(
75-
"PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)"
102+
"LOCATION must be set (argument or `VERTEX_LOCATION` in .env)"
76103
)
77-
self.api_timeout = timeout
104+
105+
vertexai.init(
106+
project=self.project_id,
107+
location=self.location,
108+
credentials=self.credentials,
109+
)
110+
111+
self.tolerance = tolerance
112+
self.temperature = max(temperature, tolerance)
113+
self.top_p = top_p
114+
self.max_tokens = max_tokens
115+
78116
self.safety_settings = {
79117
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
80118
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
81119
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
82120
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
83121
}
84-
# Parameters
85-
tolerance = 1e-7
86-
self.temperature = temperature
87-
if self.temperature < tolerance:
88-
# Currently, the model inference fails if temperature
89-
# is exactly 0, so we nudge it slightly to work around
90-
# the issue
91-
self.temperature += tolerance
92-
self.top_p = top_p
93-
self.max_tokens = max_tokens
94122

95123
super(GeminiModel, self).__init__(
96124
retry_exceptions=(TimeoutError, GeminiFailure), **kwargs
97125
)
98-
vertexai.init(project=self.project_id, location="us-central1")
99-
# self.client = GenerativeModel(self.model_name)
100126

101127
def summarize_response(self, response):
102128
"""Returns the "outputs" key's value, if available"""
@@ -127,20 +153,6 @@ def prompt(self, processed_input):
127153
This method raises this exception if the server responded with a non-ok
128154
response
129155
"""
130-
# headers = {
131-
# "Content-Type": "application/json",
132-
# "Authorization": "Bearer " + self.api_key,
133-
# }
134-
# body = {
135-
# "input_data": {
136-
# "input_string": processed_input,
137-
# "parameters": {
138-
# "max_tokens": self.max_tokens,
139-
# "temperature": self.temperature,
140-
# "top_p": self.top_p,
141-
# },
142-
# }
143-
# }
144156
generation_config = {
145157
"max_output_tokens": self.max_tokens,
146158
"temperature": self.temperature,

tests/models/test_Gemini.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,20 @@ class TestGeminiDepModelConfig(unittest.TestCase):
5353
def test_gemini_deployed_model_config(self):
5454
"Test if model config parameters passed as arguments are used"
5555
model = GeminiModel(
56-
project_id="test_project_id", api_key="secret-key", model_name="gemini-test"
56+
project_id="test_project_id",
57+
model_name="gemini-test",
58+
location="us-central1",
5759
)
5860

5961
self.assertEqual(model.project_id, "test_project_id")
60-
self.assertEqual(model.api_key, "secret-key")
62+
self.assertEqual(model.location, "us-central1")
6163
self.assertEqual(model.model_name, "gemini-test")
6264

6365
@patch.dict(
6466
"os.environ",
6567
{
6668
"GOOGLE_PROJECT_ID": "test_project_id",
67-
"GOOGLE_API_KEY": "secret-key",
69+
"LOCATION": "us-central1",
6870
"MODEL": "gemini-test",
6971
},
7072
)
@@ -73,23 +75,25 @@ def test_gemini_deployed_model_config_env_var(self):
7375
model = GeminiModel()
7476

7577
self.assertEqual(model.project_id, "test_project_id")
76-
self.assertEqual(model.api_key, "secret-key")
78+
self.assertEqual(model.location, "us-central1")
7779
self.assertEqual(model.model_name, "gemini-test")
7880

7981
@patch.dict(
8082
"os.environ",
8183
{
8284
"GOOGLE_PROJECT_ID": "test_project_id",
83-
"GOOGLE_API_KEY": "secret-env-key",
85+
"LOCATION": "us-central1",
8486
"MODEL": "gemini-test",
8587
},
8688
)
8789
def test_gemini_deployed_model_config_priority(self):
8890
"Test if model config parameters passed directly get priority"
8991
model = GeminiModel(
90-
project_id="test_project_id", api_key="secret-key", model_name="gemini_test"
92+
project_id="test_project_id",
93+
model_name="gemini_test",
94+
location="us-central1",
9195
)
9296

9397
self.assertEqual(model.project_id, "test_project_id")
94-
self.assertEqual(model.api_key, "secret-key")
98+
self.assertEqual(model.location, "us-central1")
9599
self.assertEqual(model.model_name, "gemini_test")

0 commit comments

Comments
 (0)