Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
49489ba
refactor: Refactored HttpAsyncHook to easily support session based ru…
dabla Jan 13, 2026
493e237
Merge branch 'main' into feature/add-session-async-http-hook
dabla Jan 13, 2026
d23a94d
fix: Fixed import of LoggingMixin
dabla Jan 13, 2026
da1bfd6
Merge branch 'main' into feature/add-session-async-http-hook
dabla Jan 17, 2026
27e30e9
refactor: LivyAsyncHook now reuses logic from HttpAsyncHook which is …
dabla Jan 30, 2026
d19c593
refactor: Reformatted HttpAsyncHook
dabla Jan 30, 2026
9203bd7
Merge branch 'main' into feature/add-session-async-http-hook
dabla Jan 30, 2026
d833458
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 1, 2026
a2cecc6
refactor: Fixed possible None types for merged_headers
dabla Feb 2, 2026
c233e8e
refactor: Changed type of _retryable_error_async method
dabla Feb 2, 2026
15754af
refactor: Removed unused import
dabla Feb 2, 2026
f9c503d
refactor: Moved SessionConfig inside AsyncHttpSession
dabla Feb 2, 2026
7e382b4
refactor: Reformatted run method of HttpAsyncHook
dabla Feb 2, 2026
7a109d8
refactor: Removed unused import from LivyHook module
dabla Feb 2, 2026
c8a10b9
Revert "refactor: Moved SessionConfig inside AsyncHttpSession"
dabla Feb 2, 2026
491e365
refactor: Added docstring for retry_limit and retry_delay parameters
dabla Feb 2, 2026
c77715e
refactor: Reformatted docstring in _retryable_error_async method
dabla Feb 2, 2026
53cb8b3
refactor: Added docstring for SessionConfig and AsyncHttpSession
dabla Feb 2, 2026
5eda002
refactor: Added warning logging when run attempt fails
dabla Feb 3, 2026
10620f4
refactor: Refactored run_method of LivyAsyncHook
dabla Feb 3, 2026
8a1f675
refactor: Refactored unit tests for LivyAsyncHook
dabla Feb 3, 2026
0beab47
refactor: Reformatted AsyncHttpSession
dabla Feb 3, 2026
d649850
refactor: Reformatted run_method of LivyAsyncHook
dabla Feb 3, 2026
bd93295
refactor: Escape aiohttp.ClientSession in docstring of session contex…
dabla Feb 3, 2026
232f0b2
refactor: Also take into extra_options from connection when building …
dabla Feb 3, 2026
d8918ef
refactor: Fixed mocking of test_run_method_success
dabla Feb 3, 2026
eb6ff25
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 3, 2026
068493f
refactor: Removed unused imports
dabla Feb 3, 2026
3ce90aa
refactor: Reorganized imports
dabla Feb 3, 2026
2280100
refactor: Run method of LivyAsyncHook must internally use session fro…
dabla Feb 3, 2026
db2a036
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 3, 2026
ed98651
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 3, 2026
ff2b4b1
refactor: Escape reserved words in HttpAsyncHook
dabla Feb 3, 2026
6e96278
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 3, 2026
ffdae83
refactor: Mock get_async_connection in TestLivyAsyncHook
dabla Feb 4, 2026
1719a39
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 4, 2026
4af843d
refactor: Mock get_async_connection in TestLivyAsyncHook should be pa…
dabla Feb 4, 2026
2fe9900
Merge branch 'main' into feature/add-session-async-http-hook
dabla Feb 4, 2026
5d798ee
refactor: Mock get_async_connection in TestLivyAsyncHook should be pa…
Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 22 additions & 106 deletions providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,18 @@
# under the License.
from __future__ import annotations

import asyncio
import json
import re
from collections.abc import Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import Any

import aiohttp
import requests
from aiohttp import ClientResponseError

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook

if TYPE_CHECKING:
from airflow.models import Connection


class BatchState(Enum):
"""Batch session states."""
Expand Down Expand Up @@ -502,101 +496,10 @@ def __init__(
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)

async def _do_api_call_async(
self,
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
headers: dict[str, Any] | None = None,
extra_options: dict[str, Any] | None = None,
) -> Any:
"""
Perform an asynchronous HTTP request call.

:param endpoint: the endpoint to be called i.e. resource/v1/query?
:param data: payload to be uploaded or request parameters
:param headers: additional headers to be passed through as a dictionary
:param extra_options: Additional kwargs to pass when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
"""
extra_options = extra_options or {}

# headers may be passed through directly or in the "extra" field in the connection
# definition
_headers = {}
auth = None

if self.http_conn_id:
conn = await get_async_connection(self.http_conn_id)

