diff --git a/pydgraph/async_client.py b/pydgraph/async_client.py index 2cf6987..42e0c63 100644 --- a/pydgraph/async_client.py +++ b/pydgraph/async_client.py @@ -335,9 +335,7 @@ async def __aenter__(self) -> AsyncDgraphClient: """ return self - async def __aexit__( - self, exc_type: Any, exc_val: Any, exc_tb: Any - ) -> bool: + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: """Async context manager exit. Automatically closes all client connections. diff --git a/pydgraph/async_txn.py b/pydgraph/async_txn.py index 2cf1c74..a380918 100644 --- a/pydgraph/async_txn.py +++ b/pydgraph/async_txn.py @@ -216,9 +216,11 @@ async def do_request( # noqa: C901 query_error = error if query_error is not None: - # Try to discard the transaction on error + # Try to discard the transaction on error. + # Note: We use _discard_internal() here because we already hold self._lock, + # and asyncio.Lock is not reentrant. Calling self.discard() would deadlock. try: - await self.discard( + await self._discard_internal( timeout=timeout, metadata=metadata, credentials=credentials ) except asyncio.CancelledError: @@ -458,32 +460,61 @@ async def discard( Various gRPC errors on failure """ async with self._lock: - if not self._common_discard(): - return + await self._discard_internal( + timeout=timeout, metadata=metadata, credentials=credentials + ) - new_metadata = self._dg.add_login_metadata(metadata) - try: + async def _discard_internal( + self, + timeout: float | None = None, + metadata: list[tuple[str, str]] | None = None, + credentials: grpc.CallCredentials | None = None, + ) -> None: + """Internal discard implementation that doesn't acquire the lock. + + This method must only be called when the caller already holds self._lock. + Use discard() for the public API. + + Args: + timeout: Request timeout in seconds + metadata: Request metadata + credentials: Call credentials + + Raises: + AssertionError: If called without holding self._lock + Various gRPC errors on failure + """ + # Defensive check: ensure caller holds the lock to prevent misuse + assert self._lock.locked(), ( + "_discard_internal must only be called while holding self._lock" + ) + + if not self._common_discard(): + return + + new_metadata = self._dg.add_login_metadata(metadata) + try: + await self._dc.commit_or_abort( + self._ctx, + timeout=timeout, + metadata=new_metadata, + credentials=credentials, + ) + except asyncio.CancelledError: + raise + except Exception as error: + # Handle JWT expiration with automatic retry + if util.is_jwt_expired(error): + await self._dg.retry_login() + new_metadata = self._dg.add_login_metadata(metadata) await self._dc.commit_or_abort( self._ctx, timeout=timeout, metadata=new_metadata, credentials=credentials, ) - except asyncio.CancelledError: + else: raise - except Exception as error: - # Handle JWT expiration with automatic retry - if util.is_jwt_expired(error): - await self._dg.retry_login() - new_metadata = self._dg.add_login_metadata(metadata) - await self._dc.commit_or_abort( - self._ctx, - timeout=timeout, - metadata=new_metadata, - credentials=credentials, - ) - else: - raise def _common_discard(self) -> bool: """Validates and prepares for discard. @@ -533,9 +564,7 @@ async def __aenter__(self) -> AsyncTxn: """ return self - async def __aexit__( - self, exc_type: Any, exc_val: Any, exc_tb: Any - ) -> bool: + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: """Async context manager exit. Automatically discards transaction if not already finished. diff --git a/pydgraph/client.py b/pydgraph/client.py index 3656f0e..23d286f 100755 --- a/pydgraph/client.py +++ b/pydgraph/client.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Dgraph python client.""" + from __future__ import annotations import contextlib diff --git a/pydgraph/retry.py b/pydgraph/retry.py index bfe5369..83e1315 100644 --- a/pydgraph/retry.py +++ b/pydgraph/retry.py @@ -33,6 +33,7 @@ def upsert_user(client, name: str): txn.mutate(set_obj={"name": name}) txn.commit() """ + import asyncio import functools import logging