Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,9 @@ class EmrServerlessStartJobOperator(AwsBaseOperator[EmrServerlessHook]):
:param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless
application UIs. The generated links will allow any user with access to the DAG to see the Spark or
Tez UI or Spark stdout logs. Defaults to False.
:param cancel_on_kill: If True, the EMR Serverless job will be cancelled when the task is killed
while in deferrable mode. This ensures that orphan jobs are not left running in EMR Serverless
when an Airflow task is cancelled. Defaults to True.
"""

aws_hook_class = EmrServerlessHook
Expand Down Expand Up @@ -1188,6 +1191,7 @@ def __init__(
waiter_delay: int | ArgNotSet = NOTSET,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
enable_application_ui_links: bool = False,
cancel_on_kill: bool = True,
**kwargs,
):
waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
Expand All @@ -1205,6 +1209,7 @@ def __init__(
self.job_id: str | None = None
self.deferrable = deferrable
self.enable_application_ui_links = enable_application_ui_links
self.cancel_on_kill = cancel_on_kill
super().__init__(**kwargs)

self.client_request_token = client_request_token or str(uuid4())
Expand Down Expand Up @@ -1269,6 +1274,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
Expand Down Expand Up @@ -1320,7 +1326,8 @@ def on_kill(self) -> None:
"""
Cancel the submitted job run.

Note: this method will not run in deferrable mode.
Note: In deferrable mode, this method will not run. Instead, job cancellation
is handled by the trigger's cancel_on_kill parameter when the task is killed.
"""
if self.job_id:
self.log.info("Stopping job run with jobId - %s", self.job_id)
Expand Down
131 changes: 129 additions & 2 deletions providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
# under the License.
from __future__ import annotations

import asyncio
import sys
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

from asgiref.sync import sync_to_async

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import TriggerEvent
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

if not AIRFLOW_V_3_0_PLUS:
from airflow.models.taskinstance import TaskInstance
from airflow.utils.session import provide_session


class EmrAddStepsTrigger(AwsBaseWaiterTrigger):
"""
Expand Down Expand Up @@ -331,9 +345,10 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):

:param application_id: The ID of the application the job in being run on.
:param job_id: The ID of the job run.
:waiter_delay: polling period in seconds to check for the status
:param waiter_delay: polling period in seconds to check for the status
:param waiter_max_attempts: The maximum number of attempts to be made
:param aws_conn_id: Reference to AWS connection id
:param cancel_on_kill: Flag to indicate whether to cancel the job when the task is killed.
"""

def __init__(
Expand All @@ -343,9 +358,14 @@ def __init__(
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
aws_conn_id: str | None = "aws_default",
cancel_on_kill: bool = True,
) -> None:
super().__init__(
serialized_fields={"application_id": application_id, "job_id": job_id},
serialized_fields={
"application_id": application_id,
"job_id": job_id,
"cancel_on_kill": cancel_on_kill,
},
waiter_name="serverless_job_completed",
waiter_args={"applicationId": application_id, "jobRunId": job_id},
failure_message="Serverless Job failed",
Expand All @@ -357,10 +377,117 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
self.application_id = application_id
self.job_id = job_id
self.cancel_on_kill = cancel_on_kill

def hook(self) -> AwsGenericHook:
return EmrServerlessHook(self.aws_conn_id)

if not AIRFLOW_V_3_0_PLUS:

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""Get the task instance for the current trigger (Airflow 2.x compatibility)."""
from sqlalchemy import select

