Skip to content

Commit

Permalink
move ENV to server/
Browse files Browse the repository at this point in the history
  • Loading branch information
ric-evans committed Oct 28, 2024
1 parent ac245d3 commit 900ae0b
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 33 deletions.
93 changes: 93 additions & 0 deletions skymap_scanner/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""The Skymap Scanner Central Server."""

import dataclasses as dc

from wipac_dev_tools import from_environment_as_dataclass


#
# Env var constants: set as constants & typecast
#


@dc.dataclass(frozen=True)
class EnvConfig:
"""For storing environment variables, typed."""

#
# REQUIRED
#

SKYSCAN_SKYDRIVER_SCAN_ID: str # globally unique ID

# to-client queue
SKYSCAN_MQ_TOCLIENT: str
SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN: str
SKYSCAN_MQ_TOCLIENT_BROKER_TYPE: str
SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS: str
#
# from-client queue
SKYSCAN_MQ_FROMCLIENT: str
SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN: str
SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE: str
SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS: str

#
# OPTIONAL
#

SKYSCAN_PROGRESS_INTERVAL_SEC: int = 1 * 60
SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO: float = (
# The size of the sample window (a percentage of the collected/finished recos)
# used to calculate the most recent runtime rate (sec/reco), then used to make
# predictions for overall runtimes: i.e. amount of time left.
# Also, see SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN.
0.1
)
SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN: int = (
# NOTE: val should not be (too) below the num of workers (which is unknown, so make a good guess).
# In other words, if val is too low, then the rate is not representative of the
# worker-pool's concurrency; if val is too high, then the window is too large.
# This is only useful for the first `val/SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO`
# num of recos, afterward the ratio is used.
100
)
SKYSCAN_RESULT_INTERVAL_SEC: int = 2 * 60

SKYSCAN_KILL_SWITCH_CHECK_INTERVAL: int = 5 * 60

# TIMEOUTS
#
# seconds -- how long server waits before thinking all clients are dead
# - set to duration of first reco + client launch (condor)
# - important if clients launch *AFTER* server
# - normal expiration scenario: all clients died (bad condor submit file), otherwise never (server knows when all recos are done)
SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS: int = 3 * 24 * 60 * 60 # 3 days

# SKYDRIVER VARS
SKYSCAN_SKYDRIVER_ADDRESS: str = "" # SkyDriver REST interface address
SKYSCAN_SKYDRIVER_AUTH: str = "" # SkyDriver REST interface auth token

# LOGGING VARS
SKYSCAN_LOG: str = "INFO"
SKYSCAN_LOG_THIRD_PARTY: str = "WARNING"
SKYSCAN_EWMS_PILOT_LOG: str = "INFO"
SKYSCAN_MQ_CLIENT_LOG: str = "INFO"

# TESTING/DEBUG VARS
SKYSCAN_MINI_TEST: bool = False # run minimal variations for testing (mini-scale)
SKYSCAN_CRASH_DUMMY_PROBABILITY: float = 0.5 # for reco algo: crash-dummy

def __post_init__(self) -> None:
"""Check values."""
if self.SKYSCAN_PROGRESS_INTERVAL_SEC <= 0:
raise ValueError(
f"Env Var: SKYSCAN_PROGRESS_INTERVAL_SEC is not positive: {self.SKYSCAN_PROGRESS_INTERVAL_SEC}"
)
if self.SKYSCAN_RESULT_INTERVAL_SEC <= 0:
raise ValueError(
f"Env Var: SKYSCAN_RESULT_INTERVAL_SEC is not positive: {self.SKYSCAN_RESULT_INTERVAL_SEC}"
)


ENV = from_environment_as_dataclass(EnvConfig)
22 changes: 10 additions & 12 deletions skymap_scanner/server/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rest_tools.client import RestClient
from skyreader import EventMetadata, SkyScanResult

from . import ENV
from .utils import NSideProgression, connect_to_skydriver, nonurgent_request
from .. import config as cfg
from ..utils import to_skyscan_result
Expand Down Expand Up @@ -185,10 +186,8 @@ def __init__(self) -> None:
def on_server_recent_sec_per_reco_rate(self) -> float:
"""The sec/reco rate from server pov within a moving window."""
window_size = max(
int(
self.total_ct * cfg.ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO
),
cfg.ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN,
int(self.total_ct * ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO),
ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_MIN,
)

try:
Expand Down Expand Up @@ -333,7 +332,7 @@ def __init__(

self._n_sent_by_nside: Dict[int, int] = {}

if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS:
if not ENV.SKYSCAN_SKYDRIVER_ADDRESS:
self.skydriver_rc_nonurgent: Optional[RestClient] = None
self.skydriver_rc_urgent: Optional[RestClient] = None
else:
Expand Down Expand Up @@ -424,24 +423,23 @@ async def make_reports_if_needed(
# check if we need to send a report to the logger
current_time = time.time()
if bypass_timers or (
current_time - self.last_time_reported
> cfg.ENV.SKYSCAN_PROGRESS_INTERVAL_SEC
current_time - self.last_time_reported > ENV.SKYSCAN_PROGRESS_INTERVAL_SEC
):
self.last_time_reported = current_time
if self.worker_stats_collection.total_ct == 0:
epilogue_msg = "I will report back when I start getting recos."
else:
epilogue_msg = (
f"I will report back again in "
f"{cfg.ENV.SKYSCAN_PROGRESS_INTERVAL_SEC} seconds if I have an update."
f"{ENV.SKYSCAN_PROGRESS_INTERVAL_SEC} seconds if I have an update."
)
await self._send_progress(summary_msg, epilogue_msg)

# check if we need to send a report to the skymap logger
current_time = time.time()
if bypass_timers or (
current_time - self.last_time_reported_skymap
> cfg.ENV.SKYSCAN_RESULT_INTERVAL_SEC
> ENV.SKYSCAN_RESULT_INTERVAL_SEC
):
self.last_time_reported_skymap = current_time
await self._send_result()
Expand Down Expand Up @@ -658,7 +656,7 @@ async def _send_progress(
"last_updated": str(dt.datetime.fromtimestamp(int(time.time()))),
}
scan_metadata = {
"scan_id": cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID,
"scan_id": ENV.SKYSCAN_SKYDRIVER_SCAN_ID,
"nside_progression": self.nside_progression,
"position_variations": self.n_posvar,
}
Expand All @@ -674,7 +672,7 @@ async def _send_progress(
# skydriver
sd_args = dict(
method="PATCH",
path=f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest",
path=f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest",
args=body,
)
if not self.is_event_scan_done and self.skydriver_rc_nonurgent:
Expand Down Expand Up @@ -702,7 +700,7 @@ async def _send_result(self) -> SkyScanResult:
body = {"skyscan_result": serialized, "is_final": self.is_event_scan_done}
sd_args = dict(
method="PUT",
path=f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/result",
path=f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/result",
args=body,
)
if not self.is_event_scan_done and self.skydriver_rc_nonurgent:
Expand Down
6 changes: 3 additions & 3 deletions skymap_scanner/server/start_scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""The Skymap Scanner Server."""

