Skip to content

Commit

Permalink
retrieve spark driver logs and log them in airflow in case of app fai…
Browse files Browse the repository at this point in the history
…lure (#34)

* retrieve spark drive logs and log them in airflow in case of app failure
* add flag and log size safety net
* bump version
* remove max log size

---------

Co-authored-by: Sigmar Stefansson <[email protected]>
  • Loading branch information
tcassou and sigmarkarl authored Sep 30, 2024
1 parent cca9117 commit a1e57d2
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 16 deletions.
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.3
1.1.4
33 changes: 23 additions & 10 deletions ocean_spark/hooks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from datetime import timedelta
from packaging import version
from typing import Callable, Dict, Any, Tuple

from typing import Callable, Dict, Any, Tuple, Optional

from airflow.hooks.base import BaseHook

Expand All @@ -24,7 +23,6 @@
requests.get,
urljoin(API_HOST, "cluster/{cluster_id}"),
)

SUBMIT_APP_ENDPOINT = (
requests.post,
urljoin(API_HOST, "cluster/{cluster_id}/app"),
Expand All @@ -37,6 +35,10 @@
requests.delete,
urljoin(API_HOST, "cluster/{cluster_id}/app/{app_id}"),
)
GET_DRIVER_LOGS = (
requests.get,
urljoin(API_HOST, "cluster/{cluster_id}/app/{app_id}/log/live"),
)

USER_AGENT_HEADER = {"user-agent": "airflow-{v}".format(v=__version__)}

Expand Down Expand Up @@ -72,17 +74,19 @@ def __init__(
self.retry_limit = retry_limit
self.retry_delay = retry_delay

def _do_api_call(self, method: Callable, endpoint: str, payload: Dict) -> Dict:
def _do_api_call(
self, method: Callable, endpoint: str, payload: Optional[Dict]
) -> requests.models.Response:
"""
Utility function to perform an API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: tuple[string, string]
:param payload: Parameters for this API call.
:type payload: dict
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
this function returns the response object. Otherwise,
we throw an AirflowException.
:rtype: dict
:rtype: requests.models.Response
"""

if payload is None:
Expand All @@ -100,7 +104,7 @@ def _do_api_call(self, method: Callable, endpoint: str, payload: Dict) -> Dict:
timeout=self.timeout_seconds,
)
response.raise_for_status()
return response.json()
return response
except (
requests_exceptions.ConnectionError,
requests_exceptions.Timeout,
Expand Down Expand Up @@ -148,7 +152,7 @@ def submit_app(self, payload: Dict) -> str:
),
payload,
)
return response["response"]["items"][0]["id"]
return response.json()["response"]["items"][0]["id"]

def get_app(self, app_id: str) -> Dict:
method, path = GET_APP_ENDPOINT
Expand All @@ -161,7 +165,7 @@ def get_app(self, app_id: str) -> Dict:
),
{},
)
return response["response"]["items"][0]
return response.json()["response"]["items"][0]

def kill_app(self, app_id: str) -> None:
method, path = DELETE_APP_ENDPOINT
Expand All @@ -181,6 +185,15 @@ def get_app_page_url(self, app_id: str) -> str:
f"apps/clusters/{self.cluster_id}/apps/{app_id}/overview&accountId={self.account_id}",
)

def get_driver_logs(self, app_id: str) -> str:
method, path = GET_DRIVER_LOGS
response = self._do_api_call(
method,
path.format(cluster_id=self.cluster_id, app_id=app_id),
payload=None,
)
return response.content.decode("utf-8")

def test_connection(self) -> Tuple[bool, str]:
method, path = GET_CLUSTER_ENDPOINT
try:
Expand All @@ -190,7 +203,7 @@ def test_connection(self) -> Tuple[bool, str]:
cluster_id=self.cluster_id,
),
{},
)
).json()
if response["response"]["items"][0]["state"] not in [
"AVAILABLE",
"PROGRESSING",
Expand Down
8 changes: 8 additions & 0 deletions ocean_spark/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
on_spark_submit_callback: Optional[
Callable[[OceanSparkHook, str, "Context"], None]
] = None,
forward_driver_logs: bool = False,
**kwargs: Any,
):
"""
Expand All @@ -83,6 +84,7 @@ def __init__(
self.on_spark_submit_callback: Optional[
Callable[[OceanSparkHook, str, "Context"], None]
] = on_spark_submit_callback
self.forward_driver_logs = forward_driver_logs
self.payload: Dict = {}

if self.job_id is None:
Expand Down Expand Up @@ -163,6 +165,12 @@ def _monitor_app(self, hook: OceanSparkHook, context: "Context") -> None:
self.log.info("%s completed successfully.", self.task_id)
return
else:
if self.forward_driver_logs:
self.log.info(
"Ocean Spark task failure, retrieving Spark driver logs..."
)
# printing driver logs as-is to preserve formatting
print(hook.get_driver_logs(self.app_id))
error_message = "{t} failed with terminal state: {s}".format(
t=self.task_id, s=app_state.value
)
Expand Down
Loading

0 comments on commit a1e57d2

Please sign in to comment.