Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions dagger/dag_creator/airflow/operators/awsbatch_operator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context
from typing import Any, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.links.batch import (
BatchJobDefinitionLink,
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context


class AWSBatchOperator(BatchOperator):
Expand Down Expand Up @@ -69,7 +71,6 @@ def monitor_job(self, context: Context):

if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(self._format_cloudwatch_link(**log))
if len(awslogs) > 1:
Expand All @@ -88,3 +89,51 @@ def monitor_job(self, context: Context):

self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)

def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str:
"""Execute when the trigger fires - fetch logs and complete the task."""
# Call parent's execute_complete first
job_id = super().execute_complete(context, event)

# Only fetch logs if we're in deferrable mode and awslogs are enabled
# In non-deferrable mode, logs are already fetched by monitor_job()
if self.deferrable and self.awslogs_enabled and job_id:
# Set job_id for our log fetching methods
self.job_id = job_id

# Get job logs and display them
try:
# Use the log fetcher to display container logs
log_fetcher = self._get_batch_log_fetcher(job_id)
if log_fetcher:
# Get the last 50 log messages
self.log.info("Fetch the latest 50 messages from cloudwatch:")
log_messages = log_fetcher.get_last_log_messages(50)
for message in log_messages:
self.log.info(message)
except Exception as e:
self.log.warning("Could not fetch batch job logs: %s", e)

# Get CloudWatch log links
awslogs = []
try:
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
except AirflowException as ae:
self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae)

if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
for log in awslogs:
self.log.info(self._format_cloudwatch_link(**log))

CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs[0],
)

self.log.info("AWS Batch job (%s) succeeded", self.job_id)

return job_id