self.base_url = self._generate_base_url(conn) # type: ignore[arg-type]
if conn.login:
auth = self.auth_type(conn.login, conn.password)
if conn.extra:
try:
_headers.update(conn.extra_dejson)
except TypeError:
self.log.warning("Connection to %s has invalid extra field.", conn.host)
if headers:
_headers.update(headers)

if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = self.base_url + "/" + endpoint
else:
url = (self.base_url or "") + (endpoint or "")

async with aiohttp.ClientSession() as session:
if self.method == "GET":
request_func = session.get
elif self.method == "POST":
request_func = session.post
elif self.method == "PATCH":
request_func = session.patch
else:
return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"}

for attempt_num in range(1, 1 + self.retry_limit):
response = await request_func(
url,
json=data if self.method in ("POST", "PATCH") else None,
params=data if self.method == "GET" else None,
headers=_headers or None,
auth=auth,
**extra_options,
)
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt_num,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
self.log.exception("HTTP error, status code: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"}

await asyncio.sleep(self.retry_delay)

def _generate_base_url(self, conn: Connection) -> str:
if conn.host and "://" in conn.host:
base_url: str = conn.host
else:
# schema defaults to HTTP
schema = conn.schema if conn.schema else "http"
host = conn.host if conn.host else ""
base_url = f"{schema}://{host}"
if conn.port:
base_url = f"{base_url}:{conn.port}"
return base_url

