Skip to content

Commit f9c81dc

Browse files
committed
updated gemini tests
1 parent 70db896 commit f9c81dc

File tree

2 files changed

+19
-26
lines changed

2 files changed

+19
-26
lines changed

llmebench/models/Gemini.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
import vertexai
88
import vertexai.preview.generative_models as generative_models
9-
from google.oauth2 import service_account
109
from vertexai.generative_models import FinishReason, GenerativeModel, Part
11-
10+
from google.oauth2 import service_account
1211
from llmebench.models.model_base import ModelBase
1312

1413

@@ -56,11 +55,11 @@ def __init__(
5655
project_id=None,
5756
model_name=None,
5857
location=None,
59-
credentials_path=None, # path to JSON file
60-
credentials_info=None, # dict or JSON string
58+
credentials_path=None, # path to JSON file
59+
credentials_info=None, # dict or JSON string
6160
timeout=20,
6261
temperature=0,
63-
tolerance=1e-7,
62+
tolerance = 1e-7,
6463
top_p=0.95,
6564
max_tokens=2000,
6665
**kwargs,
@@ -74,34 +73,28 @@ def __init__(
7473
if credentials_info:
7574
if isinstance(credentials_info, str):
7675
credentials_info = json.loads(credentials_info)
77-
self.credentials = service_account.Credentials.from_service_account_info(
78-
credentials_info
79-
)
76+
self.credentials = service_account.Credentials.from_service_account_info(credentials_info)
8077
# 2. Else, load from path (arg or env)
8178
elif credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
8279
path = credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
8380
with open(path, "r") as f:
8481
info = json.load(f)
85-
self.credentials = service_account.Credentials.from_service_account_info(
86-
info
87-
)
82+
self.credentials = service_account.Credentials.from_service_account_info(info)
83+
elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None:
84+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
8885
# 3. Else, None: will fall back to ADC (Application Default Credentials)
8986

9087
if not self.project_id:
91-
raise Exception(
92-
"PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)"
93-
)
88+
raise Exception("PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)")
9489
if not self.model_name:
9590
raise Exception("MODEL must be set (argument or `MODEL` in .env)")
9691
if not self.location:
97-
raise Exception(
98-
"LOCATION must be set (argument or `VERTEX_LOCATION` in .env)"
99-
)
92+
raise Exception("LOCATION must be set (argument or `VERTEX_LOCATION` in .env)")
10093

10194
vertexai.init(
10295
project=self.project_id,
10396
location=self.location,
104-
credentials=self.credentials,
97+
credentials=self.credentials
10598
)
10699

107100
self.tolerance = tolerance

tests/models/test_Gemini.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,18 @@ class TestGeminiDepModelConfig(unittest.TestCase):
5353
def test_gemini_deployed_model_config(self):
5454
"Test if model config parameters passed as arguments are used"
5555
model = GeminiModel(
56-
project_id="test_project_id", api_key="secret-key", model_name="gemini-test"
56+
project_id="test_project_id", model_name="gemini-test", location="us-central1"
5757
)
5858

5959
self.assertEqual(model.project_id, "test_project_id")
60-
self.assertEqual(model.api_key, "secret-key")
60+
self.assertEqual(model.location, "us-central1")
6161
self.assertEqual(model.model_name, "gemini-test")
6262

6363
@patch.dict(
6464
"os.environ",
6565
{
6666
"GOOGLE_PROJECT_ID": "test_project_id",
67-
"GOOGLE_API_KEY": "secret-key",
67+
"LOCATION": "us-central1",
6868
"MODEL": "gemini-test",
6969
},
7070
)
@@ -73,23 +73,23 @@ def test_gemini_deployed_model_config_env_var(self):
7373
model = GeminiModel()
7474

7575
self.assertEqual(model.project_id, "test_project_id")
76-
self.assertEqual(model.api_key, "secret-key")
76+
self.assertEqual(model.location, "us-central1")
7777
self.assertEqual(model.model_name, "gemini-test")
7878

7979
@patch.dict(
8080
"os.environ",
8181
{
8282
"GOOGLE_PROJECT_ID": "test_project_id",
83-
"GOOGLE_API_KEY": "secret-env-key",
83+
"LOCATION": "us-central1",
8484
"MODEL": "gemini-test",
8585
},
8686
)
8787
def test_gemini_deployed_model_config_priority(self):
8888
"Test if model config parameters passed directly get priority"
8989
model = GeminiModel(
90-
project_id="test_project_id", api_key="secret-key", model_name="gemini_test"
90+
project_id="test_project_id", model_name="gemini_test", location="us-central1"
9191
)
9292

93-
self.assertEqual(model.project_id, "test_project_id")
94-
self.assertEqual(model.api_key, "secret-key")
93+
self.assertEqual(model.project_id, "test_project_id")
94+
self.assertEqual(model.location, "us-central1")
9595
self.assertEqual(model.model_name, "gemini_test")

0 commit comments

Comments
 (0)