query = select(TaskInstance).where(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = session.scalars(query).one_or_none()
if task_instance is None:
raise ValueError(
f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
f"task_id: {self.task_instance.task_id}, "
f"run_id: {self.task_instance.run_id} and "
f"map_index: {self.task_instance.map_index} is not found"
)
return task_instance

async def get_task_state(self):
"""Get the current state of the task instance (Airflow 3.x)."""
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise ValueError(
f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
f"task_id: {self.task_instance.task_id}, "
f"run_id: {self.task_instance.run_id} and "
f"map_index: {self.task_instance.map_index} is not found"
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the EMR Serverless job.

Returns True if task is NOT DEFERRED (user-initiated cancellation).
Returns False if task is DEFERRED (triggerer restart - don't cancel job).
"""
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Run the trigger and wait for the job to complete.

If the task is cancelled while waiting, attempt to cancel the EMR Serverless job
if cancel_on_kill is enabled and it's safe to do so.
"""
hook = self.hook()
try:
async with await hook.get_async_conn() as client:
waiter = hook.get_waiter(
self.waiter_name,
deferrable=True,
client=client,
config_overrides=self.waiter_config_overrides,
)
await async_wait(
waiter,
self.waiter_delay,
self.attempts,
self.waiter_args,
self.failure_message,
self.status_message,
self.status_queries,
)
yield TriggerEvent({"status": "success", self.return_key: self.return_value})
except asyncio.CancelledError:
if self.job_id and self.cancel_on_kill and await self.safe_to_cancel():
self.log.info(
"Task was cancelled. Cancelling EMR Serverless job. Application ID: %s, Job ID: %s",
self.application_id,
self.job_id,
)
hook.conn.cancel_job_run(applicationId=self.application_id, jobRunId=self.job_id)
self.log.info("EMR Serverless job %s cancelled successfully.", self.job_id)
else:
self.log.info(
"Trigger may have shutdown or cancel_on_kill is disabled. "
"Skipping job cancellation. Application ID: %s, Job ID: %s",
self.application_id,
self.job_id,
)
raise
except Exception as e:
yield TriggerEvent({"status": "failure", "message": str(e)})


class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
132 changes: 132 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# under the License.
from __future__ import annotations

import asyncio
import sys
from unittest import mock

import pytest

from airflow.providers.amazon.aws.triggers.emr import (
EmrAddStepsTrigger,
Expand Down Expand Up @@ -269,8 +273,136 @@ def test_serialization(self):
"waiter_max_attempts": 60,
"job_id": "job_id",
"aws_conn_id": "aws_default",
"cancel_on_kill": True,
}

def test_serialization_cancel_on_kill_false(self):
"""Test that cancel_on_kill=False is correctly serialized."""
trigger = EmrServerlessStartJobTrigger(
application_id="test_app",
job_id="test_job",
waiter_delay=30,
waiter_max_attempts=60,
aws_conn_id="aws_default",
cancel_on_kill=False,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger"
assert kwargs["cancel_on_kill"] is False

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
async def test_emr_serverless_trigger_cancellation(self, mock_safe_to_cancel, mock_async_wait):
"""
Test that EmrServerlessStartJobTrigger cancels the job when task is killed
and safe_to_cancel returns True.
"""
mock_safe_to_cancel.return_value = True
mock_async_wait.side_effect = asyncio.CancelledError()

trigger = EmrServerlessStartJobTrigger(
application_id="test_app",
job_id="test_job",
waiter_delay=30,
waiter_max_attempts=60,
aws_conn_id="aws_default",
cancel_on_kill=True,
)

mock_hook = mock.MagicMock()
mock_hook.get_waiter.return_value = mock.MagicMock()
mock_hook.conn.cancel_job_run.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}

mock_client = mock.MagicMock()
mock_async_cm = mock.MagicMock()
mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)

with mock.patch.object(trigger, "hook", return_value=mock_hook):
with pytest.raises(asyncio.CancelledError):
async for _ in trigger.run():
pass

mock_hook.conn.cancel_job_run.assert_called_once_with(applicationId="test_app", jobRunId="test_job")

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
async def test_emr_serverless_trigger_no_cancellation_when_unsafe(
self, mock_safe_to_cancel, mock_async_wait
):
"""
Test that EmrServerlessStartJobTrigger does NOT cancel the job when
safe_to_cancel returns False (e.g., triggerer shutdown).
"""
mock_safe_to_cancel.return_value = False
mock_async_wait.side_effect = asyncio.CancelledError()

trigger = EmrServerlessStartJobTrigger(
application_id="test_app",
job_id="test_job",
waiter_delay=30,
waiter_max_attempts=60,
aws_conn_id="aws_default",
cancel_on_kill=True,
)

mock_hook = mock.MagicMock()
mock_hook.get_waiter.return_value = mock.MagicMock()

mock_client = mock.MagicMock()
mock_async_cm = mock.MagicMock()
mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)

with mock.patch.object(trigger, "hook", return_value=mock_hook):
with pytest.raises(asyncio.CancelledError):
async for _ in trigger.run():
pass

mock_hook.conn.cancel_job_run.assert_not_called()

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrServerlessStartJobTrigger.safe_to_cancel")
async def test_emr_serverless_trigger_no_cancellation_when_disabled(
self, mock_safe_to_cancel, mock_async_wait
):
"""
Test that EmrServerlessStartJobTrigger does NOT cancel the job when
cancel_on_kill=False.
"""
mock_safe_to_cancel.return_value = True
mock_async_wait.side_effect = asyncio.CancelledError()

trigger = EmrServerlessStartJobTrigger(
application_id="test_app",
job_id="test_job",
waiter_delay=30,
waiter_max_attempts=60,
aws_conn_id="aws_default",
cancel_on_kill=False, # Disabled
)

mock_hook = mock.MagicMock()
mock_hook.get_waiter.return_value = mock.MagicMock()

mock_client = mock.MagicMock()
mock_async_cm = mock.MagicMock()
mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)

with mock.patch.object(trigger, "hook", return_value=mock_hook):
with pytest.raises(asyncio.CancelledError):
async for _ in trigger.run():
pass

mock_hook.conn.cancel_job_run.assert_not_called()


class TestEmrServerlessDeleteApplicationTrigger:
def test_serialization(self):
Expand Down