async def run_method(
self,
endpoint: str,
method: str = "GET",
method: str | None = None,
data: Any | None = None,
headers: dict[str, Any] | None = None,
) -> Any:
Expand All @@ -609,16 +512,29 @@ async def run_method(
:param headers: headers
:return: http response
"""
if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"):
method = method or self.method
if method not in {"GET", "PATCH", "POST", "PUT", "DELETE", "HEAD"}:
return {"status": "error", "response": f"Invalid http method {method}"}

back_method = self.method
self.method = method
endpoint = (
f"{self.endpoint_prefix}/{endpoint}"
if self.endpoint_prefix and endpoint
else endpoint or self.endpoint_prefix
)

try:
result = await self._do_api_call_async(endpoint, data, headers, self.extra_options)
finally:
self.method = back_method
return {"status": "success", "response": result}
async with self.session() as session:
response = await session.run(
endpoint=endpoint,
data=data,
headers={**self._def_headers, **self.extra_headers, **(headers or {})},
extra_options=self.extra_options,
)

result = await response.json()
return {"status": "success", "response": result}
except ClientResponseError as e:
return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"}

async def get_batch_state(self, session_id: int | str) -> Any:
"""
Expand Down
159 changes: 54 additions & 105 deletions providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,159 +592,106 @@ async def test_dump_batch_logs_error(self, mock_get_batch_logs):
assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
async def test_run_method_success(self, mock_do_api_call_async):
@mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
async def test_run_method_success(self, mock_session):
"""Asserts the run_method for success response."""
mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}}
mock_session.return_value.__aenter__.return_value.post = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock(
return_value={"id": 1}
)
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "GET")
assert response["status"] == "success"
assert response == {"status": "success", "response": {"id": 1}}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
async def test_run_method_error(self, mock_do_api_call_async):
async def test_run_method_error(self):
"""Asserts the run_method for error response."""
mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "abc")
assert response == {"status": "error", "response": "Invalid http method abc"}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for success response for POST method."""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"status": "success"}

mock_session.return_value.__aexit__.return_value = mock_fun
@mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_post_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the run_method for success response for POST method."""
mock_session.return_value.__aenter__.return_value.post = AsyncMock()
mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock(
return_value={"status": "success"}
return_value={"hello": "world"}
)
GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.http_conn_id = mock_get_connection
hook.http_conn_id.host = "https://localhost"
hook.http_conn_id.login = "login"
hook.http_conn_id.password = "PASSWORD"
response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
assert response == {"status": "success"}
response = await hook.run_method("api/jobs/runs/get")
assert response["status"] == "success"
assert response["response"] == {"hello": "world"}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for GET method."""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"status": "success"}

mock_session.return_value.__aexit__.return_value = mock_fun
@mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_get_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the run_method for GET method."""
mock_session.return_value.__aenter__.return_value.get = AsyncMock()
mock_session.return_value.__aenter__.return_value.get.return_value.json = AsyncMock(
return_value={"status": "success"}
return_value={"hello": "world"}
)
GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "GET"
hook.http_conn_id = mock_get_connection
hook.http_conn_id.host = "test.com"
hook.http_conn_id.login = "login"
hook.http_conn_id.password = "PASSWORD"
hook.http_conn_id.extra_dejson = ""
response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
assert response == {"status": "success"}
response = await hook.run_method("api/jobs/runs/get")
assert response["status"] == "success"
assert response["response"] == {"hello": "world"}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for PATCH method."""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"status": "success"}

mock_session.return_value.__aexit__.return_value = mock_fun
@mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_patch_method_with_success(self, mock_get_connection, mock_session):
"""Asserts the run_method for PATCH method."""
mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
mock_session.return_value.__aenter__.return_value.patch.return_value.json = AsyncMock(
return_value={"status": "success"}
return_value={"hello": "world"}
)
GET_RUN_ENDPOINT = "api/jobs/runs/get"
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "PATCH"
hook.http_conn_id = mock_get_connection
hook.http_conn_id.host = "test.com"
hook.http_conn_id.login = "login"
hook.http_conn_id.password = "PASSWORD"
hook.http_conn_id.extra_dejson = ""
response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
assert response == {"status": "success"}
response = await hook.run_method("api/jobs/runs/get")
assert response["status"] == "success"
assert response["response"] == {"hello": "world"}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for unexpected method error"""
GET_RUN_ENDPOINT = "api/jobs/runs/get"
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_unexpected_method_with_success(self, mock_get_connection):
"""Asserts the run_method for unexpected method error"""
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "abc"
hook.http_conn_id = mock_get_connection
hook.http_conn_id.host = "test.com"
hook.http_conn_id.login = "login"
hook.http_conn_id.password = "PASSWORD"
hook.http_conn_id.extra_dejson = ""
response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT, headers={})
assert response == {"Response": "Unexpected HTTP Method: abc", "status": "error"}
response = await hook.run_method(endpoint="api/jobs/runs/get", headers={})
assert response == {"response": "Invalid http method abc", "status": "error"}

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for TypeError."""
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_put_method_with_type_error(self, mock_get_connection):
"""Asserts the run_method for TypeError."""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"random value"}

mock_session.return_value.__aexit__.return_value = mock_fun
mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value = {}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
hook.method = "PATCH"
hook.retry_limit = 1
hook.retry_delay = 1
hook.http_conn_id = mock_get_connection
with pytest.raises(TypeError):
await hook._do_api_call_async(endpoint="", data="test", headers=mock_fun, extra_options=mock_fun)
await hook.run_method(endpoint="api/jobs/runs/get", data="test", headers=mock_fun)

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
@mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection")
async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session):
"""Asserts the _do_api_call_async for Client Response Error."""
@mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession")
@mock.patch("airflow.providers.common.compat.connection.get_async_connection")
async def test_run_method_with_client_response_error(self, mock_get_connection, mock_session):
"""Asserts the run_method for Client Response Error."""

async def mock_fun(arg1, arg2, arg3, arg4):
return {"random value"}

mock_session.return_value.__aexit__.return_value = mock_fun
mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect = (
ClientResponseError(
mock_session.return_value.__aenter__.return_value.patch = AsyncMock(
side_effect=ClientResponseError(
request_info=RequestInfo(url="example.com", method="PATCH", headers=multidict.CIMultiDict()),
status=500,
history=[],
)
)
GET_RUN_ENDPOINT = ""
hook = LivyAsyncHook(livy_conn_id="livy_default")
hook.method = "PATCH"
hook.base_url = ""
hook.http_conn_id = mock_get_connection
hook.http_conn_id.host = "test.com"
hook.http_conn_id.login = "login"
hook.http_conn_id.password = "PASSWORD"
hook.http_conn_id.extra_dejson = ""
response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
response = await hook.run_method("")
assert response["status"] == "error"

@pytest.fixture
Expand All @@ -764,7 +711,8 @@ def setup_livy_conn(self, create_connection_without_db):
create_connection_without_db(Connection(conn_id="missing_host", conn_type="http", port=1234))
create_connection_without_db(Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321"))

def test_build_get_hook(self, setup_livy_conn):
@pytest.mark.asyncio
async def test_build_get_hook(self, setup_livy_conn):
connection_url_mapping = {
# id, expected
"default_port": "http://host",
Expand All @@ -776,9 +724,10 @@ def test_build_get_hook(self, setup_livy_conn):

for conn_id, expected in connection_url_mapping.items():
hook = LivyAsyncHook(livy_conn_id=conn_id)
response_conn: Connection = hook.get_connection(conn_id=conn_id)
assert isinstance(response_conn, Connection)
assert hook._generate_base_url(response_conn) == expected
async with hook.session() as session:
response_conn: Connection = hook.get_connection(conn_id=conn_id)
assert isinstance(response_conn, Connection)
assert session.base_url == expected

def test_build_body(self):
# minimal request
Expand Down
Loading
Loading