diff --git a/.github/workflows/test-cassandra.yml b/.github/workflows/test-cassandra.yml index 6b75ec86..495ceb70 100644 --- a/.github/workflows/test-cassandra.yml +++ b/.github/workflows/test-cassandra.yml @@ -1,11 +1,10 @@ name: Test Cassandra on: - workflow_run: - workflows: - - Test - types: - - completed + push: + branches: [master] + pull_request: + branches: [master] jobs: build: diff --git a/.github/workflows/test-mongo.yml b/.github/workflows/test-mongo.yml index d4c60be4..01d8cda8 100644 --- a/.github/workflows/test-mongo.yml +++ b/.github/workflows/test-mongo.yml @@ -1,11 +1,10 @@ name: Test Mongo on: - workflow_run: - workflows: - - Test - types: - - completed + push: + branches: [master] + pull_request: + branches: [master] jobs: build: diff --git a/.github/workflows/test-redis.yml b/.github/workflows/test-redis.yml index 99b4b589..ce4d5719 100644 --- a/.github/workflows/test-redis.yml +++ b/.github/workflows/test-redis.yml @@ -1,11 +1,10 @@ name: Test Redis on: - workflow_run: - workflows: - - Test - types: - - completed + push: + branches: [master] + pull_request: + branches: [master] jobs: build: diff --git a/datasketch/__init__.py b/datasketch/__init__.py index 19c0d514..d35e3aef 100644 --- a/datasketch/__init__.py +++ b/datasketch/__init__.py @@ -23,9 +23,15 @@ WeightedMinHashLSH = MinHashLSH WeightedMinHashLSHForest = MinHashLSHForest +# Optional async export (requires motor or redis.asyncio) +try: + from datasketch.aio import AsyncMinHashLSH +except ImportError: + AsyncMinHashLSH = None # type: ignore[misc,assignment] __all__ = [ "HNSW", + "AsyncMinHashLSH", "HyperLogLog", "HyperLogLogPlusPlus", "LeanMinHash", diff --git a/datasketch/aio/__init__.py b/datasketch/aio/__init__.py new file mode 100644 index 00000000..89ad439f --- /dev/null +++ b/datasketch/aio/__init__.py @@ -0,0 +1,36 @@ +"""Async MinHash LSH module. + +This module provides asynchronous implementations of MinHash LSH for use with +async storage backends like MongoDB (via motor) and Redis (via redis.asyncio). + +Example: + .. code-block:: python + + from datasketch.aio import AsyncMinHashLSH + from datasketch import MinHash + + async def main(): + async with AsyncMinHashLSH( + storage_config={"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}}, + threshold=0.5, + num_perm=128, + prepickle=True, # Enable string keys + ) as lsh: + m = MinHash(num_perm=128) + m.update(b"data") + await lsh.insert("key", m) + result = await lsh.query(m) + +""" + +from datasketch.aio.lsh import ( + AsyncMinHashLSH, + AsyncMinHashLSHDeleteSession, + AsyncMinHashLSHInsertionSession, +) + +__all__ = [ + "AsyncMinHashLSH", + "AsyncMinHashLSHDeleteSession", + "AsyncMinHashLSHInsertionSession", +] diff --git a/datasketch/aio/lsh.py b/datasketch/aio/lsh.py new file mode 100644 index 00000000..0e3d45ba --- /dev/null +++ b/datasketch/aio/lsh.py @@ -0,0 +1,382 @@ +"""Asynchronous MinHash LSH implementation. + +This module provides AsyncMinHashLSH for use with async storage backends +like MongoDB (via motor) and Redis (via redis.asyncio). +""" + +import asyncio +import pickle +from itertools import chain +from typing import Optional + +from datasketch.aio.storage import ( + async_ordered_storage, + async_unordered_storage, +) +from datasketch.lsh import _optimal_param +from datasketch.storage import _random_name, unordered_storage + + +class AsyncMinHashLSH: + """Asynchronous MinHashLSH index. + + :param float threshold: see :class:`datasketch.MinHashLSH`. + :param int num_perm: see :class:`datasketch.MinHashLSH`. + :param weights: see :class:`datasketch.MinHashLSH`. + :type weights: tuple(float, float) + :param tuple params: see :class:`datasketch.MinHashLSH`. + :param dict storage_config: New type of storage service - aiomongo - to use for storing + hashtables and keys are implemented. + If storage_config is None aiomongo storage will be used. + :param prepickle (bool, optional): If True, all keys are pickled to bytes before + insertion. If None, a default value is chosen based on the + `storage_config`. + For example usage see :ref:`minhash_lsh_async`. + + Example of supported storage configuration: + + .. code-block:: python + + MONGO = {"type": "aiomongo", "basename": "base_name_1", "mongo": {"host": "localhost", "port": 27017}} + + .. note:: + * For main functionality of LSH algorithm see :class:`datasketch.MinHashLSH`. + * For additional information see :ref:`minhash_lsh_at_scale` and :ref:`minhash_lsh_async` + """ + + def __init__( + self, + threshold: float = 0.9, + num_perm: int = 128, + weights: tuple[float, float] = (0.5, 0.5), + params: Optional[tuple[int, int]] = None, + storage_config: Optional[dict] = None, + prepickle: Optional[bool] = None, + ): + if storage_config is None: + storage_config = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} + self._storage_config = storage_config.copy() + self._storage_config["basename"] = self._storage_config.get("basename", _random_name(11)) + self._basename = self._storage_config["basename"] + self._batch_size = 10000 + self._threshold = threshold + self._num_perm = num_perm + self._weights = weights + self._params = params + self.prepickle = storage_config["type"] == "aioredis" if prepickle is None else prepickle + self._require_bytes_keys = not self.prepickle + + if self._threshold > 1.0 or self._threshold < 0.0: + raise ValueError("threshold must be in [0.0, 1.0]") + if self._num_perm < 2: + raise ValueError("Too few permutation functions") + if any(w < 0.0 or w > 1.0 for w in self._weights): + raise ValueError("Weight must be in [0.0, 1.0]") + if sum(self._weights) != 1.0: + raise ValueError("Weights must sum to 1.0") + self.h = self._num_perm + if self._params is not None: + self.b, self.r = self._params + if self.b * self.r > self._num_perm: + raise ValueError("The product of b and r must be less than num_perm") + else: + false_positive_weight, false_negative_weight = self._weights + self.b, self.r = _optimal_param( + self._threshold, self._num_perm, false_positive_weight, false_negative_weight + ) + + self.hashranges = [(i * self.r, (i + 1) * self.r) for i in range(self.b)] + self.hashtables = None + self.keys = None + + self._lock = asyncio.Lock() + self._initialized = False + + async def __async_init(self): + async with self._lock: + if not self._initialized: + await self.init_storages() + self._initialized = True + return self + + def __await__(self): + return self.__async_init().__await__() + + async def __aenter__(self): + return await self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __getstate__(self): + state = self.__dict__.copy() + state["_initialized"] = False + state.pop("_lock") + state.pop("hashranges") + state.pop("hashtables") + state.pop("keys") + return state + + def __setstate__(self, state): + state["_lock"] = asyncio.Lock() + self.__dict__ = state + self.__init__( + self._threshold, self._num_perm, self._weights, self._params, self._storage_config, self.prepickle + ) + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + if self.keys is not None: + self.keys.batch_size = value + else: + raise AttributeError("AsyncMinHashLSH is not initialized.") + + for t in self.hashtables: + t.batch_size = value + + self._batch_size = value + + async def _create_storages(self): + if self._storage_config["type"] == "aioredis": + name_ordered = b"".join([self._basename, b"_keys"]) + fs = ( + async_unordered_storage( + config=self._storage_config, + name=b"".join([self._basename, b"_bucket_", bytes([i])]), + ) + for i in range(self.b) + ) + else: + name_ordered = "".join([self._basename.decode("utf-8"), "_keys"]) + fs = ( + async_unordered_storage( + config=self._storage_config, name="".join([self._basename.decode("utf-8"), "_bucket_", str(i)]) + ) + for i in range(self.b) + ) + + fs = chain(fs, (async_ordered_storage(self._storage_config, name=name_ordered),)) + storages = await asyncio.gather(*fs) + *self.hashtables, self.keys = storages + + async def init_storages(self): + if self.keys is None: + await self._create_storages() + + async def close(self): + """Cleanup client resources and disconnect from AsyncMinHashLSH storage.""" + async with self._lock: + for t in self.hashtables: + await t.close() + + if self.keys is not None: + await self.keys.close() + + self._initialized = False + + async def insert(self, key, minhash, check_duplication=True): + """See :class:`datasketch.MinHashLSH`.""" + await self._insert(key, minhash, check_duplication=check_duplication, buffer=False) + + def insertion_session(self, batch_size=10000): + """Create a asynchronous context manager for fast insertion in index. + + :param int batch_size: the size of chunks to use in insert_session mode (default=10000). + + :return: datasketch.aio.lsh.AsyncMinHashLSHInsertionSession + + Example: + .. code-block:: python + + import asyncio + from datasketch.aio import AsyncMinHashLSH + from datasketch import MinHash + + async def main(): + storage_config = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} + async with AsyncMinHashLSH( + storage_config=storage_config, threshold=0.5, num_perm=16, prepickle=True + ) as lsh: + async with lsh.insertion_session(batch_size=1000) as session: + m = MinHash(num_perm=16) + m.update(b"data") + await session.insert("key", m) + + asyncio.run(main()) + + """ + return AsyncMinHashLSHInsertionSession(self, batch_size=batch_size) + + def delete_session(self, batch_size=10000): + """Create a asynchronous context manager for fast removal of keys + from index. + + :param int batch_size: the size of chunks to use in delete_session mode (default=10000). + + :return: datasketch.aio.lsh.AsyncMinHashLSHDeleteSession + + Example: + .. code-block:: python + + import asyncio + from datasketch.aio import AsyncMinHashLSH + from datasketch import MinHash + + async def main(): + storage_config = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} + async with AsyncMinHashLSH( + storage_config=storage_config, threshold=0.5, num_perm=16, prepickle=True + ) as lsh: + # Insert some data first + m = MinHash(num_perm=16) + m.update(b"data") + await lsh.insert("key1", m) + + # Delete using session + async with lsh.delete_session(batch_size=100) as session: + await session.remove("key1") + + asyncio.run(main()) + + """ + return AsyncMinHashLSHDeleteSession(self, batch_size=batch_size) + + async def _insert(self, key, minhash, check_duplication=True, buffer=False): + if len(minhash) != self.h: + raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) + if self._require_bytes_keys and not isinstance(key, bytes): + raise TypeError( + f"prepickle=False requires bytes keys for non-dict storage, got {type(key).__name__}. " + "Either pass bytes keys or use prepickle=True for automatic serialization." + ) + if self.prepickle: + key = pickle.dumps(key) + + if check_duplication and await self.has_key(key): + raise ValueError("The given key already exists") + Hs = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges] + + fs = chain( + (self.keys.insert(key, *Hs, buffer=buffer),), + (hashtable.insert(H, key, buffer=buffer) for H, hashtable in zip(Hs, self.hashtables)), + ) + await asyncio.gather(*fs) + + async def query(self, minhash): + """See :class:`datasketch.MinHashLSH`.""" + if len(minhash) != self.h: + raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) + + fs = ( + hashtable.get(self._H(minhash.hashvalues[start:end])) + for (start, end), hashtable in zip(self.hashranges, self.hashtables) + ) + candidates = frozenset(chain.from_iterable(await asyncio.gather(*fs))) + if self.prepickle: + return [pickle.loads(key) for key in candidates] + return list(candidates) + + async def has_key(self, key): + """See :class:`datasketch.MinHashLSH`.""" + return await self.keys.has_key(key) + + async def remove(self, key): + """See :class:`datasketch.MinHashLSH`.""" + await self._remove(key, buffer=False) + + async def _remove(self, key, buffer=False): + if not await self.has_key(key): + raise ValueError("The given key does not exist") + + for H, hashtable in zip(await self.keys.get(key), self.hashtables): + await hashtable.remove_val(H, key, buffer=buffer) + if not await hashtable.get(H): + await hashtable.remove(H, buffer=buffer) + + await self.keys.remove(key, buffer=buffer) + + async def is_empty(self): + """See :class:`datasketch.MinHashLSH`.""" + for t in self.hashtables: + if await t.size() == 0: + return True + return False + + @staticmethod + def _H(hs): + return bytes(hs.byteswap().data) + + async def _query_b(self, minhash, b): + if len(minhash) != self.h: + raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) + if b > len(self.hashtables): + raise ValueError("b must be less or equal to the number of hash tables") + fs = [] + for (start, end), hashtable in zip(self.hashranges[:b], self.hashtables[:b]): + H = self._H(minhash.hashvalues[start:end]) + if await hashtable.has_key(H): + fs.append(hashtable.get(H)) + return set(chain.from_iterable(await asyncio.gather(*fs))) # candidates + + async def get_counts(self): + """See :class:`datasketch.MinHashLSH`.""" + fs = (hashtable.itemcounts() for hashtable in self.hashtables) + return await asyncio.gather(*fs) + + async def get_subset_counts(self, *keys): + """See :class:`datasketch.MinHashLSH`.""" + key_set = list(set(keys)) + hashtables = [unordered_storage({"type": "dict"}) for _ in range(self.b)] + Hss = await self.keys.getmany(*key_set) + for key, Hs in zip(key_set, Hss): + for H, hashtable in zip(Hs, hashtables): + hashtable.insert(H, key) + return [hashtable.itemcounts() for hashtable in hashtables] + + +class AsyncMinHashLSHInsertionSession: + """Context manager for batch insertion.""" + + def __init__(self, lsh: AsyncMinHashLSH, batch_size: int): + self.lsh = lsh + self.lsh.batch_size = batch_size + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + fs = chain((self.lsh.keys.empty_buffer(),), (hashtable.empty_buffer() for hashtable in self.lsh.hashtables)) + await asyncio.gather(*fs) + + async def insert(self, key, minhash, check_duplication=True): + """See :class:`datasketch.MinHashLSH`.""" + await self.lsh._insert(key, minhash, check_duplication=check_duplication, buffer=True) + + +class AsyncMinHashLSHDeleteSession: + """Context manager for batch removal of keys.""" + + def __init__(self, lsh: AsyncMinHashLSH, batch_size: int): + self.lsh = lsh + self.lsh.batch_size = batch_size + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + fs = chain((self.lsh.keys.empty_buffer(),), (hashtable.empty_buffer() for hashtable in self.lsh.hashtables)) + await asyncio.gather(*fs) + + async def remove(self, key): + """Remove key from LSH index.""" + await self.lsh._remove(key, buffer=True) diff --git a/datasketch/aio/storage.py b/datasketch/aio/storage.py new file mode 100644 index 00000000..3b42da72 --- /dev/null +++ b/datasketch/aio/storage.py @@ -0,0 +1,466 @@ +"""Async storage backends for MinHash LSH. + +This module provides async storage implementations for use with AsyncMinHashLSH: +- AsyncMongoListStorage / AsyncMongoSetStorage: MongoDB storage via motor +- AsyncRedisListStorage / AsyncRedisSetStorage: Redis storage via redis.asyncio +""" + +import asyncio +import os +from abc import ABCMeta +from itertools import chain + +from datasketch.storage import OrderedStorage, Storage, UnorderedStorage, _random_name + +# RedisStorage is only available when redis package is installed (optional dependency) +# Import it conditionally to avoid ImportError when redis is not installed +try: + from datasketch.storage import RedisStorage +except ImportError: + RedisStorage = None + +ABC = ABCMeta("ABC", (object,), {}) + +try: + import motor.motor_asyncio + from pymongo import ReturnDocument +except ImportError: + motor = None + ReturnDocument = None + +try: + import redis + + if redis.__version__ < "4.2.0rc1": + raise ImportError("Can't use AsyncMinHashLSH module. Redis version should be >=4.2.0rc1") + import redis.asyncio as redis +except ImportError: + redis = None + + +__all__ = [ + "async_ordered_storage", + "async_unordered_storage", +] + + +async def async_ordered_storage(config, name=None): + tp = config["type"] + if tp == "aiomongo": + if motor is None: + raise RuntimeError("motor is not installed") + return AsyncMongoListStorage(config, name=name) + if tp == "aioredis": + if redis is None: + raise RuntimeError("redis is not installed") + return AsyncRedisListStorage(config, name=name) + raise ValueError('Unknown config ["type"]') + + +async def async_unordered_storage(config, name=None): + tp = config["type"] + if tp == "aiomongo": + if motor is None: + raise RuntimeError("motor is not installed") + return AsyncMongoSetStorage(config, name=name) + if tp == "aioredis": + if redis is None: + raise RuntimeError("redis is not installed") + return AsyncRedisSetStorage(config, name=name) + raise ValueError('Unknown config ["type"]') + + +if motor is not None and ReturnDocument is not None: + + class AsyncMongoBuffer: + def __init__(self, aio_mongo_collection, batch_size): + self._batch_size = batch_size + self._insert_documents_stack = [] + self._delete_by_key_documents_stack = [] + self._delete_by_val_documents_stack = [] + self._mongo_coll = aio_mongo_collection + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + self._batch_size = value + + async def execute_command(self, **kwargs): + command = kwargs.pop("command") + if command == "insert": + if len(self._insert_documents_stack) >= self.batch_size: + await self.execute(command) + self._insert_documents_stack.append(kwargs["obj"]) + elif command == "delete_by_key": + if len(self._delete_by_key_documents_stack) >= self.batch_size: + await self.execute(command) + self._delete_by_key_documents_stack.append(kwargs["key"]) + elif command == "delete_by_val": + if len(self._delete_by_val_documents_stack) >= self.batch_size: + await self.execute(command) + self._delete_by_val_documents_stack.append(kwargs["val"]) + + async def execute(self, command): + if command == "insert" and self._insert_documents_stack: + buffer = self._insert_documents_stack + self._insert_documents_stack = [] + await self._mongo_coll.insert_many(buffer, ordered=False) + elif command == "delete_by_key" and self._delete_by_key_documents_stack: + buffer = self._delete_by_key_documents_stack + self._delete_by_key_documents_stack = [] + await self._mongo_coll.delete_many({"key": {"$in": buffer}}) + elif command == "delete_by_val" and self._delete_by_val_documents_stack: + buffer = self._delete_by_val_documents_stack + self._delete_by_val_documents_stack = [] + await self._mongo_coll.delete_many({"vals": {"$in": buffer}}) + + async def insert_one(self, **kwargs): + await self.execute_command(obj=kwargs["document"], command="insert") + + async def delete_many_by_key(self, **kwargs): + await self.execute_command(key=kwargs["key"], command="delete_by_key") + + async def delete_many_by_val(self, **kwargs): + await self.execute_command(val=kwargs["val"], command="delete_by_val") + + class AsyncMongoStorage: + """Base class for asynchronous MongoDB-based storage containers. + + :param dict config: MongoDB storage units require a configuration + of the form:: + + storage_config = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} + + one can refer to system environment variables via:: + + storage_config={ + 'type': 'aiomongo', + 'mongo': { + 'host': {'env': 'MONGO_HOSTNAME', + 'default':'localhost'}, + 'port': 27017} + } + } + + :param bytes name: see :class:`datasketch.storage.RedisStorage` (default = None). + """ + + def __init__(self, config, name=None): + if config["type"] != "aiomongo": + raise ValueError("Storage type <{}> not supported".format(config["type"])) + self._config = config + self._mongo_param = self._parse_config(self._config["mongo"]) + + self._name = name if name else _random_name(11).decode("utf-8") + if "collection_name" in self.mongo_param: + self._collection_name = self.mongo_param["collection_name"] + elif "collection_prefix" in self.mongo_param: + self._collection_name = self.mongo_param["collection_prefix"] + self._name + else: + self._collection_name = "lsh_" + self._name + + db_lsh = self.mongo_param.get("db", "db_0") + if "url" in self.mongo_param: + dsn = self.mongo_param["url"] + elif "replica_set" in self.mongo_param: + dsn = "mongodb://{replica_set_nodes}/?replicaSet={replica_set}".format(**self.mongo_param) + elif "username" in self.mongo_param or "password" in self.mongo_param: + dsn = "mongodb://{username}:{password}@{host}:{port}".format(**self.mongo_param) + else: + dsn = "mongodb://{host}:{port}".format(**self.mongo_param) + + additional_args = self.mongo_param.get("args", {}) + + self._batch_size = 1000 + self._mongo_client = motor.motor_asyncio.AsyncIOMotorClient(dsn, **additional_args) + self._collection = self._mongo_client.get_default_database(db_lsh).get_collection(self._collection_name) + self._collection.create_index("key", background=True) + + self._initialized = True + self._buffer = AsyncMongoBuffer(self._collection, self._batch_size) + + async def close(self): + fs = (self._buffer.execute(command) for command in ("insert", "delete_by_key", "delete_by_val")) + await asyncio.gather(*fs) + self._mongo_client.close() + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + self._batch_size = value + self._buffer.batch_size = value + + @property + def initialized(self): + return self._initialized + + @property + def mongo_param(self): + return self._mongo_param + + @staticmethod + def _parse_config(config): + cfg = {} + for key, value in config.items(): + if isinstance(value, dict) and "env" in value: + value = os.getenv(value["env"], value.get("default", None)) + cfg[key] = value + return cfg + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_mongo_client") + state.pop("_collection") + state.pop("_buffer") + state["_initialized"] = False + return state + + def __setstate__(self, state): + self.__dict__ = state + self.__init__(self._config, name=self._name) + + class AsyncMongoListStorage(OrderedStorage, AsyncMongoStorage): + async def keys(self): + return [doc["key"] async for doc in self._collection.find(projection={"_id": False, "vals": False})] + + async def get(self, key: str): + return list( + chain.from_iterable( + [ + doc["vals"] + async for doc in self._collection.find( + filter={"key": key}, projection={"_id": False, "key": False} + ) + ] + ) + ) + + async def insert(self, key, *vals, **kwargs): + buffer = kwargs.pop("buffer", False) + if buffer: + await self._insert(self._buffer, key, *vals) + else: + await self._insert(self._collection, key, *vals) + + async def _insert(self, obj, key, *values): + await obj.insert_one(document={"key": key, "vals": values}) + + async def remove(self, *keys, **kwargs): + buffer = kwargs.pop("buffer", False) + if buffer: + fs = (self._buffer.delete_many_by_key(key=key) for key in keys) + await asyncio.gather(*fs) + else: + await self._collection.delete_many({"key": {"$in": keys}}) + + async def remove_val(self, key, val, **kwargs): + pass + + async def size(self): + return await self._collection.count_documents({}) + + async def itemcounts(self): + return { + doc["_id"]: doc["count"] + async for doc in self._collection.aggregate([{"$group": {"_id": "$key", "count": {"$sum": 1}}}]) + } + + async def has_key(self, key): + return bool(await self._collection.find_one({"key": key})) + + async def getmany(self, *keys): + return [await self.get(key) for key in keys] + + async def status(self): + status = self._parse_config(self.config["mongo"]) + status.update({"keyspace_size": await self.size()}) + return status + + async def empty_buffer(self): + fs = (self._buffer.execute(command) for command in ("insert", "delete_by_key", "delete_by_val")) + await asyncio.gather(*fs) + + class AsyncMongoSetStorage(UnorderedStorage, AsyncMongoListStorage): + async def get(self, key): + keys = [ + doc["vals"] + async for doc in self._collection.find(filter={"key": key}, projection={"_id": False, "key": False}) + ] + return frozenset(keys) + + async def _insert(self, obj, key, *values): + await obj.insert_one(document={"key": key, "vals": values[0]}) + + async def remove(self, *keys, **kwargs): + pass + + async def remove_val(self, key, val, **kwargs): + buffer = kwargs.pop("buffer", False) + if buffer: + await self._buffer.delete_many_by_val(val=val) + else: + await self._collection.find_one_and_delete({"key": key, "vals": val}) + + +# Redis-based async storage classes are only defined when both redis package +# and RedisStorage are available (optional dependencies) +if redis is not None and RedisStorage is not None: + + class AsyncRedisBuffer(redis.client.Pipeline): + def __init__(self, connection_pool, response_callbacks, transaction, buffer_size, shard_hint=None): + self._buffer_size = buffer_size + super(AsyncRedisBuffer, self).__init__( + connection_pool, response_callbacks, transaction, shard_hint=shard_hint + ) + + @property + def buffer_size(self): + return self._buffer_size + + @buffer_size.setter + def buffer_size(self, value): + self._buffer_size = value + + async def execute_command(self, *args, **kwargs): + if len(self.command_stack) >= self._buffer_size: + await self.execute() + await super(AsyncRedisBuffer, self).execute_command(*args, **kwargs) + + class AsyncRedisStorage(RedisStorage): + def __init__(self, config, name=None): + super(AsyncRedisStorage, self).__init__(config, name) + self.config = config + self._buffer_size = 50000 + redis_param = self._parse_config(self.config["redis"]) + self._redis = redis.Redis(**redis_param) + redis_buffer_param = self._parse_config(self.config.get("redis_buffer", {})) + self._buffer = AsyncRedisBuffer( + self._redis.connection_pool, + self._redis.response_callbacks, + transaction=redis_buffer_param.get("transaction", True), + buffer_size=self._buffer_size, + ) + self._initialized = True + + async def close(self): + await self._redis.aclose() + + @property + def initialized(self): + return self._initialized + + class AsyncRedisListStorage(OrderedStorage, AsyncRedisStorage): + async def keys(self): + return await self._redis.hkeys(self._name) # type: ignore + + async def redis_keys(self): + return await self._redis.hvals(self._name) # type: ignore + + def status(self): + status = self._parse_config(self.config["redis"]) + status.update(Storage.status(self)) + return status + + async def get(self, key): + return await self._get_items(self._redis, self.redis_key(key)) + + async def getmany(self, *keys): + pipe = self._redis.pipeline() + for key in keys: + pipe.lrange(self.redis_key(key), 0, -1) + return await pipe.execute() + + @staticmethod + async def _get_items(r, k): + return await r.lrange(k, 0, -1) + + async def remove(self, *keys, **kwargs): + buffer = kwargs.pop("buffer", False) + if buffer: + await self._remove(self._buffer, *keys) + else: + await self._remove(self._redis, *keys) + + async def _remove(self, r, *keys): + await r.hdel(self._name, *keys) + await r.delete(*[self.redis_key(key) for key in keys]) + + async def remove_val(self, key, val, **kwargs): + buffer = kwargs.pop("buffer", False) + redis_key = self.redis_key(key) + if buffer: + await self._buffer.lrem(redis_key, val) + else: + await self._redis.lrem(redis_key, val) + if not await self._redis.exists(redis_key): # type: ignore + await self._redis.hdel(self._name, redis_key) # type: ignore + + async def insert(self, key, *vals, **kwargs): + # Using buffer=True outside of an `insertion_session` + # could lead to inconsistencies, because those + # insertion will not be processed until the + # buffer is cleared + buffer = kwargs.pop("buffer", False) + if buffer: + await self._insert(self._buffer, key, *vals) + else: + await self._insert(self._redis, key, *vals) + + async def _insert(self, r, key, *values): + redis_key = self.redis_key(key) + await r.hset(self._name, key, redis_key) + await r.rpush(redis_key, *values) + + async def size(self): + return await self._redis.hlen(self._name) # type: ignore + + async def itemcounts(self): + pipe = self._redis.pipeline() + pipe.multi() + ks = await self.keys() + for k in ks: + await self._get_len(pipe, self.redis_key(k)) + return dict(zip(ks, await pipe.execute())) + + @staticmethod + async def _get_len(r, k): + return await r.llen(k) + + async def has_key(self, key): + return await self._redis.hexists(self._name, key) # type: ignore + + async def empty_buffer(self): + await self._buffer.execute() + # To avoid broken pipes, recreate the connection + # objects upon emptying the buffer + self.__init__(self.config, name=self._name) + + class AsyncRedisSetStorage(UnorderedStorage, AsyncRedisListStorage): + @staticmethod + async def _get_items(r, k): + return await r.smembers(k) + + async def remove_val(self, key, val, **kwargs): + buffer = kwargs.pop("buffer", False) + redis_key = self.redis_key(key) + if buffer: + await self._buffer.srem(redis_key, val) + else: + await self._redis.srem(redis_key, val) + if not await self._redis.exists(redis_key): # type: ignore + await self._redis.hdel(self._name, redis_key) # type: ignore + + async def _insert(self, r, key, *values): + redis_key = self.redis_key(key) + await r.hset(self._name, key, redis_key) + await r.sadd(redis_key, *values) + + @staticmethod + async def _get_len(r, k): + return await r.scard(k) diff --git a/datasketch/experimental/__init__.py b/datasketch/experimental/__init__.py index f029f482..53a26279 100644 --- a/datasketch/experimental/__init__.py +++ b/datasketch/experimental/__init__.py @@ -1,15 +1,26 @@ -"""Warning. +"""Deprecated experimental module. -datasketch.experimental is dedicated to new modules that are to be merged into -the stable interface of datasketch. So their interfaces may change in future -versions. +.. deprecated:: + The `datasketch.experimental` module is deprecated and will be removed in a future version. + Please use `datasketch.aio` instead: -To add a new class or function, register it here in this file. For example: - -from new_module import NewModuleClass + Old: ``from datasketch.experimental import AsyncMinHashLSH`` + New: ``from datasketch.aio import AsyncMinHashLSH`` + Or simply: ``from datasketch import AsyncMinHashLSH`` """ -from datasketch.experimental.aio.lsh import AsyncMinHashLSH +import warnings + +warnings.warn( + "datasketch.experimental is deprecated. " + "Use 'from datasketch.aio import AsyncMinHashLSH' or " + "'from datasketch import AsyncMinHashLSH' instead.", + DeprecationWarning, + stacklevel=2, +) + +# Re-export from new location for backward compatibility +from datasketch.aio import AsyncMinHashLSH # noqa: E402 __all__ = ["AsyncMinHashLSH"] diff --git a/datasketch/experimental/aio/__init__.py b/datasketch/experimental/aio/__init__.py index e69de29b..2aeacdcf 100644 --- a/datasketch/experimental/aio/__init__.py +++ b/datasketch/experimental/aio/__init__.py @@ -0,0 +1,31 @@ +"""Deprecated experimental aio module. + +.. deprecated:: + The `datasketch.experimental.aio` module is deprecated and will be removed in a future version. + Please use `datasketch.aio` instead: + + Old: ``from datasketch.experimental.aio import AsyncMinHashLSH`` + New: ``from datasketch.aio import AsyncMinHashLSH`` +""" + +import warnings + +warnings.warn( + "datasketch.experimental.aio is deprecated. " + "Use 'from datasketch.aio import AsyncMinHashLSH' instead.", + DeprecationWarning, + stacklevel=2, +) + +# Re-export from new location for backward compatibility +from datasketch.aio import ( # noqa: E402 + AsyncMinHashLSH, + AsyncMinHashLSHDeleteSession, + AsyncMinHashLSHInsertionSession, +) + +__all__ = [ + "AsyncMinHashLSH", + "AsyncMinHashLSHDeleteSession", + "AsyncMinHashLSHInsertionSession", +] diff --git a/datasketch/experimental/aio/lsh.py b/datasketch/experimental/aio/lsh.py index ebeeb2f1..442e946f 100644 --- a/datasketch/experimental/aio/lsh.py +++ b/datasketch/experimental/aio/lsh.py @@ -1,414 +1,31 @@ -import asyncio -import pickle -from itertools import chain -from typing import Optional - -from datasketch.experimental.aio.storage import ( - async_ordered_storage, - async_unordered_storage, -) -from datasketch.lsh import _optimal_param -from datasketch.storage import _random_name, unordered_storage - - -class AsyncMinHashLSH: - """Asynchronous MinHashLSH index. - - :param float threshold: see :class:`datasketch.MinHashLSH`. - :param int num_perm: see :class:`datasketch.MinHashLSH`. - :param weights: see :class:`datasketch.MinHashLSH`. - :type weights: tuple(float, float) - :param tuple params: see :class:`datasketch.MinHashLSH`. - :param dict storage_config: New type of storage service - aiomongo - to use for storing - hashtables and keys are implemented. - If storage_config is None aiomongo storage will be used. - :param prepickle (bool, optional): If True, all keys are pickled to bytes before - insertion. If None, a default value is chosen based on the - `storage_config`. - For example usage see :ref:`minhash_lsh_async`. - - Example of supported storage configuration: - - .. code-block:: python - - MONGO = {"type": "aiomongo", "basename": "base_name_1", "mongo": {"host": "localhost", "port": 27017}} - - .. note:: - * The module supports Python version >=3.6, and is currently experimental. - So the interface may change slightly in the future. - * For main functionality of LSH algorithm see :class:`datasketch.MinHashLSH`. - * For additional information see :ref:`minhash_lsh_at_scale` and :ref:`minhash_lsh_async` - """ - - def __init__( - self, - threshold: float = 0.9, - num_perm: int = 128, - weights: tuple[float, float] = (0.5, 0.5), - params: Optional[tuple[int, int]] = None, - storage_config: Optional[dict] = None, - prepickle: Optional[bool] = None, - ): - if storage_config is None: - storage_config = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} - self._storage_config = storage_config.copy() - self._storage_config["basename"] = self._storage_config.get("basename", _random_name(11)) - self._basename = self._storage_config["basename"] - self._batch_size = 10000 - self._threshold = threshold - self._num_perm = num_perm - self._weights = weights - self._params = params - self.prepickle = storage_config["type"] == "aioredis" if prepickle is None else prepickle - self._require_bytes_keys = not self.prepickle - - if self._threshold > 1.0 or self._threshold < 0.0: - raise ValueError("threshold must be in [0.0, 1.0]") - if self._num_perm < 2: - raise ValueError("Too few permutation functions") - if any(w < 0.0 or w > 1.0 for w in self._weights): - raise ValueError("Weight must be in [0.0, 1.0]") - if sum(self._weights) != 1.0: - raise ValueError("Weights must sum to 1.0") - self.h = self._num_perm - if self._params is not None: - self.b, self.r = self._params - if self.b * self.r > self._num_perm: - raise ValueError("The product of b and r must be less than num_perm") - else: - false_positive_weight, false_negative_weight = self._weights - self.b, self.r = _optimal_param( - self._threshold, self._num_perm, false_positive_weight, false_negative_weight - ) - - self.hashranges = [(i * self.r, (i + 1) * self.r) for i in range(self.b)] - self.hashtables = None - self.keys = None - - self._lock = asyncio.Lock() - self._initialized = False - - async def __async_init(self): - async with self._lock: - if not self._initialized: - await self.init_storages() - self._initialized = True - return self - - def __await__(self): - return self.__async_init().__await__() - - async def __aenter__(self): - return await self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - def __getstate__(self): - state = self.__dict__.copy() - state["_initialized"] = False - state.pop("_lock") - state.pop("hashranges") - state.pop("hashtables") - state.pop("keys") - return state - - def __setstate__(self, state): - state["_lock"] = asyncio.Lock() - self.__dict__ = state - self.__init__( - self._threshold, self._num_perm, self._weights, self._params, self._storage_config, self.prepickle - ) - - @property - def batch_size(self): - return self._batch_size - - @batch_size.setter - def batch_size(self, value): - if self.keys is not None: - self.keys.batch_size = value - else: - raise AttributeError("AsyncMinHash is not initialized.") - - for t in self.hashtables: - t.batch_size = value - - self._batch_size = value - - async def _create_storages(self): - if self._storage_config["type"] == "aioredis": - name_ordered = b"".join([self._basename, b"_keys"]) - fs = ( - async_unordered_storage( - config=self._storage_config, - name=b"".join([self._basename, b"_bucket_", bytes([i])]), - ) - for i in range(self.b) - ) - else: - name_ordered = "".join([self._basename.decode("utf-8"), "_keys"]) - fs = ( - async_unordered_storage( - config=self._storage_config, name="".join([self._basename.decode("utf-8"), "_bucket_", str(i)]) - ) - for i in range(self.b) - ) - - fs = chain(fs, (async_ordered_storage(self._storage_config, name=name_ordered),)) - storages = await asyncio.gather(*fs) - *self.hashtables, self.keys = storages - - async def init_storages(self): - if self.keys is None: - await self._create_storages() - - if not self.keys.initialized: - await self.keys - - fs = (ht for ht in self.hashtables if not ht.initialized) - await asyncio.gather(*fs) - - async def close(self): - """Cleanup client resources and disconnect from AsyncMinHashLSH storage.""" - async with self._lock: - for t in self.hashtables: - await t.close() - - if self.keys is not None: - await self.keys.close() - - self._initialized = False - - async def insert(self, key, minhash, check_duplication=True): - """See :class:`datasketch.MinHashLSH`.""" - await self._insert(key, minhash, check_duplication=check_duplication, buffer=False) - - def insertion_session(self, batch_size=10000): - """Create a asynchronous context manager for fast insertion in index. - - :param int batch_size: the size of chunks to use in insert_session mode (default=10000). - - :return: datasketch.experimental.aio.lsh.AsyncMinHashLSHSession - - Example: - .. code-block:: python - - from datasketch.experimental.aio.lsh import AsyncMinHashLSH - from datasketch import MinHash - - - def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - - _chunked_str = chunk((random.choice(string.ascii_lowercase) for _ in range(10000)), 4) - seq = frozenset( - chain( - ("".join(s) for s in _chunked_str), - ("aahhb", "aahh", "aahhc", "aac", "kld", "bhg", "kkd", "yow", "ppi", "eer"), - ) - ) - objs = [MinHash(16) for _ in range(len(seq))] - for e, obj in zip(seq, objs): - for i in e: - obj.update(i.encode("utf-8")) - data = [(e, m) for e, m in zip(seq, objs)] - - _storage_config_redis = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} +"""Deprecated experimental aio lsh module. +.. deprecated:: + The `datasketch.experimental.aio.lsh` module is deprecated and will be removed in a future version. + Please use `datasketch.aio.lsh` instead: - async def func(): - async with AsyncMinHashLSH(storage_config=_storage_config_redis, threshold=0.5, num_perm=16) as lsh: - async with lsh.insertion_session(batch_size=1000) as session: - fs = (session.insert(key, minhash, check_duplication=True) for key, minhash in data) - await asyncio.gather(*fs) + Old: ``from datasketch.experimental.aio.lsh import AsyncMinHashLSH`` + New: ``from datasketch.aio import AsyncMinHashLSH`` +""" - """ - return AsyncMinHashLSHInsertionSession(self, batch_size=batch_size) +import warnings - def delete_session(self, batch_size=10000): - """Create a asynchronous context manager for fast removal of keys - from index. - - :param int batch_size: the size of chunks to use in insert_session mode (default=10000). - - :return: datasketch.experimental.aio.lsh.AsyncMinHashLSHSession - - Example: - .. code-block:: python - - from datasketch.experimental.aio.lsh import AsyncMinHashLSH - from datasketch import MinHash - - - def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - - _chunked_str = chunk((random.choice(string.ascii_lowercase) for _ in range(10000)), 4) - seq = frozenset( - chain( - ("".join(s) for s in _chunked_str), - ("aahhb", "aahh", "aahhc", "aac", "kld", "bhg", "kkd", "yow", "ppi", "eer"), - ) - ) - objs = [MinHash(16) for _ in range(len(seq))] - for e, obj in zip(seq, objs): - for i in e: - obj.update(i.encode("utf-8")) - data = [(e, m) for e, m in zip(seq, objs)] - - _storage_config_redis = {"type": "aiomongo", "mongo": {"host": "localhost", "port": 27017}} - - - async def func(): - async with AsyncMinHashLSH(storage_config=_storage_config_redis, threshold=0.5, num_perm=16) as lsh: - async with lsh.insertion_session(batch_size=1000) as session: - fs = (session.insert(key, minhash, check_duplication=True) for key, minhash in data) - await asyncio.gather(*fs) - - async with lsh.delete_session(batch_size=3) as session: - fs = (session.remove(key) for key in keys_to_remove) - await asyncio.gather(*fs) - - """ - return AsyncMinHashLSHDeleteSession(self, batch_size=batch_size) - - async def _insert(self, key, minhash, check_duplication=True, buffer=False): - if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) - if self._require_bytes_keys and not isinstance(key, bytes): - raise TypeError( - f"prepickle=False requires bytes keys for non-dict storage, got {type(key).__name__}. " - "Either pass bytes keys or use prepickle=True for automatic serialization." - ) - if self.prepickle: - key = pickle.dumps(key) - - if check_duplication and await self.has_key(key): - raise ValueError("The given key already exists") - Hs = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges] - - fs = chain( - (self.keys.insert(key, *Hs, buffer=buffer),), - (hashtable.insert(H, key, buffer=buffer) for H, hashtable in zip(Hs, self.hashtables)), - ) - await asyncio.gather(*fs) - - async def query(self, minhash): - """See :class:`datasketch.MinHashLSH`.""" - if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) - - fs = ( - hashtable.get(self._H(minhash.hashvalues[start:end])) - for (start, end), hashtable in zip(self.hashranges, self.hashtables) - ) - candidates = frozenset(chain.from_iterable(await asyncio.gather(*fs))) - if self.prepickle: - return [pickle.loads(key) for key in candidates] - return list(candidates) - - async def has_key(self, key): - """See :class:`datasketch.MinHashLSH`.""" - return await self.keys.has_key(key) - - async def remove(self, key): - """See :class:`datasketch.MinHashLSH`.""" - await self._remove(key, buffer=False) - - async def _remove(self, key, buffer=False): - if not await self.has_key(key): - raise ValueError("The given key does not exist") - - for H, hashtable in zip(await self.keys.get(key), self.hashtables): - await hashtable.remove_val(H, key, buffer=buffer) - if not await hashtable.get(H): - await hashtable.remove(H, buffer=buffer) - - await self.keys.remove(key, buffer=buffer) - - async def is_empty(self): - """See :class:`datasketch.MinHashLSH`.""" - for t in self.hashtables: - if await t.size() == 0: - return True - return False - - @staticmethod - def _H(hs): - return bytes(hs.byteswap().data) - - async def _query_b(self, minhash, b): - if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" % (self.h, len(minhash))) - if b > len(self.hashtables): - raise ValueError("b must be less or equal to the number of hash tables") - fs = [] - for (start, end), hashtable in zip(self.hashranges[:b], self.hashtables[:b]): - H = self._H(minhash.hashvalues[start:end]) - if await hashtable.has_key(H): - fs.append(hashtable.get(H)) - return set(chain.from_iterable(await asyncio.gather(*fs))) # candidates - - async def get_counts(self): - """See :class:`datasketch.MinHashLSH`.""" - fs = (hashtable.itemcounts() for hashtable in self.hashtables) - return await asyncio.gather(*fs) - - async def get_subset_counts(self, *keys): - """See :class:`datasketch.MinHashLSH`.""" - key_set = list(set(keys)) - hashtables = [unordered_storage({"type": "dict"}) for _ in range(self.b)] - Hss = await self.keys.getmany(*key_set) - for key, Hs in zip(key_set, Hss): - for H, hashtable in zip(Hs, hashtables): - hashtable.insert(H, key) - return [hashtable.itemcounts() for hashtable in hashtables] - - -class AsyncMinHashLSHInsertionSession: - """Context manager for batch insertion.""" - - def __init__(self, lsh: AsyncMinHashLSH, batch_size: int): - self.lsh = lsh - self.lsh.batch_size = batch_size - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def close(self): - fs = chain((self.lsh.keys.empty_buffer(),), (hashtable.empty_buffer() for hashtable in self.lsh.hashtables)) - await asyncio.gather(*fs) - - async def insert(self, key, minhash, check_duplication=True): - """See :class:`datasketch.MinHashLSH`.""" - await self.lsh._insert(key, minhash, check_duplication=check_duplication, buffer=True) - - -class AsyncMinHashLSHDeleteSession: - """Context manager for batch removal of keys.""" - - def __init__(self, lsh: AsyncMinHashLSH, batch_size: int): - self.lsh = lsh - self.lsh.batch_size = batch_size - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() +warnings.warn( + "datasketch.experimental.aio.lsh is deprecated. " + "Use 'from datasketch.aio import AsyncMinHashLSH' instead.", + DeprecationWarning, + stacklevel=2, +) - async def close(self): - fs = chain((self.lsh.keys.empty_buffer(),), (hashtable.empty_buffer() for hashtable in self.lsh.hashtables)) - await asyncio.gather(*fs) +# Re-export from new location for backward compatibility +from datasketch.aio.lsh import ( # noqa: E402 + AsyncMinHashLSH, + AsyncMinHashLSHDeleteSession, + AsyncMinHashLSHInsertionSession, +) - async def remove(self, key): - """Remove key from LSH index.""" - await self.lsh._remove(key, buffer=True) +__all__ = [ + "AsyncMinHashLSH", + "AsyncMinHashLSHDeleteSession", + "AsyncMinHashLSHInsertionSession", +] diff --git a/docs/documentation.rst b/docs/documentation.rst index 202b4f01..869c11e6 100644 --- a/docs/documentation.rst +++ b/docs/documentation.rst @@ -38,7 +38,7 @@ MinHash LSH Asynchronous MinHash LSH ------------------------ -.. autoclass:: datasketch.experimental.aio.lsh.AsyncMinHashLSH +.. autoclass:: datasketch.aio.AsyncMinHashLSH :members: :special-members: diff --git a/docs/lsh.rst b/docs/lsh.rst index 7685213a..4fc827a5 100644 --- a/docs/lsh.rst +++ b/docs/lsh.rst @@ -237,8 +237,7 @@ Asynchronous MinHash LSH at scale --------------------------------- .. note:: - The module supports Python version >=3.6, and is currently experimental. - So the interface may change slightly in the future. + The module supports Python version >=3.6. This module may be useful if you want to process millions of text documents in streaming/batch mode using asynchronous RESTful API (for example, aiohttp) @@ -256,7 +255,7 @@ The Asynchronous MongoDB storage option can be configured using: .. code:: python - from datasketch.experimental.aio.lsh import AsyncMinHashLSH + from datasketch.aio import AsyncMinHashLSH from datasketch import MinHash _storage = {'type': 'aiomongo', 'mongo': {'host': 'localhost', 'port': 27017, 'db': 'lsh_test'}} @@ -277,7 +276,7 @@ The Asynchronous MongoDB storage option can be configured using: .. code:: python - from datasketch.experimental.aio.lsh import AsyncMinHashLSH + from datasketch.aio import AsyncMinHashLSH from datasketch import MinHash _storage = {'type': 'aiomongo', 'mongo': {'host': 'localhost', 'port': 27017, 'db': 'lsh_test'}} @@ -329,7 +328,7 @@ To create index for a large number of MinHashes using asynchronous MinHash LSH. .. code:: python - from datasketch.experimental.aio.lsh import AsyncMinHashLSH + from datasketch.aio import AsyncMinHashLSH from datasketch import MinHash def chunk(it, size): @@ -355,7 +354,7 @@ To bulk remove keys from LSH index using asynchronous MinHash LSH. .. code:: python - from datasketch.experimental.aio.lsh import AsyncMinHashLSH + from datasketch.aio import AsyncMinHashLSH from datasketch import MinHash def chunk(it, size): diff --git a/pyproject.toml b/pyproject.toml index 29487e1c..296d02c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ test = [ "pytest-rerunfailures", "pytest-asyncio", ] -experimental_aio = ["aiounittest", "motor>3.6.0"] +aio = ["aiounittest", "motor>3.6.0"] +experimental_aio = ["aiounittest", "motor>3.6.0"] # Deprecated alias for 'aio' [project.urls] Homepage = "https://ekzhu.github.io/datasketch" @@ -194,4 +195,4 @@ reportCallIssue = "none" [tool.coverage.run] source = ["datasketch"] -omit = ["*/experimental/*", "*/tests/*", "*/test/*"] +omit = ["*/tests/*", "*/test/*"] diff --git a/test/aio/test_lsh.py b/test/aio/test_lsh.py index 33c0b2a3..2dbe22be 100644 --- a/test/aio/test_lsh.py +++ b/test/aio/test_lsh.py @@ -10,7 +10,7 @@ from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient -from datasketch.experimental.aio.lsh import AsyncMinHashLSH +from datasketch.aio import AsyncMinHashLSH from datasketch.minhash import MinHash from datasketch.weighted_minhash import WeightedMinHashGenerator @@ -356,6 +356,37 @@ async def test_get_counts(self, storage_config): for table in counts: assert sum(table.values()) == 2 + async def test_get_subset_counts(self, storage_config): + """Test get_subset_counts which uses the getmany() method.""" + async with AsyncMinHashLSH(storage_config=storage_config, threshold=0.5, num_perm=16, prepickle=False) as lsh: + m1 = MinHash(16) + m1.update(b"a") + m2 = MinHash(16) + m2.update(b"b") + m3 = MinHash(16) + m3.update(b"c") + await lsh.insert(b"a", m1) + await lsh.insert(b"b", m2) + await lsh.insert(b"c", m3) + + # Test get_subset_counts with a subset of keys + subset_counts = await lsh.get_subset_counts(b"a", b"b") + assert len(subset_counts) == lsh.b + for table in subset_counts: + assert sum(table.values()) == 2 + + # Test with all keys + all_counts = await lsh.get_subset_counts(b"a", b"b", b"c") + assert len(all_counts) == lsh.b + for table in all_counts: + assert sum(table.values()) == 3 + + # Test with single key + single_counts = await lsh.get_subset_counts(b"a") + assert len(single_counts) == lsh.b + for table in single_counts: + assert sum(table.values()) == 1 + @pytest.mark.skipif(not DO_TEST_MONGO, reason="MongoDB-specific test") async def test_arbitrary_url(self): config = {