-
Notifications
You must be signed in to change notification settings - Fork 21
Model/gemini add #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Model/gemini add #395
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c3d70f7
add gemini model and sample asset
AridHasan 0a5b7ae
rename model class
AridHasan 0bb4428
add test case
AridHasan 1d84102
update requirements
AridHasan 3a73e47
update requirements
AridHasan 4929c8d
remove hard coded parameters
AridHasan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import json | ||
| import re | ||
|
|
||
| from llmebench.datasets import MultiNativQADataset | ||
| from llmebench.models import GeminiModel | ||
| from llmebench.tasks import MultiNativQATask | ||
|
|
||
|
|
||
| def metadata(): | ||
| return { | ||
| "author": "Arabic Language Technologies, QCRI, HBKU", | ||
| "model": "LLama 3 8b", | ||
| "description": "Deployed on Azure.", | ||
| "scores": {}, | ||
| } | ||
|
|
||
|
|
||
| def config(): | ||
| return { | ||
| "dataset": MultiNativQADataset, | ||
| "task": MultiNativQATask, | ||
| "model": GeminiModel, | ||
| "general_args": {"test_split": "english_bd"}, | ||
| } | ||
|
|
||
|
|
||
| def prompt(input_sample): | ||
| # Define the question prompt | ||
| question_prompt = f""" | ||
| Please use your expertise to answer the following English question. Answer in English and rate your confidence level from 1 to 10. | ||
| Provide your response in the following JSON format: {{"answer": "your answer", "score": your confidence score}}. | ||
| Please provide JSON output only. No additional text. Answer should be limited to less or equal to {input_sample['length']} words. | ||
|
|
||
| Question: {input_sample['question']} | ||
| """ | ||
|
|
||
| # Define the assistant prompt | ||
| assistant_prompt = """ | ||
| You are an English AI assistant specialized in providing detailed and accurate answers across various fields. | ||
| Your task is to deliver clear, concise, and relevant information. | ||
| """ | ||
|
|
||
| return [ | ||
| { | ||
| "role": "assistant", | ||
| "content": assistant_prompt, | ||
| }, | ||
| { | ||
| "role": "user", | ||
| "content": question_prompt, | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| def post_process(response): | ||
| content = response[0]["content"]["parts"][0]["text"] | ||
| content = content.replace("\n", "").strip() | ||
| if "```json" in content: | ||
| # content = content.replace("```json", "").replace('```', '').replace("\n}", "}") | ||
| # content = content.replace("{\n", "{").replace("\",\n", "\",") | ||
|
|
||
| content = re.search(r"```json(.*)```", content).group(1) | ||
| return json.loads(content)["answer"] | ||
| # response = json.loads(data) | ||
| # answer = response["answer"] | ||
| return answer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
|
|
||
| import requests | ||
|
|
||
| import vertexai | ||
| import vertexai.preview.generative_models as generative_models | ||
| from vertexai.generative_models import FinishReason, GenerativeModel, Part | ||
|
|
||
| from llmebench.models.model_base import ModelBase | ||
|
|
||
|
|
||
| class GeminiFailure(Exception): | ||
| """Exception class to map various failure types from the Gemini server""" | ||
|
|
||
| def __init__(self, failure_type, failure_message): | ||
| self.type_mapping = { | ||
| "processing": "Model Inference failure", | ||
| "connection": "Failed to connect to Google Server", | ||
| } | ||
| self.type = failure_type | ||
| self.failure_message = failure_message | ||
|
|
||
| def __str__(self): | ||
| return ( | ||
| f"{self.type_mapping.get(self.type, self.type)}: \n {self.failure_message}" | ||
| ) | ||
|
|
||
|
|
||
| class GeminiModel(ModelBase): | ||
| """ | ||
| Gemini Model interface. | ||
|
|
||
| Arguments | ||
| --------- | ||
| project_id : str | ||
| Google Project ID. If not provided, the implementation will | ||
| look at environment variable `GOOGLE_PROJECT_ID` | ||
| api_key : str | ||
| Authentication token for the API. If not provided, the implementation will derive it | ||
| from environment variables `OPENAI_API_KEY` or `AZURE_API_KEY`. | ||
| timeout : int | ||
| Number of seconds before the request to the server is timed out | ||
| temperature : float | ||
| Temperature value to use for the model. Defaults to zero for reproducibility. | ||
| top_p : float | ||
| Top P value to use for the model. Defaults to 0.95 | ||
| max_tokens : int | ||
| Maximum number of tokens to pass to the model. Defaults to 1512 | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| project_id=None, | ||
| api_key=None, | ||
| model_name=None, | ||
| timeout=20, | ||
| temperature=0, | ||
| top_p=0.95, | ||
| max_tokens=2000, | ||
| **kwargs, | ||
| ): | ||
| # API parameters | ||
| # self.api_url = api_url or os.getenv("AZURE_DEPLOYMENT_API_URL") | ||
| self.api_key = api_key or os.getenv("GOOGLE_API_KEY") | ||
| self.project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") | ||
| self.model_name = model_name or os.getenv("MODEL") | ||
| if self.api_key is None: | ||
| raise Exception( | ||
| "API Key must be provided as model config or environment variable (`GOOGLE_API_KEY`)" | ||
| ) | ||
| if self.project_id is None: | ||
| raise Exception( | ||
| "PROJECT_ID must be provided as model config or environment variable (`GOOGLE_PROJECT_ID`)" | ||
| ) | ||
| self.api_timeout = timeout | ||
| self.safety_settings = { | ||
| generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
| generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
| generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
| generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
| } | ||
| # Parameters | ||
| tolerance = 1e-7 | ||
| self.temperature = temperature | ||
| if self.temperature < tolerance: | ||
| # Currently, the model inference fails if temperature | ||
| # is exactly 0, so we nudge it slightly to work around | ||
| # the issue | ||
| self.temperature += tolerance | ||
| self.top_p = top_p | ||
| self.max_tokens = max_tokens | ||
|
|
||
| super(GeminiModel, self).__init__( | ||
| retry_exceptions=(TimeoutError, GeminiFailure), **kwargs | ||
| ) | ||
| vertexai.init(project=self.project_id, location="us-central1") | ||
| # self.client = GenerativeModel(self.model_name) | ||
|
|
||
| def summarize_response(self, response): | ||
| """Returns the "outputs" key's value, if available""" | ||
| if "messages" in response: | ||
| return response["messages"] | ||
|
|
||
| return response | ||
|
|
||
| def prompt(self, processed_input): | ||
| """ | ||
| Gemini API Implementation | ||
|
|
||
| Arguments | ||
| --------- | ||
| processed_input : list | ||
| Must be list of dictionaries, where each dictionary has two keys; | ||
| "role" defines a role in the chat (e.g. "system", "user") and | ||
| "content" defines the actual message for that turn | ||
|
|
||
| Returns | ||
| ------- | ||
| response : Gemini API response | ||
| Response from the Gemini server | ||
|
|
||
| Raises | ||
| ------ | ||
| GeminiFailure : Exception | ||
| This method raises this exception if the server responded with a non-ok | ||
| response | ||
| """ | ||
| # headers = { | ||
| # "Content-Type": "application/json", | ||
| # "Authorization": "Bearer " + self.api_key, | ||
| # } | ||
| # body = { | ||
| # "input_data": { | ||
| # "input_string": processed_input, | ||
| # "parameters": { | ||
| # "max_tokens": self.max_tokens, | ||
| # "temperature": self.temperature, | ||
| # "top_p": self.top_p, | ||
| # }, | ||
| # } | ||
| # } | ||
| generation_config = { | ||
| "max_output_tokens": self.max_tokens, | ||
| "temperature": self.temperature, | ||
| "top_p": self.top_p, | ||
| } | ||
|
|
||
| try: | ||
| client = GenerativeModel( | ||
| self.model_name, system_instruction=[processed_input[0]["content"]] | ||
| ) | ||
| response = client.generate_content( | ||
| [processed_input[1]["content"]], | ||
| generation_config=generation_config, | ||
| safety_settings=self.safety_settings, | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| raise GeminiFailure( | ||
| "processing", | ||
| "Processing failed with status: {}".format(e), | ||
| ) | ||
|
|
||
| # Parse the final response | ||
| try: | ||
| # response_data = response.json() | ||
| response_data = [response.to_dict() for response in response.candidates] | ||
| except Exception as e: | ||
| raise GeminiFailure( | ||
| "processing", | ||
| "Processing failed: {}".format(response), | ||
| ) | ||
|
|
||
| return response_data | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import unittest | ||
| from unittest.mock import patch | ||
|
|
||
| from llmebench import Benchmark | ||
| from llmebench.models import GeminiModel | ||
|
|
||
| from llmebench.utils import is_fewshot_asset | ||
|
|
||
|
|
||
| class TestAssetsForGeminiDepModelPrompts(unittest.TestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| # Load the benchmark assets | ||
| benchmark = Benchmark(benchmark_dir="assets") | ||
| all_assets = benchmark.find_assets() | ||
|
|
||
| # Filter out assets not using the Petals model | ||
| cls.assets = [ | ||
| asset for asset in all_assets if asset["config"]["model"] in [GeminiModel] | ||
| ] | ||
|
|
||
| def test_gemini_deployed_model_prompts(self): | ||
| "Test if all assets using this model return data in an appropriate format for prompting" | ||
|
|
||
| n_shots = 3 # Sample for few shot prompts | ||
|
|
||
| for asset in self.assets: | ||
| with self.subTest(msg=asset["name"]): | ||
| config = asset["config"] | ||
| dataset_args = config.get("dataset_args", {}) | ||
| dataset_args["data_dir"] = "" | ||
| dataset = config["dataset"](**dataset_args) | ||
| data_sample = dataset.get_data_sample() | ||
| if is_fewshot_asset(config, asset["module"].prompt): | ||
| prompt = asset["module"].prompt( | ||
| data_sample["input"], | ||
| [data_sample for _ in range(n_shots)], | ||
| ) | ||
| else: | ||
| prompt = asset["module"].prompt(data_sample["input"]) | ||
|
|
||
| self.assertIsInstance(prompt, list) | ||
|
|
||
| for message in prompt: | ||
| self.assertIsInstance(message, dict) | ||
| self.assertIn("role", message) | ||
| self.assertIsInstance(message["role"], str) | ||
| self.assertIn("content", message) | ||
| self.assertIsInstance(message["content"], (str, list)) | ||
|
|
||
|
|
||
| class TestGeminiDepModelConfig(unittest.TestCase): | ||
| def test_gemini_deployed_model_config(self): | ||
| "Test if model config parameters passed as arguments are used" | ||
| model = GeminiModel( | ||
| project_id="test_project_id", api_key="secret-key", model_name="gemini-test" | ||
| ) | ||
|
|
||
| self.assertEqual(model.project_id, "test_project_id") | ||
| self.assertEqual(model.api_key, "secret-key") | ||
| self.assertEqual(model.model_name, "gemini-test") | ||
|
|
||
| @patch.dict( | ||
| "os.environ", | ||
| { | ||
| "GOOGLE_PROJECT_ID": "test_project_id", | ||
| "GOOGLE_API_KEY": "secret-key", | ||
| "MODEL": "gemini-test", | ||
| }, | ||
| ) | ||
| def test_gemini_deployed_model_config_env_var(self): | ||
| "Test if model config parameters passed as environment variables are used" | ||
| model = GeminiModel() | ||
|
|
||
| self.assertEqual(model.project_id, "test_project_id") | ||
| self.assertEqual(model.api_key, "secret-key") | ||
| self.assertEqual(model.model_name, "gemini-test") | ||
|
|
||
| @patch.dict( | ||
| "os.environ", | ||
| { | ||
| "GOOGLE_PROJECT_ID": "test_project_id", | ||
| "GOOGLE_API_KEY": "secret-env-key", | ||
| "MODEL": "gemini-test", | ||
| }, | ||
| ) | ||
| def test_gemini_deployed_model_config_priority(self): | ||
| "Test if model config parameters passed directly get priority" | ||
| model = GeminiModel( | ||
| project_id="test_project_id", api_key="secret-key", model_name="gemini_test" | ||
| ) | ||
|
|
||
| self.assertEqual(model.project_id, "test_project_id") | ||
| self.assertEqual(model.api_key, "secret-key") | ||
| self.assertEqual(model.model_name, "gemini_test") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.