Skip to content

Commit

Permalink
Merge pull request #92 from LamaAni/reformat
Browse files Browse the repository at this point in the history
Reforemat to black
  • Loading branch information
LamaAni authored Nov 22, 2023
2 parents 96f736c + 10f011a commit cee09aa
Show file tree
Hide file tree
Showing 15 changed files with 540 additions and 173 deletions.
12 changes: 9 additions & 3 deletions airflow_kubernetes_job_operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from airflow_kubernetes_job_operator.kubernetes_job_operator import KubernetesJobOperator # noqa F401
from airflow_kubernetes_job_operator.kubernetes_legacy_job_operator import KubernetesLegacyJobOperator # noqa F401
from airflow_kubernetes_job_operator.kubernetes_job_operator import (
KubernetesJobOperator,
) # noqa F401
from airflow_kubernetes_job_operator.kubernetes_legacy_job_operator import (
KubernetesLegacyJobOperator,
) # noqa F401
from airflow_kubernetes_job_operator.utils import resolve_relative_path # noqa F401
from airflow_kubernetes_job_operator.job_runner import JobRunnerDeletePolicy # noqa F401
from airflow_kubernetes_job_operator.job_runner import (
JobRunnerDeletePolicy,
) # noqa F401
24 changes: 18 additions & 6 deletions airflow_kubernetes_job_operator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
import re


