Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix EmrCreateJobFlowOperator using deferrable mode with wait_for_completion #41561

52 changes: 29 additions & 23 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def _emr_hook(self) -> EmrHook:
aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name
)

def execute(self, context: Context) -> str | None:
def execute(self, context: Context) -> str:
self.log.info(
"Creating job flow using aws_conn_id: %s, emr_conn_id: %s", self.aws_conn_id, self.emr_conn_id
)
Expand All @@ -801,13 +801,15 @@ def execute(self, context: Context) -> str | None:
self.job_flow_overrides = job_flow_overrides
else:
job_flow_overrides = self.job_flow_overrides

response = self._emr_hook.create_job_flow(job_flow_overrides)

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Job flow creation failed: {response}")

self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)

EmrClusterLink.persist(
context=context,
operator=self,
Expand All @@ -824,31 +826,35 @@ def execute(self, context: Context) -> str | None:
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id),
)
if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
shahar1 marked this conversation as resolved.
Show resolved Hide resolved
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)

if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)
if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
else:
self._emr_hook.get_waiter("job_flow_waiting").wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)

return self._job_flow_id


def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
event = validate_execute_complete_event(event)

Expand Down
102 changes: 94 additions & 8 deletions tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,21 +189,107 @@ def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client):
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
assert_expected_waiter_type(mock_waiter, "job_flow_waiting")

def test_create_job_flow_deferrable(self, mocked_hook_client):
@patch("airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowTrigger")
def test_create_job_flow_deferrable(self, mock_trigger, mocked_hook_client):
"""
Test to make sure that the operator raises a TaskDeferred exception
Test to ensure the operator raises a TaskDeferred exception
if run in deferrable mode.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN


# Set the deferrable flag and wait_for_completion
self.operator.deferrable = True
self.operator.wait_for_completion = True

# Check for TaskDeferred being raised
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)

# Ensure the trigger is created with the right parameters
mock_trigger.assert_called_once_with(
job_flow_id=JOB_FLOW_ID,
aws_conn_id=self.operator.aws_conn_id,
waiter_delay=self.operator.waiter_delay,
waiter_max_attempts=self.operator.waiter_max_attempts,
)

# Ensure the trigger is correctly set
assert exc.value.trigger == mock_trigger.return_value


class TestEmrCreateJobFlowOperatorExtended(TestEmrCreateJobFlowOperator):

@patch("airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowTrigger")
@patch("airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator.defer")
def test_deferrable_and_wait_for_completion(self, mock_defer, mock_trigger, mocked_hook_client):
# Simulate successful job flow creation
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Set the deferrable attributes
self.operator.deferrable = True
self.operator.wait_for_completion = True
self.operator.waiter_delay = 10 # Example delay value
self.operator.waiter_max_attempts = 5 # Example max attempts value

# Execute the operator
self.operator.execute(self.mock_context)

# Ensure that the trigger was called with the correct parameters
mock_trigger.assert_called_once_with(
job_flow_id=JOB_FLOW_ID,
aws_conn_id=self.operator.aws_conn_id,
waiter_delay=self.operator.waiter_delay,
waiter_max_attempts=self.operator.waiter_max_attempts,
)

# Ensure the defer method was called with the correct arguments
mock_defer.assert_called_once_with(
trigger=mock_trigger.return_value,
method_name="execute_complete",
timeout=timedelta(seconds=self.operator.waiter_max_attempts * self.operator.waiter_delay + 60),
)

@mock.patch("airflow.providers.amazon.aws.operators.emr.EmrCreateJobFlowOperator.defer")
def test_deferrable_and_wait_for_completion(self, mock_defer, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.deferrable = True
self.operator.wait_for_completion = True
self.operator.waiter_delay = 10 # Example value
self.operator.waiter_max_attempts = 5 # Example value

assert isinstance(
exc.value.trigger, EmrCreateJobFlowTrigger
), "Trigger is not a EmrCreateJobFlowTrigger"
self.operator.execute(self.mock_context)
mock_defer.assert_called_once_with(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=JOB_FLOW_ID,
aws_conn_id=self.operator.aws_conn_id,
waiter_delay=self.operator.waiter_delay,
waiter_max_attempts=self.operator.waiter_max_attempts,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.operator.waiter_max_attempts * self.operator.waiter_delay + 60),
)

@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
def test_non_deferrable_but_wait_for_completion(self, mock_waiter, _, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.deferrable = False
self.operator.wait_for_completion = True

assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)

def test_no_wait_for_completion(self, mocked_hook_client):
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.deferrable = True
self.operator.wait_for_completion = False

assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
assert not mocked_hook_client.get_waiter.called

# This part comes from the main branch
def test_template_fields(self):
validate_template_fields(self.operator)
validate_template_fields(self.operator)
Loading