Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions assets/en/QA/bd/MultiNativQA_Gemini_ZeroShot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
import re

from llmebench.datasets import MultiNativQADataset
from llmebench.models import GeminiModel
from llmebench.tasks import MultiNativQATask


def metadata():
return {
"author": "Arabic Language Technologies, QCRI, HBKU",
"model": "LLama 3 8b",
"description": "Deployed on Azure.",
"scores": {},
}


def config():
return {
"dataset": MultiNativQADataset,
"task": MultiNativQATask,
"model": GeminiModel,
"general_args": {"test_split": "english_bd"},
}


def prompt(input_sample):
# Define the question prompt
question_prompt = f"""
Please use your expertise to answer the following English question. Answer in English and rate your confidence level from 1 to 10.
Provide your response in the following JSON format: {{"answer": "your answer", "score": your confidence score}}.
Please provide JSON output only. No additional text. Answer should be limited to less or equal to {input_sample['length']} words.

Question: {input_sample['question']}
"""

# Define the assistant prompt
assistant_prompt = """
You are an English AI assistant specialized in providing detailed and accurate answers across various fields.
Your task is to deliver clear, concise, and relevant information.
"""

return [
{
"role": "assistant",
"content": assistant_prompt,
},
{
"role": "user",
"content": question_prompt,
},
]


def post_process(response):
content = response[0]["content"]["parts"][0]["text"]
content = content.replace("\n", "").strip()
if "```json" in content:
# content = content.replace("```json", "").replace('```', '').replace("\n}", "}")
# content = content.replace("{\n", "{").replace("\",\n", "\",")

content = re.search(r"```json(.*)```", content).group(1)
return json.loads(content)["answer"]
# response = json.loads(data)
# answer = response["answer"]
return answer
176 changes: 176 additions & 0 deletions llmebench/models/Gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import json
import logging
import os

import requests

import vertexai
import vertexai.preview.generative_models as generative_models
from vertexai.generative_models import FinishReason, GenerativeModel, Part

from llmebench.models.model_base import ModelBase


class GeminiFailure(Exception):
"""Exception class to map various failure types from the Gemini server"""

def __init__(self, failure_type, failure_message):
self.type_mapping = {
"processing": "Model Inference failure",
"connection": "Failed to connect to Google Server",
}
self.type = failure_type
self.failure_message = failure_message

def __str__(self):
return (
f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}"
)


class GeminiModel(ModelBase):
"""
Gemini Model interface.

Arguments
---------
project_id : str
Google Project ID. If not provided, the implementation will
look at environment variable `GOOGLE_PROJECT_ID`
api_key : str
Authentication token for the API. If not provided, the implementation will derive it
from environment variables `OPENAI_API_KEY` or `AZURE_API_KEY`.
timeout : int
Number of seconds before the request to the server is timed out
temperature : float
Temperature value to use for the model. Defaults to zero for reproducibility.
top_p : float
Top P value to use for the model. Defaults to 0.95
max_tokens : int
Maximum number of tokens to pass to the model. Defaults to 1512
"""

