diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 1b56ca7..3297e85 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, @@ -10,6 +10,18 @@ from airflow.utils.context import Context +def _format_extra_info(error_msg: str, last_logs: list[str], cloudwatch_link: Optional[str]) -> str: + """Format the enhanced error message with logs and link.""" + extra_info = [] + if cloudwatch_link: + extra_info.append(f"CloudWatch Logs: {cloudwatch_link}") + if last_logs: + extra_info.append("Last log lines:\n" + "\n".join(last_logs[-5:])) + if extra_info: + return f"{error_msg}\n\n" + "\n".join(extra_info) + return error_msg + + class AWSBatchOperator(BatchOperator): @staticmethod def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): @@ -90,50 +102,81 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) - def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: - """Execute when the trigger fires - fetch logs and complete the task.""" - # Call parent's execute_complete first - job_id = super().execute_complete(context, event) - - # Only fetch logs if we're in deferrable mode and awslogs are enabled - # In non-deferrable mode, logs are already fetched by monitor_job() - if self.deferrable and self.awslogs_enabled and job_id: - # Set job_id for our log fetching methods - self.job_id = job_id - - # Get job logs and display them - try: - # Use the log fetcher to display container logs - log_fetcher = self._get_batch_log_fetcher(job_id) - if log_fetcher: - # Get the last 50 log messages - self.log.info("Fetch the latest 50 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(50) - for message in log_messages: + def _fetch_and_log_cloudwatch(self, job_id: str) -> tuple[list[str], Optional[str]]: + """Fetch CloudWatch logs for the given job_id and return (last_logs, cloudwatch_link).""" + last_logs: list[str] = [] + cloudwatch_link: Optional[str] = None + + if not self.awslogs_enabled: + return last_logs, cloudwatch_link + + # Fetch last log messages + try: + log_fetcher = self._get_batch_log_fetcher(job_id) + if log_fetcher: + last_logs = log_fetcher.get_last_log_messages(50) + if last_logs: + self.log.info("CloudWatch logs (last 50 messages):") + for message in last_logs: self.log.info(message) - except Exception as e: - self.log.warning("Could not fetch batch job logs: %s", e) - - # Get CloudWatch log links - awslogs = [] - try: - awslogs = self.hook.get_job_all_awslogs_info(self.job_id) - except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + # Get CloudWatch log link + try: + awslogs = self.hook.get_job_all_awslogs_info(job_id) if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - - CloudWatchEventsLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - **awslogs[0], - ) - - self.log.info("AWS Batch job (%s) succeeded", self.job_id) - + cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) + self.log.info("CloudWatch link: %s", cloudwatch_link) + except AirflowException as e: + self.log.warning("Cannot determine CloudWatch log link: %s", e) + + return last_logs, cloudwatch_link + + def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: + """Execute when the trigger fires - fetch logs first, then check job status.""" + job_id = event.get("job_id") if event else None + if not job_id: + raise AirflowException("No job_id found in event data from trigger.") + + self.job_id = job_id + + # Always fetch logs before checking status + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(job_id) + + try: + self.hook.check_job_success(job_id) + except AirflowException as e: + raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) + + self.log.info("AWS Batch job (%s) succeeded", job_id) return job_id + + def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): + """Override resume_execution to handle trigger failures and fetch logs.""" + # Retrieve job_id from batch_job_details XCom if not available on the instance + if not hasattr(self, 'job_id') or not self.job_id: + task_instance = context.get('task_instance') + if task_instance: + try: + batch_job_details = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_details') + if batch_job_details and 'job_id' in batch_job_details: + self.job_id = batch_job_details['job_id'] + self.log.info(f"Retrieved job_id from batch_job_details XCom: {self.job_id}") + except Exception as e: + self.log.debug(f"Could not retrieve job_id from batch_job_details XCom: {e}") + + try: + return super().resume_execution(next_method, next_kwargs, context) + except TaskDeferralError as e: + # When trigger fails, try to fetch logs if job_id is available + if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: + self.log.info("Batch job trigger failed - fetching CloudWatch logs...") + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(self.job_id) + # Re-raise with enhanced error message including logs + raise AirflowException( + _format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link) + ) + else: + self.log.warning("Cannot fetch logs for failed batch job - job_id or awslogs_enabled not available") + raise