diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 23b3596..1b56ca7 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,11 +1,13 @@ -from airflow.providers.amazon.aws.operators.batch import BatchOperator -from airflow.utils.context import Context +from typing import Any, Optional + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink +from airflow.providers.amazon.aws.operators.batch import BatchOperator +from airflow.utils.context import Context class AWSBatchOperator(BatchOperator): @@ -69,7 +71,6 @@ def monitor_job(self, context: Context): if awslogs: self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - link_builder = CloudWatchEventsLink() for log in awslogs: self.log.info(self._format_cloudwatch_link(**log)) if len(awslogs) > 1: @@ -88,3 +89,51 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) + + def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: + """Execute when the trigger fires - fetch logs and complete the task.""" + # Call parent's execute_complete first + job_id = super().execute_complete(context, event) + + # Only fetch logs if we're in deferrable mode and awslogs are enabled + # In non-deferrable mode, logs are already fetched by monitor_job() + if self.deferrable and self.awslogs_enabled and job_id: + # Set job_id for our log fetching methods + self.job_id = job_id + + # Get job logs and display them + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher(job_id) + if log_fetcher: + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) + for message in log_messages: + self.log.info(message) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], + ) + + self.log.info("AWS Batch job (%s) succeeded", self.job_id) + + return job_id