Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions llmebench/models/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import vertexai
import vertexai.preview.generative_models as generative_models
from google.oauth2 import service_account
from vertexai.generative_models import FinishReason, GenerativeModel, Part

from llmebench.models.model_base import ModelBase
Expand Down Expand Up @@ -53,50 +54,75 @@ class GeminiModel(ModelBase):
def __init__(
self,
project_id=None,
api_key=None,
model_name=None,
location=None,
credentials_path=None, # path to JSON file
credentials_info=None, # dict or JSON string
timeout=20,
temperature=0,
tolerance=1e-7,
top_p=0.95,
max_tokens=2000,
**kwargs,
):
# API parameters
# self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL")
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID")
self.model_name = model_name or os.getenv("MODEL")
if self.api_key is None:
self.location = location or os.getenv("VERTEX_LOCATION") or "us-central1"
self.credentials = None

# 1. Prefer explicit credentials_info (dict or JSON string)
if credentials_info:
if isinstance(credentials_info, str):
credentials_info = json.loads(credentials_info)
self.credentials = service_account.Credentials.from_service_account_info(
credentials_info
)
# 2. Else, load from path (arg or env)
elif credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS"):
path = credentials_path or os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
with open(path, "r") as f:
info = json.load(f)
self.credentials = service_account.Credentials.from_service_account_info(
info
)
elif os.getenv("GOOGLE_APPLICATION_CREDENTIALS") is not None:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.getenv(
"GOOGLE_APPLICATION_CREDENTIALS"
)
# 3. Else, None: will fall back to ADC (Application Default Credentials)

if not self.project_id:
raise Exception(
"API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)"
"PROJECT_ID must be set (argument or `GOOGLE_PROJECT_ID` in .env)"
)
if self.project_id is None:
if not self.model_name:
raise Exception("MODEL must be set (argument or `MODEL` in .env)")
if not self.location:
raise Exception(
"PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)"
"LOCATION must be set (argument or `VERTEX_LOCATION` in .env)"
)
self.api_timeout = timeout

vertexai.init(
project=self.project_id,
location=self.location,
credentials=self.credentials,
)

self.tolerance = tolerance
self.temperature = max(temperature, tolerance)
self.top_p = top_p
self.max_tokens = max_tokens

self.safety_settings = {
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
# Parameters
tolerance = 1e-7
self.temperature = temperature
if self.temperature < tolerance:
# Currently, the model inference fails if temperature
# is exactly 0, so we nudge it slightly to work around
# the issue
self.temperature += tolerance
self.top_p = top_p
self.max_tokens = max_tokens

super(GeminiModel, self).__init__(
retry_exceptions=(TimeoutError, GeminiFailure), **kwargs
)
vertexai.init(project=self.project_id, location="us-central1")
# self.client = GenerativeModel(self.model_name)

def summarize_response(self, response):
"""Returns the "outputs" key's value, if available"""
Expand Down Expand Up @@ -127,20 +153,6 @@ def prompt(self, processed_input):
This method raises this exception if the server responded with a non-ok
response
"""
# headers = {
# "Content-Type": "application/json",
# "Authorization": "Bearer " + self.api_key,
# }
# body = {
# "input_data": {
# "input_string": processed_input,
# "parameters": {
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# "top_p": self.top_p,
# },
# }
# }
generation_config = {
"max_output_tokens": self.max_tokens,
"temperature": self.temperature,
Expand Down
18 changes: 11 additions & 7 deletions tests/models/test_Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ class TestGeminiDepModelConfig(unittest.TestCase):
def test_gemini_deployed_model_config(self):
"Test if model config parameters passed as arguments are used"
model = GeminiModel(
project_id="test_project_id", api_key="secret-key", model_name="gemini-test"
project_id="test_project_id",
model_name="gemini-test",
location="us-central1",
)

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.location, "us-central1")
self.assertEqual(model.model_name, "gemini-test")

@patch.dict(
"os.environ",
{
"GOOGLE_PROJECT_ID": "test_project_id",
"GOOGLE_API_KEY": "secret-key",
"LOCATION": "us-central1",
"MODEL": "gemini-test",
},
)
Expand All @@ -73,23 +75,25 @@ def test_gemini_deployed_model_config_env_var(self):
model = GeminiModel()

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.location, "us-central1")
self.assertEqual(model.model_name, "gemini-test")

@patch.dict(
"os.environ",
{
"GOOGLE_PROJECT_ID": "test_project_id",
"GOOGLE_API_KEY": "secret-env-key",
"LOCATION": "us-central1",
"MODEL": "gemini-test",
},
)
def test_gemini_deployed_model_config_priority(self):
"Test if model config parameters passed directly get priority"
model = GeminiModel(
project_id="test_project_id", api_key="secret-key", model_name="gemini_test"
project_id="test_project_id",
model_name="gemini_test",
location="us-central1",
)

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.location, "us-central1")
self.assertEqual(model.model_name, "gemini_test")