Skip to content

Commit

Permalink
use a Variable to keep the Future
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Dec 5, 2024
1 parent 1ab47da commit fde7c95
Showing 1 changed file with 21 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def _comp_sidecar_fct(
)
# NOTE: the callback is running in a secondary thread, and takes a future as arg
task_future.add_done_callback(lambda _: callback())
await distributed.Variable(job_id, client=self.backend.client).set(
task_future
)

await dask_utils.wrap_client_async_routine(
self.backend.client.publish_dataset(task_future, name=job_id)
Expand Down Expand Up @@ -450,30 +453,26 @@ def _get_pipeline_statuses(
DaskSchedulerTaskState | None, task_statuses.get(job_id, "lost")
)
if dask_status == "erred":
try:
# find out if this was a cancellation
exception = await distributed.Future(
job_id, client=self.backend.client
).exception(timeout=_DASK_DEFAULT_TIMEOUT_S)
assert isinstance(exception, Exception) # nosec

if isinstance(exception, TaskCancelledError):
running_states.append(DaskClientTaskState.ABORTED)
else:
assert exception # nosec
_logger.warning(
"Task %s completed in error:\n%s\nTrace:\n%s",
job_id,
exception,
"".join(traceback.format_exception(exception)),
)
running_states.append(DaskClientTaskState.ERRED)
except TimeoutError:
# find out if this was a cancellation
var = distributed.Variable(job_id, client=self.backend.client)
future: distributed.Future = await var.get(
timeout=_DASK_DEFAULT_TIMEOUT_S
)
exception = await future.exception(timeout=_DASK_DEFAULT_TIMEOUT_S)
assert isinstance(exception, Exception) # nosec

if isinstance(exception, TaskCancelledError):
running_states.append(DaskClientTaskState.ABORTED)
else:
assert exception # nosec
_logger.warning(
"Task %s completed in error but was lost from dask-scheduler since then."
"TIP: This can happen when the future just disappeared from the dask-scheduler when this call was done."
"Task %s completed in error:\n%s\nTrace:\n%s",
job_id,
exception,
"".join(traceback.format_exception(exception)),
)
running_states.append(DaskClientTaskState.LOST)
running_states.append(DaskClientTaskState.ERRED)

elif dask_status is None:
running_states.append(DaskClientTaskState.LOST)
else:
Expand Down

0 comments on commit fde7c95

Please sign in to comment.