Skip to content

Commit 8f6859e

Browse files
Added Similarity Evaluator To Config
1 parent cc32f66 commit 8f6859e

File tree

8 files changed

+49
-29
lines changed

8 files changed

+49
-29
lines changed

benchmarks/benchmark.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
HNSWLibVectorDB,
3333
SimilarityMetricType,
3434
)
35+
from vectorq.vectorq_core.similarity_evaluator.strategies.llm_comparison import (
36+
LLMComparisonSimilarityEvaluator,
37+
)
38+
from vectorq.vectorq_core.similarity_evaluator.strategies.string_comparison import (
39+
StringComparisonSimilarityEvaluator,
40+
)
3541
from vectorq.vectorq_policy.strategies.dynamic_global_threshold import (
3642
DynamicGlobalThresholdPolicy,
3743
)
@@ -58,8 +64,10 @@
5864
########################################################################################################################
5965

6066
# Benchmark Config
61-
MAX_SAMPLES: int = 15000
67+
MAX_SAMPLES: int = 5000
6268
CONFIDENCE_INTERVALS_ITERATIONS: int = 3
69+
IS_LLM_JUDGE_BENCHMARK: bool = False
70+
6371
EMBEDDING_MODEL_1 = (
6472
"embedding_1",
6573
"GteLargeENv1_5",
@@ -96,7 +104,7 @@
96104
"ecommerce_dataset.json",
97105
"semantic_prompt_cache_benchmark.json",
98106
]
99-
DATASETS_TO_EXCLUDE: List[str] = [DATASETS[1], DATASETS[2]]
107+
DATASETS_TO_EXCLUDE: List[str] = [DATASETS[0], DATASETS[2], DATASETS[3]]
100108

101109
embedding_models: List[Tuple[str, str, str, int]] = [
102110
EMBEDDING_MODEL_1,
@@ -410,6 +418,11 @@ def __run_baseline(
410418
delta: float,
411419
threshold: float,
412420
):
421+
if IS_LLM_JUDGE_BENCHMARK:
422+
similarity_evaluator = LLMComparisonSimilarityEvaluator()
423+
else:
424+
similarity_evaluator = StringComparisonSimilarityEvaluator()
425+
413426
vectorq_config: VectorQConfig = VectorQConfig(
414427
inference_engine=BenchmarkInferenceEngine(),
415428
embedding_engine=BenchmarkEmbeddingEngine(),
@@ -418,6 +431,7 @@ def __run_baseline(
418431
max_capacity=MAX_VECTOR_DB_CAPACITY,
419432
),
420433
embedding_metadata_storage=InMemoryEmbeddingMetadataStorage(),
434+
similarity_evaluator=similarity_evaluator,
421435
)
422436
vectorQ: VectorQ = VectorQ(vectorq_config, vectorq_policy)
423437

File renamed without changes.

tests/integration/test_dynamic_threshold.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
InMemoryEmbeddingMetadataStorage,
99
LangChainEmbeddingEngine,
1010
OpenAIInferenceEngine,
11-
StringComparisonSimilarityEvaluator,
1211
VectorQ,
1312
VectorQConfig,
1413
)
@@ -29,10 +28,7 @@ def create_default_config_and_policy():
2928
embedding_metadata_storage=InMemoryEmbeddingMetadataStorage(),
3029
system_prompt="Please answer in a single word with the first letter capitalized. Example: London",
3130
)
32-
policy = DynamicLocalThresholdPolicy(
33-
delta=0.05,
34-
similarity_evaluator=StringComparisonSimilarityEvaluator(),
35-
)
31+
policy = DynamicLocalThresholdPolicy(delta=0.05)
3632
return config, policy
3733

3834

vectorq/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from vectorq.vectorq_core.cache.eviction_policy.strategies.no_eviction import (
1919
NoEvictionPolicy,
2020
)
21+
from vectorq.vectorq_core.similarity_evaluator.similarity_evaluator import (
22+
SimilarityEvaluator,
23+
)
24+
from vectorq.vectorq_core.similarity_evaluator.strategies.string_comparison import (
25+
StringComparisonSimilarityEvaluator,
26+
)
2127

2228

2329
class VectorQConfig:
@@ -33,11 +39,14 @@ def __init__(
3339
vector_db: VectorDB = HNSWLibVectorDB(),
3440
embedding_metadata_storage: EmbeddingMetadataStorage = InMemoryEmbeddingMetadataStorage(),
3541
eviction_policy: EvictionPolicy = NoEvictionPolicy(),
42+
similarity_evaluator: SimilarityEvaluator = StringComparisonSimilarityEvaluator(),
3643
system_prompt: Optional[str] = None,
3744
):
3845
self.inference_engine = inference_engine
3946
self.embedding_engine = embedding_engine
4047
self.vector_db = vector_db
4148
self.eviction_policy = eviction_policy
4249
self.embedding_metadata_storage = embedding_metadata_storage
50+
self.similarity_evaluator = similarity_evaluator
51+
self.similarity_evaluator.set_inference_engine(self.inference_engine)
4352
self.system_prompt = system_prompt

vectorq/vectorq_core/similarity_evaluator/similarity_evaluator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from abc import ABC, abstractmethod
22