DEFAULT_EXECUTION_OBJECT_PATHS: Dict[KubernetesJobOperatorDefaultExecutionResource, str] = {
KubernetesJobOperatorDefaultExecutionResource.Job: resolve_path("./templates/job_default.yaml"),
KubernetesJobOperatorDefaultExecutionResource.Pod: resolve_path("./templates/pod_default.yaml"),
DEFAULT_EXECUTION_OBJECT_PATHS: Dict[
KubernetesJobOperatorDefaultExecutionResource, str
] = {
KubernetesJobOperatorDefaultExecutionResource.Job: resolve_path(
"./templates/job_default.yaml"
),
KubernetesJobOperatorDefaultExecutionResource.Pod: resolve_path(
"./templates/pod_default.yaml"
),
}

AIRFLOW_CONFIG_SECTION_NAME = "kubernetes_job_operator"
Expand Down Expand Up @@ -62,7 +68,9 @@ def get(
allow_empty = False

if val is None or (not allow_empty and len(val.strip()) == 0):
assert default is not None, f"Airflow configuration {collection}.{key} not found, and no default value"
assert (
default is not None
), f"Airflow configuration {collection}.{key} not found, and no default value"
return default

if otype == bool:
Expand All @@ -79,7 +87,9 @@ def get(
# Airflow config values

# Job runner
DEFAULT_DELETE_POLICY: JobRunnerDeletePolicy = get("delete_policy", JobRunnerDeletePolicy.IfSucceeded)
DEFAULT_DELETE_POLICY: JobRunnerDeletePolicy = get(
"delete_policy", JobRunnerDeletePolicy.IfSucceeded
)

# Default bodies
DEFAULT_EXECTION_OBJECT: KubernetesJobOperatorDefaultExecutionResource = get(
Expand All @@ -99,7 +109,9 @@ def get(
SHOW_RUNNER_ID_IN_LOGS: bool = get("show_runner_id", False)

# Client config
KUBE_CONFIG_EXTRA_LOCATIONS: str = get("kube_config_extra_locations", "", otype=str, allow_empty=True)
KUBE_CONFIG_EXTRA_LOCATIONS: str = get(
"kube_config_extra_locations", "", otype=str, allow_empty=True
)
if not_empty_string(KUBE_CONFIG_EXTRA_LOCATIONS):
for loc in KUBE_CONFIG_EXTRA_LOCATIONS.split(",").reverse():
loc = loc.strip()
Expand Down
166 changes: 129 additions & 37 deletions airflow_kubernetes_job_operator/job_runner.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion airflow_kubernetes_job_operator/kube_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
from airflow_kubernetes_job_operator.kube_api.watchers import * # noqa F401
from airflow_kubernetes_job_operator.kube_api.queries import * # noqa F401
from airflow_kubernetes_job_operator.kube_api.operations import * # noqa F401
from airflow_kubernetes_job_operator.kube_api.config import KubeApiConfiguration # noqa F401
from airflow_kubernetes_job_operator.kube_api.config import (
KubeApiConfiguration,
) # noqa F401
from airflow_kubernetes_job_operator.kube_api.utils import kube_logger # noqa F401
122 changes: 91 additions & 31 deletions airflow_kubernetes_job_operator/kube_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from kubernetes.config import kube_config

from airflow_kubernetes_job_operator.kube_api.utils import kube_logger
from airflow_kubernetes_job_operator.kube_api.exceptions import KubeApiException, KubeApiClientException
from airflow_kubernetes_job_operator.kube_api.collections import KubeApiRestQueryConnectionState
from airflow_kubernetes_job_operator.kube_api.exceptions import (
KubeApiException,
KubeApiClientException,
)
from airflow_kubernetes_job_operator.kube_api.collections import (
KubeApiRestQueryConnectionState,
)
from airflow_kubernetes_job_operator.kube_api.config import (
KubeApiConfiguration,
DEFAULT_AUTO_RECONNECT_MAX_ATTEMPTS,
Expand Down Expand Up @@ -95,7 +100,9 @@ def __init__(
throw_on_if_first_api_call_fails (bool, optional): If true the and the first attempt to connect fails,
throws an error. Defaults to True.
"""
assert use_asyncio is not True, NotImplementedError("AsyncIO not yet implemented.")
assert use_asyncio is not True, NotImplementedError(
"AsyncIO not yet implemented."
)
super().__init__(
self._execute_query,
use_async_loop=use_asyncio or KubeApiRestQuery.default_use_asyncio,
Expand All @@ -120,11 +127,15 @@ def __init__(
# these event are object specific
self.query_started_event_name = f"{self.query_started_event_name} {id(self)}"
self.query_ended_event_name = f"{self.query_started_event_name} {id(self)}"
self.connection_state_changed_event_name = f"{self.connection_state_changed_event_name} {id(self)}"
self.connection_state_changed_event_name = (
f"{self.connection_state_changed_event_name} {id(self)}"
)

self._active_responses: Set[HTTPResponse] = WeakSet() # type:ignore
self._is_being_stopped: bool = False
self._connection_state: KubeApiRestQueryConnectionState = KubeApiRestQueryConnectionState.Disconnected
self._connection_state: KubeApiRestQueryConnectionState = (
KubeApiRestQueryConnectionState.Disconnected
)

@property
def query_running(self) -> bool:
Expand All @@ -136,7 +147,9 @@ def connection_state(self) -> KubeApiRestQueryConnectionState:
"""The state of the connection"""
return self._connection_state

def _set_connection_state(self, state: KubeApiRestQueryConnectionState, emit_event: bool = True):
def _set_connection_state(
self, state: KubeApiRestQueryConnectionState, emit_event: bool = True
):
if self._connection_state == state:
return
self._connection_state = state
Expand Down Expand Up @@ -281,11 +294,18 @@ def can_reconnect():

# starting query.
is_first_connect_attempt = True
while self.is_running and not self._is_being_stopped and (is_first_connect_attempt or self.auto_reconnect):
while (
self.is_running
and not self._is_being_stopped
and (is_first_connect_attempt or self.auto_reconnect)
):
try:
if not is_first_connect_attempt:
# error while running and has wait time
if self.query_running and self.auto_reconnect_wait_between_attempts > 0:
if (
self.query_running
and self.auto_reconnect_wait_between_attempts > 0
):
kube_logger.debug(
f"[{self.resource_path}][Reconnect] Sleeping for "
+ f"{self.auto_reconnect_wait_between_attempts}"
Expand All @@ -298,13 +318,17 @@ def can_reconnect():
self.emit(self.query_before_reconnect_event_name)

# Reset the connection state.
self._set_connection_state(KubeApiRestQueryConnectionState.Disconnected)
self._set_connection_state(
KubeApiRestQueryConnectionState.Disconnected
)

# Case auto_reconnect has changed.
if not self.auto_reconnect or not do_reconnect:
break

kube_logger.debug(f"[{self.resource_path}] Connection lost, reconnecting..")
kube_logger.debug(
f"[{self.resource_path}] Connection lost, reconnecting.."
)

# generating the query params
path_params = validate_dictionary(self.path_params)
Expand Down Expand Up @@ -364,16 +388,25 @@ def can_reconnect():
except Exception:
pass

exeuctor_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
exeuctor_name = (
f"{self.__class__.__module__}.{self.__class__.__name__}"
)

if isinstance(ex.body, dict):
exception_message = f"{exeuctor_name}, {ex.reason}: {ex.body.get('message')}"
exception_message = (
f"{exeuctor_name}, {ex.reason}: {ex.body.get('message')}"
)
else:
exception_message = f"{exeuctor_name}, {ex.reason}: {ex.body}"

err = KubeApiClientException(exception_message, rest_api_exception=ex)
err = KubeApiClientException(
exception_message, rest_api_exception=ex
)

if is_first_connect_attempt and self.throw_on_if_first_api_call_fails:
if (
is_first_connect_attempt
and self.throw_on_if_first_api_call_fails
):
raise err

# check if can reconnect.
Expand All @@ -396,7 +429,9 @@ def start(self, client: "KubeApiRestClient"):
client (ApiClient): The api client to use.
"""
assert not self.is_running, "Cannot start a running query"
assert isinstance(client, KubeApiRestClient), "client must be of class KubeApiRestClient"
assert isinstance(
client, KubeApiRestClient
), "client must be of class KubeApiRestClient"

self._query_running = False
super().start(client)
Expand Down Expand Up @@ -439,7 +474,9 @@ def stop(self, timeout: float = None, throw_error_if_not_running: bool = False):
rsp.close()
except Exception:
pass
super().stop(timeout=timeout, throw_error_if_not_running=throw_error_if_not_running) # type:ignore
super().stop(
timeout=timeout, throw_error_if_not_running=throw_error_if_not_running
) # type:ignore
finally:
self._query_running = False
self._is_being_stopped = False
Expand All @@ -453,7 +490,9 @@ def log_event(self, logger: Logger, ev: Event):
"""
pass

def pipe_to_logger(self, logger: Logger = kube_logger, allowed_event_names=None) -> EventHandler:
def pipe_to_logger(
self, logger: Logger = kube_logger, allowed_event_names=None
) -> EventHandler:
"""Called to pipe logging events to a specific logger. The log_event method
will be called when a message is emitted.
Expand All @@ -475,9 +514,15 @@ def pipe_to_logger(self, logger: Logger = kube_logger, allowed_event_names=None)

def process_log_event(ev: Event):
if ev.name in [self.error_event_name, self.warning_event_name]:
err: Exception = ev.args[-1] if len(ev.args) > 0 else Exception("Unknown error")
err: Exception = (
ev.args[-1] if len(ev.args) > 0 else Exception("Unknown error")
)
msg = (
"\n".join(traceback.format_exception(err.__class__, err, err.__traceback__))
"\n".join(
traceback.format_exception(
err.__class__, err, err.__traceback__
)
)
if isinstance(err, Exception)
else err
)
Expand Down Expand Up @@ -536,9 +581,13 @@ def kube_config(self) -> kube_config.Configuration:
"""
if self._kube_config is None:
if not self.auto_load_kube_config:
raise KubeApiException("Kubernetes configuration not loaded and auto load is set to false.")
raise KubeApiException(
"Kubernetes configuration not loaded and auto load is set to false."
)
self.load_kube_config()
assert self._kube_config is not None, "Failed to load default kubernetes configuration"
assert (
self._kube_config is not None
), "Failed to load default kubernetes configuration"
return self._kube_config

@property
Expand Down Expand Up @@ -570,12 +619,14 @@ def load_kube_config(
persist (bool, optional): If True, config file will be updated when changed
(e.g GCP token refresh).
"""
self._kube_config: kube_config.Configuration = KubeApiConfiguration.load_kubernetes_configuration(
config_file=config_file,
is_in_cluster=is_in_cluster,
extra_config_locations=extra_config_locations,
context=context,
persist=persist,
self._kube_config: kube_config.Configuration = (
KubeApiConfiguration.load_kubernetes_configuration(
config_file=config_file,
is_in_cluster=is_in_cluster,
extra_config_locations=extra_config_locations,
context=context,
persist=persist,
)
)

assert self._kube_config is not None, KubeApiClientException(
Expand All @@ -599,7 +650,9 @@ def stop(self):

def _create_query_handler(self, queries: List[KubeApiRestQuery]) -> EventHandler:
assert isinstance(queries, list), "queries Must be a list of queries"
assert all([isinstance(q, KubeApiRestQuery) for q in queries]), "All queries must be of type KubeApiRestQuery"
assert all(
[isinstance(q, KubeApiRestQuery) for q in queries]
), "All queries must be of type KubeApiRestQuery"
assert len(queries) > 0, "You must at least send one query"

handler = EventHandler()
Expand All @@ -618,7 +671,10 @@ def remove_from_pending(q, ex: Exception = None):
for q in queries:
self._active_queries.add(q)
q.on(q.error_event_name, lambda query, err: remove_from_pending(query, err))
q.on(q.query_ended_event_name, lambda query, client: remove_from_pending(query))
q.on(
q.query_ended_event_name,
lambda query, client: remove_from_pending(query),
)
q.pipe(handler)

return handler
Expand All @@ -628,7 +684,9 @@ def _start_execution(self, queries: List[KubeApiRestQuery]):
self._active_queries.add(query)
query.start(self)

def query_async(self, queries: Union[List[KubeApiRestQuery], KubeApiRestQuery]) -> EventHandler:
def query_async(
self, queries: Union[List[KubeApiRestQuery], KubeApiRestQuery]
) -> EventHandler:
"""Asynchronous querying. The queries will be called in the background. Use wait_until_running
for each query to wait for the queries to start.
Expand Down Expand Up @@ -710,7 +768,9 @@ def query(
Union[List[object], object]: A single query if a single query is sent. A list if a
list is sent
"""
strm = self.stream(queries, event_name=event_name, timeout=timeout, throw_errors=throw_errors)
strm = self.stream(
queries, event_name=event_name, timeout=timeout, throw_errors=throw_errors
)
rslt = [v for v in strm]
if not isinstance(queries, list):
return rslt[0] if len(rslt) > 0 else None
Expand Down
Loading

0 comments on commit cee09aa

Please sign in to comment.