From c7f568f064b9a7391320d6271007a7746eeab198 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Thu, 18 Jul 2024 15:34:17 +0200 Subject: [PATCH] fix(airflow): fix xcom pull in _try_to_adopt_job --- spark_on_k8s/airflow/operators.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/spark_on_k8s/airflow/operators.py b/spark_on_k8s/airflow/operators.py index 6f3f0b6..cd1fbc6 100644 --- a/spark_on_k8s/airflow/operators.py +++ b/spark_on_k8s/airflow/operators.py @@ -242,10 +242,22 @@ def _persist_spark_history_ui_link(self, context: Context): def _try_to_adopt_job(self, context: Context, spark_app_manager: SparkAppManager) -> bool: from spark_on_k8s.utils.spark_app_status import SparkAppStatus - xcom_driver_namespace = context["ti"].xcom_pull(key=self._XCOM_DRIVER_POD_NAMESPACE) + xcom_driver_namespace = context["ti"].xcom_pull( + dag_id=context["ti"].dag_id, + task_ids=context["ti"].task_id, + map_indexes=context["ti"].map_index, + key=self._XCOM_DRIVER_POD_NAMESPACE, + include_prior_dates=True, + ) if not xcom_driver_namespace or xcom_driver_namespace != self.namespace: return False - xcom_driver_pod_name = context["ti"].xcom_pull(key=self._XCOM_DRIVER_POD_NAME) + xcom_driver_pod_name = context["ti"].xcom_pull( + dag_id=context["ti"].dag_id, + task_ids=context["ti"].task_id, + map_indexes=context["ti"].map_index, + key=self._XCOM_DRIVER_POD_NAME, + include_prior_dates=True, + ) if xcom_driver_pod_name: with contextlib.suppress(Exception): app_status = spark_app_manager.app_status(