def __init__(
self,
project_id=None,
api_key=None,
model_name=None,
timeout=20,
temperature=0,
top_p=0.95,
max_tokens=2000,
**kwargs,
):
# API parameters
# self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL")
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID")
self.model_name = model_name or os.getenv("MODEL")
if self.api_key is None:
raise Exception(
"API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)"
)
if self.project_id is None:
raise Exception(
"PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)"
)
self.api_timeout = timeout
self.safety_settings = {
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
# Parameters
tolerance = 1e-7
self.temperature = temperature
if self.temperature < tolerance:
# Currently, the model inference fails if temperature
# is exactly 0, so we nudge it slightly to work around
# the issue
self.temperature += tolerance
self.top_p = top_p
self.max_tokens = max_tokens

super(GeminiModel, self).__init__(
retry_exceptions=(TimeoutError, GeminiFailure), **kwargs
)
vertexai.init(project=self.project_id, location="us-central1")
# self.client = GenerativeModel(self.model_name)

def summarize_response(self, response):
"""Returns the "outputs" key's value, if available"""
if "messages" in response:
return response["messages"]

return response

def prompt(self, processed_input):
"""
Gemini API Implementation

Arguments
---------
processed_input : list
Must be list of dictionaries, where each dictionary has two keys;
"role" defines a role in the chat (e.g. "system", "user") and
"content" defines the actual message for that turn

Returns
-------
response : Gemini API response
Response from the Gemini server

Raises
------
GeminiFailure : Exception
This method raises this exception if the server responded with a non-ok
response
"""
# headers = {
# "Content-Type": "application/json",
# "Authorization": "Bearer " + self.api_key,
# }
# body = {
# "input_data": {
# "input_string": processed_input,
# "parameters": {
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# "top_p": self.top_p,
# },
# }
# }
generation_config = {
"max_output_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
}

try:
client = GenerativeModel(
self.model_name, system_instruction=[processed_input[0]["content"]]
)
response = client.generate_content(
[processed_input[1]["content"]],
generation_config=generation_config,
safety_settings=self.safety_settings,
)

except Exception as e:
raise GeminiFailure(
"processing",
"Processing failed with status: {}".format(e),
)

# Parse the final response
try:
# response_data = response.json()
response_data = [response.to_dict() for response in response.candidates]
except Exception as e:
raise GeminiFailure(
"processing",
"Processing failed: {}".format(response),
)

return response_data
1 change: 1 addition & 0 deletions llmebench/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .Anthropic import AnthropicModel
from .AzureModel import AzureModel
from .FastChat import FastChatModel
from .Gemini import GeminiModel
from .HuggingFaceInferenceAPI import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes
from .OpenAI import LegacyOpenAIModel, OpenAIModel, OpenAIO1Model
from .Petals import PetalsModel
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ install_requires =
rouge-score==0.1.2
absl-py==2.1.0
GitPython==3.1.43
google-cloud-aiplatform>=1.90.0
# For now, make sure NumPy 2 is not installed
numpy<2

Expand Down
95 changes: 95 additions & 0 deletions tests/models/test_Gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import unittest
from unittest.mock import patch

from llmebench import Benchmark
from llmebench.models import GeminiModel

from llmebench.utils import is_fewshot_asset


class TestAssetsForGeminiDepModelPrompts(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load the benchmark assets
benchmark = Benchmark(benchmark_dir="assets")
all_assets = benchmark.find_assets()

# Filter out assets not using the Petals model
cls.assets = [
asset for asset in all_assets if asset["config"]["model"] in [GeminiModel]
]

def test_gemini_deployed_model_prompts(self):
"Test if all assets using this model return data in an appropriate format for prompting"

n_shots = 3 # Sample for few shot prompts

for asset in self.assets:
with self.subTest(msg=asset["name"]):
config = asset["config"]
dataset_args = config.get("dataset_args", {})
dataset_args["data_dir"] = ""
dataset = config["dataset"](**dataset_args)
data_sample = dataset.get_data_sample()
if is_fewshot_asset(config, asset["module"].prompt):
prompt = asset["module"].prompt(
data_sample["input"],
[data_sample for _ in range(n_shots)],
)
else:
prompt = asset["module"].prompt(data_sample["input"])

self.assertIsInstance(prompt, list)

for message in prompt:
self.assertIsInstance(message, dict)
self.assertIn("role", message)
self.assertIsInstance(message["role"], str)
self.assertIn("content", message)
self.assertIsInstance(message["content"], (str, list))


class TestGeminiDepModelConfig(unittest.TestCase):
def test_gemini_deployed_model_config(self):
"Test if model config parameters passed as arguments are used"
model = GeminiModel(
project_id="test_project_id", api_key="secret-key", model_name="gemini-test"
)

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model_name, "gemini-test")

@patch.dict(
"os.environ",
{
"GOOGLE_PROJECT_ID": "test_project_id",
"GOOGLE_API_KEY": "secret-key",
"MODEL": "gemini-test",
},
)
def test_gemini_deployed_model_config_env_var(self):
"Test if model config parameters passed as environment variables are used"
model = GeminiModel()

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model_name, "gemini-test")

@patch.dict(
"os.environ",
{
"GOOGLE_PROJECT_ID": "test_project_id",
"GOOGLE_API_KEY": "secret-env-key",
"MODEL": "gemini-test",
},
)
def test_gemini_deployed_model_config_priority(self):
"Test if model config parameters passed directly get priority"
model = GeminiModel(
project_id="test_project_id", api_key="secret-key", model_name="gemini_test"
)

self.assertEqual(model.project_id, "test_project_id")
self.assertEqual(model.api_key, "secret-key")
self.assertEqual(model.model_name, "gemini_test")