3+
from vectorq.inference_engine import InferenceEngine
4+
35

46
class SimilarityEvaluator(ABC):
7+
def __init__(self):
8+
self.inference_engine: InferenceEngine = None
9+
510
@abstractmethod
611
def answers_similar(self, a: str, b: str) -> bool:
712
"""
@@ -10,3 +15,6 @@ def answers_similar(self, a: str, b: str) -> bool:
1015
returns: bool - True if the answers are similar, False otherwise
1116
"""
1217
pass
18+
19+
def set_inference_engine(self, inference_engine: InferenceEngine):
20+
self.inference_engine = inference_engine

vectorq/vectorq_core/similarity_evaluator/strategies/llm_comparison.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ def __init__(self):
99

1010
def answers_similar(self, a: str, b: str) -> bool:
1111
# TODO
12+
# @Alex: You can access the inference engine via:
13+
# self.inference_engine
1214
print("TODO: Embedding-based Answer similarity check not implemented")
1315
return False

vectorq/vectorq_policy/strategies/dynamic_global_threshold.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,19 @@
1010
from typing_extensions import override
1111

1212
from vectorq.config import VectorQConfig
13+
from vectorq.inference_engine import InferenceEngine
1314
from vectorq.vectorq_core.cache.cache import Cache
1415
from vectorq.vectorq_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import (
1516
EmbeddingMetadataObj,
1617
)
1718
from vectorq.vectorq_core.cache.embedding_store.embedding_store import EmbeddingStore
18-
from vectorq.vectorq_core.similarity_evaluator import (
19-
SimilarityEvaluator,
20-
StringComparisonSimilarityEvaluator,
21-
)
19+
from vectorq.vectorq_core.similarity_evaluator import SimilarityEvaluator
2220
from vectorq.vectorq_policy.vectorq_policy import VectorQPolicy
2321

2422

2523
class DynamicGlobalThresholdPolicy(VectorQPolicy):
2624
def __init__(
2725
self,
28-
similarity_evaluator: SimilarityEvaluator = StringComparisonSimilarityEvaluator(),
2926
delta: float = 0.01,
3027
):
3128
"""
@@ -34,16 +31,16 @@ def __init__(
3431
This is suboptimal in cases when the embeddings cannot seperate correct from incorrect responses.
3532
3633
Args
37-
similarity_evaluator: SimilarityEvaluator - The similarity evaluator to use
3834
delta: float - The delta value to use
3935
"""
40-
self.similarity_evaluator = similarity_evaluator
4136
self.bayesian = _Bayesian(delta=delta)
42-
self.inference_engine = None
43-
self.cache = None
37+
self.similarity_evaluator: SimilarityEvaluator = None
38+
self.inference_engine: InferenceEngine = None
39+
self.cache: Cache = None
4440

4541
@override
4642
def setup(self, config: VectorQConfig):
43+
self.similarity_evaluator = config.similarity_evaluator
4744
self.inference_engine = config.inference_engine
4845
self.cache = Cache(
4946
embedding_engine=config.embedding_engine,

vectorq/vectorq_policy/strategies/dynamic_local_threshold.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,34 @@
1010
from typing_extensions import override
1111

1212
from vectorq.config import VectorQConfig
13+
from vectorq.inference_engine import InferenceEngine
1314
from vectorq.vectorq_core.cache.cache import Cache
1415
from vectorq.vectorq_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import (
1516
EmbeddingMetadataObj,
1617
)
1718
from vectorq.vectorq_core.cache.embedding_store.embedding_store import EmbeddingStore
18-
from vectorq.vectorq_core.similarity_evaluator import (
19-
SimilarityEvaluator,
20-
StringComparisonSimilarityEvaluator,
21-
)
19+
from vectorq.vectorq_core.similarity_evaluator import SimilarityEvaluator
2220
from vectorq.vectorq_policy.vectorq_policy import VectorQPolicy
2321

2422

2523
class DynamicLocalThresholdPolicy(VectorQPolicy):
26-
def __init__(
27-
self,
28-
similarity_evaluator: SimilarityEvaluator = StringComparisonSimilarityEvaluator(),
29-
delta: float = 0.01,
30-
):
24+
def __init__(self, delta: float = 0.01):
3125
"""
3226
This policy uses the VectorQ algorithm to compute the optimal threshold for each
3327
embedding in the cache.
3428
Each threshold is used to determine if a response is a cache hit.
3529
3630
Args
37-
similarity_evaluator: SimilarityEvaluator - The similarity evaluator to use
3831
delta: float - The delta value to use
3932
"""
40-
self.similarity_evaluator = similarity_evaluator
4133
self.bayesian = _Bayesian(delta=delta)
42-
self.inference_engine = None
43-
self.cache = None
34+
self.similarity_evaluator: SimilarityEvaluator = None
35+
self.inference_engine: InferenceEngine = None
36+
self.cache: Cache = None
4437

4538
@override
4639
def setup(self, config: VectorQConfig):
40+
self.similarity_evaluator = config.similarity_evaluator
4741
self.inference_engine = config.inference_engine
4842
self.cache = Cache(
4943
embedding_engine=config.embedding_engine,

0 commit comments

Comments
 (0)