diff --git a/openviking/storage/vectordb/collection/collection.py b/openviking/storage/vectordb/collection/collection.py index 45954e8c..97d259fd 100644 --- a/openviking/storage/vectordb/collection/collection.py +++ b/openviking/storage/vectordb/collection/collection.py @@ -1,12 +1,23 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 +import importlib from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from openviking.storage.vectordb.collection.result import AggregateResult, SearchResult from openviking.storage.vectordb.index.index import IIndex +def load_collection_class(class_path: str) -> Type["ICollection"]: + """Load collection class from string path""" + try: + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not load collection class {class_path}: {e}") + + class ICollection(ABC): def __init__(self): pass diff --git a/openviking/storage/vectordb/project/vikingdb_project.py b/openviking/storage/vectordb/project/vikingdb_project.py index 50a67acf..bcf82ec9 100644 --- a/openviking/storage/vectordb/project/vikingdb_project.py +++ b/openviking/storage/vectordb/project/vikingdb_project.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict, List, Optional -from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.collection import ( + Collection, + load_collection_class, +) from openviking.storage.vectordb.collection.vikingdb_clients import ( VIKINGDB_APIS, VikingDBClient, ) -from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection from openviking_cli.utils.logger import default_logger as logger @@ -22,6 +24,8 @@ def get_or_create_vikingdb_project( config: Configuration dict with keys: - Host: VikingDB service host - Headers: Custom headers for authentication/context + - CollectionClass: Class path for collection implementation + - CollectionArgs: Optional dictionary of arguments to pass to collection constructor Returns: VikingDBProject instance @@ -31,11 +35,23 @@ def get_or_create_vikingdb_project( host = config.get("Host") headers = config.get("Headers") + collection_class_path = config.get( + "CollectionClass", + "openviking.storage.vectordb.collection.vikingdb_collection.VikingDBCollection", + ) + # Extract any other arguments that might be needed for collection initialization + collection_args = config.get("CollectionArgs", {}) if not host: raise ValueError("config must contain 'Host'") - return VikingDBProject(host=host, headers=headers, project_name=project_name) + return VikingDBProject( + host=host, + headers=headers, + project_name=project_name, + collection_class_path=collection_class_path, + collection_args=collection_args, + ) class VikingDBProject: @@ -45,7 +61,12 @@ class VikingDBProject: """ def __init__( - self, host: str, headers: Optional[Dict[str, str]] = None, project_name: str = "default" + self, + host: str, + headers: Optional[Dict[str, str]] = None, + project_name: str = "default", + collection_class_path: str = "openviking.storage.vectordb.collection.vikingdb_collection.VikingDBCollection", + collection_args: Optional[Dict[str, Any]] = None, ): """ Initialize VikingDB project. @@ -54,12 +75,19 @@ def __init__( host: VikingDB service host headers: Custom headers for requests project_name: Project name + collection_class_path: Python path to the collection class + collection_args: Optional dictionary of arguments to pass to collection constructor """ self.host = host self.headers = headers self.project_name = project_name + self.collection_class_path = collection_class_path + self.CollectionClass = load_collection_class(self.collection_class_path) + self.collection_args = collection_args or {} - logger.info(f"Initialized VikingDB project: {project_name} with host {host}") + logger.info( + f"Initialized VikingDB project: {project_name} with host {host} and collection class {collection_class_path}" + ) def close(self): """Close project""" @@ -87,9 +115,17 @@ def get_collection(self, collection_name: str) -> Optional[Collection]: meta_data = result.get("Result", {}) if not meta_data: return None - vikingdb_collection = VikingDBCollection( - host=self.host, headers=self.headers, meta_data=meta_data - ) + # Prepare arguments for collection constructor + # Default arguments + kwargs = { + "host": self.host, + "headers": self.headers, + "meta_data": meta_data, + } + # Update with user-provided arguments (can override defaults if needed, though usually additive) + kwargs.update(self.collection_args) + + vikingdb_collection = self.CollectionClass(**kwargs) return Collection(vikingdb_collection) except Exception: return None @@ -118,12 +154,24 @@ def list_collections(self) -> List[str]: def get_collections(self) -> Dict[str, Collection]: """Get all collections from server""" colls = self._get_collections() - return { - c["CollectionName"]: Collection( - VikingDBCollection(host=self.host, headers=self.headers, meta_data=c) - ) - for c in colls + + # Prepare base arguments + base_kwargs = { + "host": self.host, + "headers": self.headers, } + + collections = {} + for c in colls: + kwargs = base_kwargs.copy() + kwargs["meta_data"] = c + kwargs.update(self.collection_args) + + collections[c["CollectionName"]] = Collection( + self.CollectionClass(**kwargs) + ) + + return collections def create_collection(self, collection_name: str, meta_data: Dict[str, Any]) -> Collection: """collection should be pre-created""" diff --git a/tests/manual_test_dynamic_loading.py b/tests/manual_test_dynamic_loading.py new file mode 100644 index 00000000..8cbfe5b8 --- /dev/null +++ b/tests/manual_test_dynamic_loading.py @@ -0,0 +1,111 @@ + +import unittest +from openviking.storage.vectordb.project.vikingdb_project import get_or_create_vikingdb_project, VikingDBProject +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection + +class TestDynamicLoading(unittest.TestCase): + def test_default_loading(self): + # Test with default configuration + config = {"Host": "test_host"} + project = get_or_create_vikingdb_project(config=config) + self.assertEqual(project.CollectionClass, VikingDBCollection) + print("Default loading test passed") + + def test_explicit_loading(self): + # Test with explicit configuration pointing to MockJoiner + # MockJoiner is in tests/mock_joiner.py, so we need to make sure tests module is importable + # or use a path that python can find. + # Assuming tests package is available or we use relative import if possible, + # but dynamic loader uses importlib.import_module which needs module path. + + # We'll use the MockJoiner we just created. + # Since 'tests' might not be a package in installed environment, but here we are in source. + # We might need to adjust python path or assume tests is importable. + import sys + import os + sys.path.append(os.getcwd()) + + config = { + "Host": "test_host", + "Headers": {"Auth": "Token"}, + "CollectionClass": "tests.mock_joiner.MockJoiner", + "CollectionArgs": { + "custom_param1": "custom_val", + "custom_param2": 123 + } + } + project = get_or_create_vikingdb_project(config=config) + + from tests.mock_joiner import MockJoiner + self.assertEqual(project.CollectionClass, MockJoiner) + self.assertEqual(project.host, "test_host") + self.assertEqual(project.headers, {"Auth": "Token"}) + self.assertEqual(project.collection_args, {"custom_param1": "custom_val", "custom_param2": 123}) + + # Test collection creation to verify params are passed + collection_name = "test_collection" + meta_data = { + "test_verification": True, + "Host": "metadata_host", + "Headers": {"Meta": "Header"} + } + + # The project wrapper will pass host, headers, meta_data, AND collection_args + kwargs = { + "host": project.host, + "headers": project.headers, + "meta_data": meta_data + } + kwargs.update(project.collection_args) + + collection_instance = project.CollectionClass(**kwargs) + + # Verify custom params are set correctly + self.assertEqual(collection_instance.custom_param1, "custom_val") + self.assertEqual(collection_instance.custom_param2, 123) + + # Verify host/headers are in kwargs (since init doesn't take them explicitly anymore) + self.assertEqual(collection_instance.kwargs.get("host"), "test_host") + self.assertEqual(collection_instance.kwargs.get("headers"), {"Auth": "Token"}) + + print("Explicit loading test passed (MockJoiner with custom params)") + + def test_kwargs_loading(self): + # Test with CollectionArgs + config = { + "Host": "test_host", + "CollectionClass": "tests.mock_joiner.MockJoiner", + "CollectionArgs": { + "custom_param1": "extra_value", + "custom_param2": 456 + } + } + project = get_or_create_vikingdb_project(config=config) + + self.assertEqual(project.collection_args, {"custom_param1": "extra_value", "custom_param2": 456}) + + # Manually verify instantiation with kwargs + kwargs = { + "host": project.host, + "headers": project.headers, + "meta_data": {"test_verification": True} + } + kwargs.update(project.collection_args) + + collection_instance = project.CollectionClass(**kwargs) + self.assertEqual(collection_instance.custom_param1, "extra_value") + self.assertEqual(collection_instance.custom_param2, 456) + print("Kwargs loading test passed") + + def test_invalid_loading(self): + # Test with invalid class path + config = { + "Host": "test_host", + "CollectionClass": "non.existent.module.Class" + } + with self.assertRaises(ImportError): + get_or_create_vikingdb_project(config=config) + print("Invalid loading test passed") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mock_joiner.py b/tests/mock_joiner.py new file mode 100644 index 00000000..4c08ba8b --- /dev/null +++ b/tests/mock_joiner.py @@ -0,0 +1,150 @@ + +from typing import Any, Dict, List, Optional +from openviking.storage.vectordb.collection.collection import ICollection +from openviking.storage.vectordb.collection.result import AggregateResult, SearchResult +from openviking.storage.vectordb.index.index import IIndex + +class MockJoiner(ICollection): + def __init__(self, custom_param1: str, custom_param2: int, meta_data: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__() + self.meta_data = meta_data if meta_data is not None else {} + + self.custom_param1 = custom_param1 + self.custom_param2 = custom_param2 + + # Store extra kwargs (including host/headers if passed but not used explicitly) + self.kwargs = kwargs + + # Verify that we can access values passed during initialization + if self.meta_data and "test_verification" in self.meta_data: + print(f"MockJoiner initialized with custom_param1={self.custom_param1}, custom_param2={self.custom_param2}, kwargs={kwargs}") + + def update(self, fields: Optional[Dict[str, Any]] = None, description: Optional[str] = None): + raise NotImplementedError("MockJoiner.update is not supported") + + def get_meta_data(self): + raise NotImplementedError("MockJoiner.get_meta_data is not supported") + + def close(self): + raise NotImplementedError("MockJoiner.close is not supported") + + def drop(self): + raise NotImplementedError("MockJoiner.drop is not supported") + + def create_index(self, index_name: str, meta_data: Dict[str, Any]) -> IIndex: + raise NotImplementedError("MockJoiner.create_index is not supported") + + def has_index(self, index_name: str) -> bool: + raise NotImplementedError("MockJoiner.has_index is not supported") + + def get_index(self, index_name: str) -> Optional[IIndex]: + raise NotImplementedError("MockJoiner.get_index is not supported") + + def search_by_vector( + self, + index_name: str, + dense_vector: Optional[List[float]] = None, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + sparse_vector: Optional[Dict[str, float]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_vector is not supported") + + def search_by_keywords( + self, + index_name: str, + keywords: Optional[List[str]] = None, + query: Optional[str] = None, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_keywords is not supported") + + def search_by_id( + self, + index_name: str, + id: Any, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_id is not supported") + + def search_by_multimodal( + self, + index_name: str, + text: Optional[str], + image: Optional[Any], + video: Optional[Any], + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_multimodal is not supported") + + def search_by_random( + self, + index_name: str, + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_random is not supported") + + def search_by_scalar( + self, + index_name: str, + field: str, + order: Optional[str] = "desc", + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + output_fields: Optional[List[str]] = None, + ) -> SearchResult: + raise NotImplementedError("MockJoiner.search_by_scalar is not supported") + + def update_index( + self, + index_name: str, + scalar_index: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + ): + raise NotImplementedError("MockJoiner.update_index is not supported") + + def get_index_meta_data(self, index_name: str): + raise NotImplementedError("MockJoiner.get_index_meta_data is not supported") + + def list_indexes(self): + raise NotImplementedError("MockJoiner.list_indexes is not supported") + + def drop_index(self, index_name: str): + raise NotImplementedError("MockJoiner.drop_index is not supported") + + def upsert_data(self, data_list: List[Dict[str, Any]], ttl=0): + raise NotImplementedError("MockJoiner.upsert_data is not supported") + + def fetch_data(self, primary_keys: List[Any]): + raise NotImplementedError("MockJoiner.fetch_data is not supported") + + def delete_data(self, primary_keys: List[Any]): + raise NotImplementedError("MockJoiner.delete_data is not supported") + + def delete_all_data(self): + raise NotImplementedError("MockJoiner.delete_all_data is not supported") + + def aggregate_data( + self, + index_name: str, + op: str = "count", + field: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + cond: Optional[Dict[str, Any]] = None, + ) -> AggregateResult: + raise NotImplementedError("MockJoiner.aggregate_data is not supported")