# pylint: disable=invalid-name,import-error
# fmt:quotes-ok

import argparse
import asyncio
Expand All @@ -25,6 +24,7 @@
from skyreader import EventMetadata
from wipac_dev_tools import argparse_tools, logging_tools

from . import ENV
from .collector import Collector, ExtraRecoPixelVariationException
from .pixels import choose_pixels_to_reconstruct
from .reporter import Reporter
Expand Down Expand Up @@ -668,13 +668,13 @@ def _mkdir_if_not_exists(val: str, is_file: bool = False) -> Path:
raise NotADirectoryError(args.gcd_dir)

# check output status
if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS and not args.output_dir:
if not ENV.SKYSCAN_SKYDRIVER_ADDRESS and not args.output_dir:
raise RuntimeError(
"Must include either --output-dir or SKYSCAN_SKYDRIVER_ADDRESS (env var), "
"otherwise you won't see your results!"
)
# read event file
if cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS:
if ENV.SKYSCAN_SKYDRIVER_ADDRESS:
event_contents = asyncio.run(fetch_event_contents_from_skydriver())
else:
event_contents = fetch_event_contents_from_file(args.event_file)
Expand Down
36 changes: 18 additions & 18 deletions skymap_scanner/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import mqclient as mq
from rest_tools.client import CalcRetryFromWaittimeMax, RestClient

