diff --git a/skymap_scanner/server/__init__.py b/skymap_scanner/server/__init__.py index e69de29bb..a75bc177c 100644 --- a/skymap_scanner/server/__init__.py +++ b/skymap_scanner/server/__init__.py @@ -0,0 +1,93 @@ +"""The Skymap Scanner Central Server.""" + +import dataclasses as dc + +from wipac_dev_tools import from_environment_as_dataclass + + +# +# Env var constants: set as constants & typecast +# + + +@dc.dataclass(frozen=True) +class EnvConfig: + """For storing environment variables, typed.""" + + # + # REQUIRED + # + + SKYSCAN_SKYDRIVER_SCAN_ID: str # globally unique ID + + # to-client queue + SKYSCAN_MQ_TOCLIENT: str + SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN: str + SKYSCAN_MQ_TOCLIENT_BROKER_TYPE: str + SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS: str + # + # from-client queue + SKYSCAN_MQ_FROMCLIENT: str + SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN: str + SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE: str + SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS: str + + # + # OPTIONAL + # + + SKYSCAN_PROGRESS_INTERVAL_SEC: int = 1 * 60 + SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO: float = ( + # The size of the sample window (a percentage of the collected/finished recos) + # used to calculate the most recent runtime rate (sec/reco), then used to make + # predictions for overall runtimes: i.e. amount of time left. + # Also, see SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN. + 0.1 + ) + SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN: int = ( + # NOTE: val should not be (too) below the num of workers (which is unknown, so make a good guess). + # In other words, if val is too low, then the rate is not representative of the + # worker-pool's concurrency; if val is too high, then the window is too large. + # This is only useful for the first `val/SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO` + # num of recos, afterward the ratio is used. + 100 + ) + SKYSCAN_RESULT_INTERVAL_SEC: int = 2 * 60 + + SKYSCAN_KILL_SWITCH_CHECK_INTERVAL: int = 5 * 60 + + # TIMEOUTS + # + # seconds -- how long server waits before thinking all clients are dead + # - set to duration of first reco + client launch (condor) + # - important if clients launch *AFTER* server + # - normal expiration scenario: all clients died (bad condor submit file), otherwise never (server knows when all recos are done) + SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS: int = 3 * 24 * 60 * 60 # 3 days + + # SKYDRIVER VARS + SKYSCAN_SKYDRIVER_ADDRESS: str = "" # SkyDriver REST interface address + SKYSCAN_SKYDRIVER_AUTH: str = "" # SkyDriver REST interface auth token + + # LOGGING VARS + SKYSCAN_LOG: str = "INFO" + SKYSCAN_LOG_THIRD_PARTY: str = "WARNING" + SKYSCAN_EWMS_PILOT_LOG: str = "INFO" + SKYSCAN_MQ_CLIENT_LOG: str = "INFO" + + # TESTING/DEBUG VARS + SKYSCAN_MINI_TEST: bool = False # run minimal variations for testing (mini-scale) + SKYSCAN_CRASH_DUMMY_PROBABILITY: float = 0.5 # for reco algo: crash-dummy + + def __post_init__(self) -> None: + """Check values.""" + if self.SKYSCAN_PROGRESS_INTERVAL_SEC <= 0: + raise ValueError( + f"Env Var: SKYSCAN_PROGRESS_INTERVAL_SEC is not positive: {self.SKYSCAN_PROGRESS_INTERVAL_SEC}" + ) + if self.SKYSCAN_RESULT_INTERVAL_SEC <= 0: + raise ValueError( + f"Env Var: SKYSCAN_RESULT_INTERVAL_SEC is not positive: {self.SKYSCAN_RESULT_INTERVAL_SEC}" + ) + + +ENV = from_environment_as_dataclass(EnvConfig) diff --git a/skymap_scanner/server/reporter.py b/skymap_scanner/server/reporter.py index 5c7fee7a3..2726e67e4 100644 --- a/skymap_scanner/server/reporter.py +++ b/skymap_scanner/server/reporter.py @@ -15,6 +15,7 @@ from rest_tools.client import RestClient from skyreader import EventMetadata, SkyScanResult +from . import ENV from .utils import NSideProgression, connect_to_skydriver, nonurgent_request from .. import config as cfg from ..utils import to_skyscan_result @@ -185,10 +186,8 @@ def __init__(self) -> None: def on_server_recent_sec_per_reco_rate(self) -> float: """The sec/reco rate from server pov within a moving window.""" window_size = max( - int( - self.total_ct * cfg.ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO - ), - cfg.ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN, + int(self.total_ct * ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO), + ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN, ) try: @@ -333,7 +332,7 @@ def __init__( self._n_sent_by_nside: Dict[int, int] = {} - if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS: + if not ENV.SKYSCAN_SKYDRIVER_ADDRESS: self.skydriver_rc_nonurgent: Optional[RestClient] = None self.skydriver_rc_urgent: Optional[RestClient] = None else: @@ -424,8 +423,7 @@ async def make_reports_if_needed( # check if we need to send a report to the logger current_time = time.time() if bypass_timers or ( - current_time - self.last_time_reported - > cfg.ENV.SKYSCAN_PROGRESS_INTERVAL_SEC + current_time - self.last_time_reported > ENV.SKYSCAN_PROGRESS_INTERVAL_SEC ): self.last_time_reported = current_time if self.worker_stats_collection.total_ct == 0: @@ -433,7 +431,7 @@ async def make_reports_if_needed( else: epilogue_msg = ( f"I will report back again in " - f"{cfg.ENV.SKYSCAN_PROGRESS_INTERVAL_SEC} seconds if I have an update." + f"{ENV.SKYSCAN_PROGRESS_INTERVAL_SEC} seconds if I have an update." ) await self._send_progress(summary_msg, epilogue_msg) @@ -441,7 +439,7 @@ async def make_reports_if_needed( current_time = time.time() if bypass_timers or ( current_time - self.last_time_reported_skymap - > cfg.ENV.SKYSCAN_RESULT_INTERVAL_SEC + > ENV.SKYSCAN_RESULT_INTERVAL_SEC ): self.last_time_reported_skymap = current_time await self._send_result() @@ -658,7 +656,7 @@ async def _send_progress( "last_updated": str(dt.datetime.fromtimestamp(int(time.time()))), } scan_metadata = { - "scan_id": cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID, + "scan_id": ENV.SKYSCAN_SKYDRIVER_SCAN_ID, "nside_progression": self.nside_progression, "position_variations": self.n_posvar, } @@ -674,7 +672,7 @@ async def _send_progress( # skydriver sd_args = dict( method="PATCH", - path=f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest", + path=f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest", args=body, ) if not self.is_event_scan_done and self.skydriver_rc_nonurgent: @@ -702,7 +700,7 @@ async def _send_result(self) -> SkyScanResult: body = {"skyscan_result": serialized, "is_final": self.is_event_scan_done} sd_args = dict( method="PUT", - path=f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/result", + path=f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/result", args=body, ) if not self.is_event_scan_done and self.skydriver_rc_nonurgent: diff --git a/skymap_scanner/server/start_scan.py b/skymap_scanner/server/start_scan.py index 5abb21ad0..00255a367 100644 --- a/skymap_scanner/server/start_scan.py +++ b/skymap_scanner/server/start_scan.py @@ -1,7 +1,6 @@ """The Skymap Scanner Server.""" # pylint: disable=invalid-name,import-error -# fmt:quotes-ok import argparse import asyncio @@ -25,6 +24,7 @@ from skyreader import EventMetadata from wipac_dev_tools import argparse_tools, logging_tools +from . import ENV from .collector import Collector, ExtraRecoPixelVariationException from .pixels import choose_pixels_to_reconstruct from .reporter import Reporter @@ -668,13 +668,13 @@ def _mkdir_if_not_exists(val: str, is_file: bool = False) -> Path: raise NotADirectoryError(args.gcd_dir) # check output status - if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS and not args.output_dir: + if not ENV.SKYSCAN_SKYDRIVER_ADDRESS and not args.output_dir: raise RuntimeError( "Must include either --output-dir or SKYSCAN_SKYDRIVER_ADDRESS (env var), " "otherwise you won't see your results!" ) # read event file - if cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS: + if ENV.SKYSCAN_SKYDRIVER_ADDRESS: event_contents = asyncio.run(fetch_event_contents_from_skydriver()) else: event_contents = fetch_event_contents_from_file(args.event_file) diff --git a/skymap_scanner/server/utils.py b/skymap_scanner/server/utils.py index 4ca00c73f..e16543cf6 100644 --- a/skymap_scanner/server/utils.py +++ b/skymap_scanner/server/utils.py @@ -14,7 +14,7 @@ import mqclient as mq from rest_tools.client import CalcRetryFromWaittimeMax, RestClient -from .. import config as cfg +from . import ENV LOGGER = logging.getLogger(__name__) @@ -25,18 +25,18 @@ def get_mqclient_connections() -> tuple[mq.Queue, mq.Queue]: """Establish connections to message queues.""" to_clients_queue = mq.Queue( - cfg.ENV.SKYSCAN_MQ_TOCLIENT_BROKER_TYPE, - address=cfg.ENV.SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS, - name=cfg.ENV.SKYSCAN_MQ_TOCLIENT, - auth_token=cfg.ENV.SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN, + ENV.SKYSCAN_MQ_TOCLIENT_BROKER_TYPE, + address=ENV.SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS, + name=ENV.SKYSCAN_MQ_TOCLIENT, + auth_token=ENV.SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN, # timeout=-1, # NOTE: this mq only sends messages so no timeout needed ) from_clients_queue = mq.Queue( - cfg.ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE, - address=cfg.ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS, - name=cfg.ENV.SKYSCAN_MQ_FROMCLIENT, - auth_token=cfg.ENV.SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN, - timeout=cfg.ENV.SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS, + ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE, + address=ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS, + name=ENV.SKYSCAN_MQ_FROMCLIENT, + auth_token=ENV.SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN, + timeout=ENV.SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS, ) return to_clients_queue, from_clients_queue @@ -49,16 +49,16 @@ def connect_to_skydriver(urgent: bool) -> RestClient: """Get REST client for SkyDriver depending on the urgency.""" if urgent: return RestClient( - cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS, - token=cfg.ENV.SKYSCAN_SKYDRIVER_AUTH, + ENV.SKYSCAN_SKYDRIVER_ADDRESS, + token=ENV.SKYSCAN_SKYDRIVER_AUTH, timeout=60.0, retries=CalcRetryFromWaittimeMax(waittime_max=1 * 60 * 60), # backoff_factor=0.3, ) else: return RestClient( - cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS, - token=cfg.ENV.SKYSCAN_SKYDRIVER_AUTH, + ENV.SKYSCAN_SKYDRIVER_ADDRESS, + token=ENV.SKYSCAN_SKYDRIVER_AUTH, timeout=10.0, retries=1, # backoff_factor=0.3, @@ -76,7 +76,7 @@ async def nonurgent_request(rc: RestClient, args: dict[str, Any]) -> Any: async def kill_switch_check_from_skydriver() -> None: """Routinely check SkyDriver whether to continue the scan.""" - if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS: + if not ENV.SKYSCAN_SKYDRIVER_ADDRESS: return logger = logging.getLogger("skyscan.kill_switch") @@ -84,10 +84,10 @@ async def kill_switch_check_from_skydriver() -> None: skydriver_rc = connect_to_skydriver(urgent=False) while True: - await asyncio.sleep(cfg.ENV.SKYSCAN_KILL_SWITCH_CHECK_INTERVAL) + await asyncio.sleep(ENV.SKYSCAN_KILL_SWITCH_CHECK_INTERVAL) status = await skydriver_rc.request( - "GET", f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/status" + "GET", f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/status" ) if status["scan_state"].startswith("STOPPED__"): @@ -105,7 +105,7 @@ async def fetch_event_contents_from_skydriver() -> Any: skydriver_rc = connect_to_skydriver(urgent=True) manifest = await skydriver_rc.request( - "GET", f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest" + "GET", f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest" ) LOGGER.info("Fetched event contents from SkyDriver") return manifest["event_i3live_json_dict"]