Skip to content

Commit

Permalink
Rename generic_run_query_task to run_query_task; Make search_task's m…
Browse files Browse the repository at this point in the history
…ake_command exception safe; Fix some PEP violations; Some clean-up.
  • Loading branch information
kirkrodrigues committed Jun 28, 2024
1 parent bd7083f commit 4a4d5b2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from clp_py_utils.sql_adapter import SQL_Adapter
from job_orchestration.executor.query.celery import app
from job_orchestration.executor.query.utils import (
generic_run_query_task,
report_command_creation_failure,
run_query_task,
)
from job_orchestration.scheduler.job_config import ExtractIrJobConfig
from job_orchestration.scheduler.scheduler_data import QueryTaskStatus
Expand Down Expand Up @@ -63,15 +63,14 @@ def extract_ir(
clp_metadata_db_conn_params: dict,
results_cache_uri: str,
) -> Dict[str, Any]:
# Task name
TASK_NAME = "IR extraction"
task_name = "IR extraction"

# Setup logging to file
clp_logs_dir = Path(os.getenv("CLP_LOGS_DIR"))
clp_logging_level = str(os.getenv("CLP_LOGGING_LEVEL"))
set_logging_level(logger, clp_logging_level)

logger.info(f"Started {TASK_NAME} task for job {job_id}")
logger.info(f"Started {task_name} task for job {job_id}")

start_time = datetime.datetime.now()
task_status: QueryTaskStatus
Expand All @@ -95,22 +94,21 @@ def extract_ir(
results_cache_uri=results_cache_uri,
ir_collection=ir_collection,
)

if not task_command:
return report_command_creation_failure(
sql_adapter=sql_adapter,
logger=logger,
task_name=TASK_NAME,
task_name=task_name,
task_id=task_id,
start_time=start_time,
)

return generic_run_query_task(
return run_query_task(
sql_adapter=sql_adapter,
logger=logger,
clp_logs_dir=clp_logs_dir,
task_command=task_command,
task_name=TASK_NAME,
task_name=task_name,
job_id=job_id,
task_id=task_id,
start_time=start_time,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import os
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List, Optional

from celery.app.task import Task
from celery.utils.log import get_task_logger
Expand All @@ -10,8 +10,8 @@
from clp_py_utils.sql_adapter import SQL_Adapter
from job_orchestration.executor.query.celery import app
from job_orchestration.executor.query.utils import (
generic_run_query_task,
report_command_creation_failure,
run_query_task,
)
from job_orchestration.scheduler.job_config import SearchJobConfig
from job_orchestration.scheduler.scheduler_data import QueryTaskStatus
Expand All @@ -28,7 +28,7 @@ def make_command(
search_config: SearchJobConfig,
results_cache_uri: str,
results_collection: str,
):
) -> Optional[List[str]]:
if StorageEngine.CLP == storage_engine:
command = [str(clp_home / "bin" / "clo"), "s", str(archives_dir / archive_id)]
if search_config.path_filter is not None:
Expand All @@ -43,7 +43,8 @@ def make_command(
archive_id,
]
else:
raise ValueError(f"Unsupported storage engine {storage_engine}")
logger.error(f"Unsupported storage engine {storage_engine}")
return None

command.append(search_config.query_string)
if search_config.begin_timestamp is not None:
Expand Down Expand Up @@ -102,15 +103,14 @@ def search(
clp_metadata_db_conn_params: dict,
results_cache_uri: str,
) -> Dict[str, Any]:
# Task name
TASK_NAME = "search"
task_name = "search"

# Setup logging to file
clp_logs_dir = Path(os.getenv("CLP_LOGS_DIR"))
clp_logging_level = str(os.getenv("CLP_LOGGING_LEVEL"))
set_logging_level(logger, clp_logging_level)

logger.info(f"Started {TASK_NAME} task for job {job_id}")
logger.info(f"Started {task_name} task for job {job_id}")

start_time = datetime.datetime.now()
task_status: QueryTaskStatus
Expand All @@ -131,22 +131,21 @@ def search(
results_cache_uri=results_cache_uri,
results_collection=str(task_id),
)

if not task_command:
return report_command_creation_failure(
sql_adapter=sql_adapter,
logger=logger,
task_name=TASK_NAME,
task_name=task_name,
task_id=task_id,
start_time=start_time,
)

return generic_run_query_task(
return run_query_task(
sql_adapter=sql_adapter,
logger=logger,
clp_logs_dir=clp_logs_dir,
task_command=task_command,
task_name=TASK_NAME,
task_name=task_name,
job_id=job_id,
task_id=task_id,
start_time=start_time,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def report_command_creation_failure(
).dict()


def generic_run_query_task(
def run_query_task(
sql_adapter: SQL_Adapter,
logger: Logger,
clp_logs_dir: Path,
Expand Down Expand Up @@ -72,7 +72,7 @@ def sigterm_handler(_signo, _stack_frame):
logger.debug("Entered sigterm handler")
if task_proc.poll() is None:
logger.debug(f"Trying to kill {task_name} process")
# Kill the process group in case the search process also forked
# Kill the process group in case the task process also forked
os.killpg(os.getpgid(task_proc.pid), signal.SIGTERM)
os.waitpid(task_proc.pid, 0)
logger.info(f"Cancelling {task_name} task.")
Expand All @@ -84,8 +84,8 @@ def sigterm_handler(_signo, _stack_frame):
signal.signal(signal.SIGTERM, sigterm_handler)

logger.info(f"Waiting for {task_name} to finish")
# communicate is equivalent to wait in this case, but avoids deadlocks if we switch to piping
# stdout/stderr in the future.
# `communicate` is equivalent to `wait` in this case, but avoids deadlocks if we switch to
# piping stdout/stderr in the future.
task_proc.communicate()
return_code = task_proc.returncode
if 0 != return_code:
Expand Down

0 comments on commit 4a4d5b2

Please sign in to comment.