from .. import config as cfg
from . import ENV

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,18 +25,18 @@
def get_mqclient_connections() -> tuple[mq.Queue, mq.Queue]:
"""Establish connections to message queues."""
to_clients_queue = mq.Queue(
cfg.ENV.SKYSCAN_MQ_TOCLIENT_BROKER_TYPE,
address=cfg.ENV.SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS,
name=cfg.ENV.SKYSCAN_MQ_TOCLIENT,
auth_token=cfg.ENV.SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN,
ENV.SKYSCAN_MQ_TOCLIENT_BROKER_TYPE,
address=ENV.SKYSCAN_MQ_TOCLIENT_BROKER_ADDRESS,
name=ENV.SKYSCAN_MQ_TOCLIENT,
auth_token=ENV.SKYSCAN_MQ_TOCLIENT_AUTH_TOKEN,
# timeout=-1, # NOTE: this mq only sends messages so no timeout needed
)
from_clients_queue = mq.Queue(
cfg.ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE,
address=cfg.ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS,
name=cfg.ENV.SKYSCAN_MQ_FROMCLIENT,
auth_token=cfg.ENV.SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN,
timeout=cfg.ENV.SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS,
ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_TYPE,
address=ENV.SKYSCAN_MQ_FROMCLIENT_BROKER_ADDRESS,
name=ENV.SKYSCAN_MQ_FROMCLIENT,
auth_token=ENV.SKYSCAN_MQ_FROMCLIENT_AUTH_TOKEN,
timeout=ENV.SKYSCAN_MQ_TIMEOUT_FROM_CLIENTS,
)

return to_clients_queue, from_clients_queue
Expand All @@ -49,16 +49,16 @@ def connect_to_skydriver(urgent: bool) -> RestClient:
"""Get REST client for SkyDriver depending on the urgency."""
if urgent:
return RestClient(
cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS,
token=cfg.ENV.SKYSCAN_SKYDRIVER_AUTH,
ENV.SKYSCAN_SKYDRIVER_ADDRESS,
token=ENV.SKYSCAN_SKYDRIVER_AUTH,
timeout=60.0,
retries=CalcRetryFromWaittimeMax(waittime_max=1 * 60 * 60),
# backoff_factor=0.3,
)
else:
return RestClient(
cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS,
token=cfg.ENV.SKYSCAN_SKYDRIVER_AUTH,
ENV.SKYSCAN_SKYDRIVER_ADDRESS,
token=ENV.SKYSCAN_SKYDRIVER_AUTH,
timeout=10.0,
retries=1,
# backoff_factor=0.3,
Expand All @@ -76,18 +76,18 @@ async def nonurgent_request(rc: RestClient, args: dict[str, Any]) -> Any:

async def kill_switch_check_from_skydriver() -> None:
"""Routinely check SkyDriver whether to continue the scan."""
if not cfg.ENV.SKYSCAN_SKYDRIVER_ADDRESS:
if not ENV.SKYSCAN_SKYDRIVER_ADDRESS:
return

logger = logging.getLogger("skyscan.kill_switch")

skydriver_rc = connect_to_skydriver(urgent=False)

while True:
await asyncio.sleep(cfg.ENV.SKYSCAN_KILL_SWITCH_CHECK_INTERVAL)
await asyncio.sleep(ENV.SKYSCAN_KILL_SWITCH_CHECK_INTERVAL)

status = await skydriver_rc.request(
"GET", f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/status"
"GET", f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/status"
)

if status["scan_state"].startswith("STOPPED__"):
Expand All @@ -105,7 +105,7 @@ async def fetch_event_contents_from_skydriver() -> Any:
skydriver_rc = connect_to_skydriver(urgent=True)

manifest = await skydriver_rc.request(
"GET", f"/scan/{cfg.ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest"
"GET", f"/scan/{ENV.SKYSCAN_SKYDRIVER_SCAN_ID}/manifest"
)
LOGGER.info("Fetched event contents from SkyDriver")
return manifest["event_i3live_json_dict"]
Expand Down

0 comments on commit 900ae0b

Please sign in to comment.