Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ruff = "^0.11.6"
mypy = "^1.15.0"
pre-commit = "^4.2.0"
pytest = "^8.0.0"
pytest-rerunfailures = "^15.0"
python-dotenv = "^1.1.0"


Expand Down
190 changes: 120 additions & 70 deletions tests/unit/InferenceEngineStrategy/test.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,146 @@
import os
import unittest
from typing import Any, Dict, Type

import pytest
from dotenv import load_dotenv

from vectorq.inference_engine import (
InferenceEngine,
LangChainInferenceEngine,
OpenAIInferenceEngine,
)

load_dotenv()

OPENAI_API_KEY_AVAILABLE = bool(os.environ.get("OPENAI_API_KEY"))
ANTHROPIC_API_KEY_AVAILABLE = bool(os.environ.get("ANTHROPIC_API_KEY"))

# Build parameter list dynamically so we only execute engines we can call
INFERENCE_ENGINE_PARAMS = []
if OPENAI_API_KEY_AVAILABLE:
INFERENCE_ENGINE_PARAMS.extend(
[
pytest.param(
OpenAIInferenceEngine,
{"model_name": "gpt-4.1-nano-2025-04-14", "temperature": 0},
),
pytest.param(
LangChainInferenceEngine,
{
"provider": "openai",
"model_name": "gpt-4.1-nano-2025-04-14",
"temperature": 0,
},
),
]
)

from vectorq.inference_engine import LangChainInferenceEngine, OpenAIInferenceEngine

INFERENCE_ENGINE_PARAMS = [
pytest.param(
OpenAIInferenceEngine,
{"model_name": "gpt-4o-mini", "temperature": 0},
marks=pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY environment variable not set",
),
),
pytest.param(
LangChainInferenceEngine,
{"provider": "openai", "model_name": "gpt-4o-mini", "temperature": 0},
marks=pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY environment variable not set",
),
),
pytest.param(
LangChainInferenceEngine,
{"provider": "anthropic", "model_name": "claude-3-5-sonnet", "temperature": 0},
marks=pytest.mark.skipif(
not os.environ.get("ANTHROPIC_API_KEY"),
reason="ANTHROPIC_API_KEY environment variable not set",
),
),
pytest.param(
LangChainInferenceEngine,
{"provider": "google", "model_name": "gemini-1.5-flash", "temperature": 0},
marks=pytest.mark.skipif(
not os.environ.get("GOOGLE_API_KEY"),
reason="GOOGLE_API_KEY environment variable not set",
),
),
]
if ANTHROPIC_API_KEY_AVAILABLE:
INFERENCE_ENGINE_PARAMS.append(
pytest.param(
LangChainInferenceEngine,
{
"provider": "anthropic",
"model_name": "claude-3-haiku-20240307",
"temperature": 0,
},
)
)


@pytest.mark.skipif(
not INFERENCE_ENGINE_PARAMS, reason="No compatible API keys found for tests."
)
class TestInferenceEngineStrategy:
"""Test all inference engine strategies using parameterization."""
# """Comprehensive tests for the inference engine strategies."""

@pytest.mark.parametrize(
"inference_engine_class, engine_params", INFERENCE_ENGINE_PARAMS
)
def test_create(self, inference_engine_class, engine_params):
"""Test creating responses from different inference engines."""
engine = inference_engine_class(**engine_params)
def test_infer(
self,
inference_engine_class: Type[InferenceEngine],
engine_params: Dict[str, Any],
) -> None:
engine: InferenceEngine = inference_engine_class(**engine_params)

prompt = "What is the capital of France?"
response = engine.create(prompt)
response = engine.infer(prompt)

# Verify the response has the expected content
assert "Paris" in response
assert response is not None
assert "paris" in response.lower()

@pytest.mark.parametrize(
"inference_engine_class, engine_params", INFERENCE_ENGINE_PARAMS
)
def test_create_with_output_format(self, inference_engine_class, engine_params):
"""Test creating responses with specified output format."""
engine = inference_engine_class(**engine_params)

prompt = "List three European capitals."
output_format = "Provide the answer as a comma-separated list."
response = engine.create(prompt, output_format)

# Verify response contains expected cities and follows the format
assert any(
city in response for city in ["Paris", "London", "Berlin", "Madrid", "Rome"]
)

# Should be in comma-separated format as requested
assert "," in response
def test_infer_with_system_prompt_override(
self,
inference_engine_class: Type[InferenceEngine],
engine_params: Dict[str, Any],
) -> None:
engine: InferenceEngine = inference_engine_class(**engine_params)

system_prompt = "ALWAYS ANSWER IN UPPERCASE."
prompt = "Say 'hello world'"
response = engine.infer(prompt, system_prompt=system_prompt)

letters = [c for c in response if c.isalpha()]
if not letters:
assert False, "No alphabetic characters found in response"
is_all_upper = sum(1 for c in letters if c.isupper()) / len(letters) > 0.9
assert is_all_upper

@pytest.mark.parametrize(
"inference_engine_class, engine_params", INFERENCE_ENGINE_PARAMS
)
def test_consistent_responses(self, inference_engine_class, engine_params):
"""Test that responses are consistent with temperature=0."""
engine = inference_engine_class(**engine_params)

prompt = "What is 2+2?"
response1 = engine.create(prompt)
response2 = engine.create(prompt)
@pytest.mark.flaky(reruns=3)
def test_consistency_with_zero_temperature(
self,
inference_engine_class: Type[InferenceEngine],
engine_params: Dict[str, Any],
) -> None:
# Set temperature to 0 when creating the engine
params = {**engine_params, "temperature": 0}
engine: InferenceEngine = inference_engine_class(**params)

prompt = "Use a short and brief sentence to describe Paris."
response1 = engine.infer(prompt)
response2 = engine.infer(prompt)

assert response1 == response2

# With temperature=0, responses to simple factual questions should be consistent
assert "4" in response1
assert "4" in response2
@pytest.mark.parametrize(
"inference_engine_class, engine_params", INFERENCE_ENGINE_PARAMS
)
@pytest.mark.flaky(reruns=3)
def test_infer_with_temperature_kwarg_override(
self,
inference_engine_class: Type[InferenceEngine],
engine_params: Dict[str, Any],
) -> None:
# Set temperature to 0.9 when creating the engine
params = {**engine_params, "temperature": 0.9}
engine: InferenceEngine = inference_engine_class(**params)

# With temperature 0.9, the response should be different
prompt = "Use a short and brief sentence to describe Paris."
response1 = engine.infer(prompt)
response2 = engine.infer(prompt)
assert response1 != response2

# With overriding temperature 0, the response should be almost the same
params = {**engine_params, "temperature": 0}
engine: InferenceEngine = inference_engine_class(**params)
response3 = engine.infer(prompt)
response4 = engine.infer(prompt)
assert response3 == response4

def test_langchain_unsupported_provider(self) -> None:
engine = LangChainInferenceEngine(provider="unknown", model_name="foo")
with pytest.raises(Exception) as exc:
_ = engine.infer("Test")
assert "Unsupported provider" in str(exc.value)


if __name__ == "__main__":
unittest.main()
pytest.main()
6 changes: 1 addition & 5 deletions vectorq/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional

from vectorq.inference_engine.inference_engine import InferenceEngine
from vectorq.inference_engine.strategies.open_ai import OpenAIInferenceEngine
from vectorq.inference_engine.strategies.openai import OpenAIInferenceEngine
from vectorq.vectorq_core.cache.embedding_engine import OpenAIEmbeddingEngine
from vectorq.vectorq_core.cache.embedding_engine.embedding_engine import EmbeddingEngine
from vectorq.vectorq_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_storage import (
Expand Down Expand Up @@ -31,11 +29,9 @@ def __init__(
vector_db: VectorDB = HNSWLibVectorDB(),
embedding_metadata_storage: EmbeddingMetadataStorage = InMemoryEmbeddingMetadataStorage(),
eviction_policy: EvictionPolicy = LRUEvictionPolicy(),
system_prompt: Optional[str] = None,
):
self.inference_engine = inference_engine
self.embedding_engine = embedding_engine
self.vector_db = vector_db
self.eviction_policy = eviction_policy
self.embedding_metadata_storage = embedding_metadata_storage
self.system_prompt = system_prompt
4 changes: 2 additions & 2 deletions vectorq/inference_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vectorq.inference_engine.inference_engine import InferenceEngine
from vectorq.inference_engine.strategies.lang_chain import LangChainInferenceEngine
from vectorq.inference_engine.strategies.open_ai import OpenAIInferenceEngine
from vectorq.inference_engine.strategies.langchain import LangChainInferenceEngine
from vectorq.inference_engine.strategies.openai import OpenAIInferenceEngine

__all__ = ["InferenceEngine", "LangChainInferenceEngine", "OpenAIInferenceEngine"]
20 changes: 15 additions & 5 deletions vectorq/inference_engine/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from abc import ABC, abstractmethod
from typing import Any, Optional


class InferenceEngine(ABC):
"""
Abstract base class for inference engines
Abstract base class for inference engines.
"""

@abstractmethod
def create(self, prompt: str, system_prompt: str = None) -> str:
def infer(
self,
prompt: str,
system_prompt: Optional[str] = None,
**kwargs: Any,
) -> str:
"""
prompt: str - The prompt to create an answer for
output_format: str - The optional output format to use for the response
returns: str - The answer to the prompt
Infer a response from the inference engine.
Args
prompt: str - The prompt to create a response for.
system_prompt: Optional[str] - The optional system prompt to use for the response.
**kwargs: Any - Additional arguments to pass to the underlying inference engine (e.g., max_tokens, temperature, etc).
Returns
str - The response of the prompt.
"""
pass
16 changes: 8 additions & 8 deletions vectorq/inference_engine/strategies/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import override
from typing import Any, override

from vectorq.inference_engine.inference_engine import InferenceEngine


class BenchmarkInferenceEngine(InferenceEngine):
"""
An inference engine implementation that returns pre-computed responses for given prompts.
An inference engine implementation that returns pre-computed responses.
It is used for benchmarking purposes.
"""

def set_next_response(self, response: str):
self.next_response = response
def set_response(self, response: str):
self.response = response

@override
def create(self, prompt: str, system_prompt: str = None) -> str:
if self.next_response is None:
raise ValueError("No next response set")
return self.next_response
def infer(self, prompt: str, system_prompt: str = None, **kwargs: Any) -> str:
if self.response is None:
raise ValueError("No response set")
return self.response
Loading
Loading