Skip to content

Commit

Permalink
was just missing data
Browse files Browse the repository at this point in the history
  • Loading branch information
cnheider committed May 28, 2024
1 parent 1fe78ea commit 49ba3b0
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 71 deletions.
163 changes: 95 additions & 68 deletions heimdallr/entry_points/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json
import logging
import socket
import time
from typing import Any

import dash
Expand All @@ -20,7 +21,7 @@
from dash.dash_table import DataTable
from dash.dependencies import Input, Output
from dash.html import Div
from draugr.writers import LogWriter, MockWriter, Writer
from draugr.writers import LogWriter
from flask import Response
from paho import mqtt
from paho.mqtt.client import Client, MQTTv5
Expand All @@ -37,15 +38,13 @@
)
from heimdallr.server.board_layout import get_root_layout
from heimdallr.utilities.server import (
get_calender_df,
per_machine_per_device_pie_charts,
to_overall_gpu_process_df,
)

__all__ = ["main"]

from heimdallr.utilities.server.du_utilities import to_overall_du_process_df
from heimdallr.utilities.server.teams_status import team_members_status
import dash_bootstrap_components

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +87,7 @@
MQTT_CLIENT = Client(
client_id=CLIENT_ID,
protocol=MQTTv5,
callback_api_version=CallbackAPIVersion.VERSION2
callback_api_version=CallbackAPIVersion.VERSION2,
# clean_session=True
)
MQTT_CLIENT.tls_set(tls_version=paho.mqtt.client.ssl.PROTOCOL_TLS)
Expand All @@ -100,6 +99,8 @@

DEVELOPMENT = False
DASH_APP.layout = get_root_layout(DEVELOPMENT)


# LOG_WRITER: Writer = MockWriter()


Expand All @@ -111,6 +112,7 @@ def update_time(n: int) -> str:
"""description"""
global GPU_STATS
global KEEP_ALIVE

try:
for k, v in KEEP_ALIVE.items():
if v > ALL_CONSTANTS.TIMEOUT_MACHINES_SEC:
Expand All @@ -131,45 +133,45 @@ def update_time(n: int) -> str:
return default_datetime_repr(datetime.datetime.now())


@DASH_APP.callback(
Output(ALL_CONSTANTS.CALENDAR_ID, "children"),
[Input(ALL_CONSTANTS.CALENDAR_INTERVAL_ID, "n_intervals")],
)
def update_calendar_live(n: int) -> DataTable:
"""description"""
try:
df = get_calender_df(
HeimdallrSettings().google_calendar_id,
HeimdallrSettings()._credentials_base_path,
num_entries=ALL_CONSTANTS.TABLE_PAGE_SIZE,
)

return DataTable(
id="calender-table-0",
columns=[{"name": i, "id": i} for i in df.columns],
data=df.to_dict("records"),
page_size=ALL_CONSTANTS.TABLE_PAGE_SIZE,
style_as_list_view=True,
style_data_conditional=[
{
"if": {"column_id": "start", "filter_query": "{start} > 3.9"},
"backgroundColor": "green",
"color": "white",
},
], # TODO: MAKE GRADIENT TO ORANGE FOR WHEN NEARING START, and GREEN WHEN IN PROGRESS
)
except Exception as e:
return Div([f"Error: {e}"])


@DASH_APP.callback(
Output(ALL_CONSTANTS.TEAMS_STATUS_ID, "children"),
[Input(ALL_CONSTANTS.TEAMS_STATUS_INTERVAL_ID, "n_intervals")],
)
def update_teams_status_live(n: int) -> Div:
"""description"""

return Div(team_members_status(None), className="row")
# @DASH_APP.callback(
# Output(ALL_CONSTANTS.CALENDAR_ID, "children"),
# [Input(ALL_CONSTANTS.CALENDAR_INTERVAL_ID, "n_intervals")],
# )
# def update_calendar_live(n: int) -> DataTable:
# """description"""
# try:
# df = get_calender_df(
# HeimdallrSettings().google_calendar_id,
# HeimdallrSettings()._credentials_base_path,
# num_entries=ALL_CONSTANTS.TABLE_PAGE_SIZE,
# )
#
# return DataTable(
# id="calender-table-0",
# columns=[{"name": i, "id": i} for i in df.columns],
# data=df.to_dict("records"),
# page_size=ALL_CONSTANTS.TABLE_PAGE_SIZE,
# style_as_list_view=True,
# style_data_conditional=[
# {
# "if": {"column_id": "start", "filter_query": "{start} > 3.9"},
# "backgroundColor": "green",
# "color": "white",
# },
# ], # TODO: MAKE GRADIENT TO ORANGE FOR WHEN NEARING START, and GREEN WHEN IN PROGRESS
# )
# except Exception as e:
# return Div([f"Error: {e}"])


# @DASH_APP.callback(
# Output(ALL_CONSTANTS.TEAMS_STATUS_ID, "children"),
# [Input(ALL_CONSTANTS.TEAMS_STATUS_INTERVAL_ID, "n_intervals")],
# )
# def update_teams_status_live(n: int) -> Div:
# """description"""
#
# return Div(team_members_status(None), className="row")


@DASH_APP.callback(
Expand All @@ -178,13 +180,20 @@ def update_teams_status_live(n: int) -> Div:
)
def update_graph(n: int) -> Div:
"""description"""

global GPU_STATS
global KEEP_ALIVE
global MQTT_CLIENT

MQTT_CLIENT.loop()

compute_machines = []
if GPU_STATS:

if len(GPU_STATS) > 0:
compute_machines.extend(
per_machine_per_device_pie_charts(GPU_STATS.as_dict(), KEEP_ALIVE.as_dict())
)

return Div(compute_machines)


Expand All @@ -194,18 +203,20 @@ def update_graph(n: int) -> Div:
)
def update_table(n: int) -> Div:
"""description"""

global GPU_STATS
global MQTT_CLIENT

MQTT_CLIENT.loop()

compute_machines = []

if GPU_STATS:
if len(GPU_STATS) > 0:
df = to_overall_gpu_process_df(GPU_STATS.as_dict())

else:
df = DataFrame(["No data"], columns=("data",))

logger.warning(df)

compute_machines.append(
DataTable(
id="gpu-table-0",
Expand Down Expand Up @@ -233,17 +244,18 @@ def update_table(n: int) -> Div:
)
def update_table(n: int) -> Div:
"""description"""
global GPU_STATS
global MQTT_CLIENT

MQTT_CLIENT.loop()

compute_machines = []

if DU_STATS:
if len(DU_STATS) > 0:
df = to_overall_du_process_df(DU_STATS.as_dict())
else:
df = DataFrame(["No data"], columns=("data",))

logger.warning(df)

compute_machines.append(
DataTable(
id="du-table-0",
Expand Down Expand Up @@ -296,8 +308,10 @@ def on_post_config() -> Response:
for k, v in flask.request.form.items():
if v != "":
setattr(settings, k, v)

else:
print("Not in development mode")

return flask.redirect("/")


Expand All @@ -306,35 +320,48 @@ def on_message(client: Any, userdata: Any, result: mqtt.client.MQTTMessage) -> N
global GPU_STATS
global KEEP_ALIVE

d = json.loads(result.payload)
mapping_message = json.loads(result.payload)

keys = d.keys()

logger.warning("received message")
keys = mapping_message.keys()

for key in keys:

if "gpu_stats" in d[key]:
GPU_STATS[key] = d[key]["gpu_stats"]
DU_STATS[key] = d[key]["du_stats"]
if "gpu_stats" in mapping_message[key]:
GPU_STATS[key] = mapping_message[key]["gpu_stats"]
DU_STATS[key] = mapping_message[key]["du_stats"]

else:
GPU_STATS[key] = d[key] # ["gpu_stats"]
GPU_STATS[key] = mapping_message[key] # ["gpu_stats"]
DU_STATS[key] = {}

KEEP_ALIVE[key] = 0

logger.warning(
logger.info(
f"received payload for {keys}, retain:{result.retain}, timestamp:{result.timestamp}"
)


def on_disconnect(client: Any, userdata: Any, rc: Any) -> None:
"""description"""
if rc != 0:
print("Unexpected MQTT disconnection. Will auto-reconnect")
client.reconnect()
client.subscribe(ALL_CONSTANTS.MQTT_TOPIC, ALL_CONSTANTS.MQTT_QOS)
def on_disconnect(client: Any, userdata: Any, rc: Any, *_) -> None:
while True:
# loop until client.reconnect()
# returns 0, which means the
# client is connected
try:
if not client.reconnect():
print("Unexpected MQTT disconnection. Will auto-reconnect")
break
except ConnectionRefusedError:
# if the server is not running,
# then the host rejects the connection
# and a ConnectionRefusedError is thrown
# getting this error > continue trying to
# connect
pass
# if the reconnect was not successful,
# wait one second
time.sleep(1)

client.subscribe(ALL_CONSTANTS.MQTT_TOPIC, ALL_CONSTANTS.MQTT_QOS)


def setup_mqtt_connection(settings) -> None:
Expand Down Expand Up @@ -368,14 +395,15 @@ def setup_mqtt_connection(settings) -> None:
logger.error(f"MQTT connection error: {e}")
# raise e


def on_unsubscribe(client, userdata, mid, reason_code_list, properties):
# Be careful, the reason_code_list is only present in MQTTv5.
# In MQTTv3 it will always be empty
if len(reason_code_list) == 0 or not reason_code_list[0].is_failure:
print("unsubscribe succeeded (if SUBACK is received in MQTTv3 it success)")
else:
print(f"Broker replied with failure: {reason_code_list[0]}")
#client.disconnect()
# client.disconnect()


def on_subscribe(client, userdata, mid, reason_code_list, properties):
Expand All @@ -397,7 +425,6 @@ def on_connect(client, userdata, flags, reason_code, properties):
MQTT_CLIENT.subscribe(ALL_CONSTANTS.MQTT_TOPIC, qos=ALL_CONSTANTS.MQTT_QOS)



def main(
*args,
setting_scope: SettingScopeEnum = SettingScopeEnum.user,
Expand Down Expand Up @@ -426,11 +453,11 @@ def main(
MQTT_CLIENT.on_subscribe = on_subscribe
MQTT_CLIENT.on_unsubscribe = on_unsubscribe
MQTT_CLIENT.on_connect = on_connect
MQTT_CLIENT.on_disconnect = on_disconnect

crystallised_heimdallr_settings = HeimdallrSettings(setting_scope)
setup_mqtt_connection(settings=crystallised_heimdallr_settings)


DASH_APP.title = ALL_CONSTANTS.HTML_TITLE
DASH_APP.update_title = ALL_CONSTANTS.HTML_TITLE
host = ALL_CONSTANTS.SERVER_ADDRESS
Expand Down
7 changes: 5 additions & 2 deletions heimdallr/utilities/server/du_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@

import numpy
import pandas
from pandas import DataFrame

from heimdallr.configuration.heimdallr_config import (
DROP_COLUMNS,
INT_COLUMNS,
MB_COLUMNS,
PERCENT_COLUMNS,
)
from heimdallr.utilities.date_tools import timestamp_to_datetime
from pandas import DataFrame

MB_DIVISOR = int(1024**2)

Expand All @@ -39,9 +40,10 @@ def to_overall_du_process_df(gpu_stats: Mapping) -> DataFrame:
for part_i in v2["partitions"]:
df = pandas.DataFrame(data=part_i)
resulta.append(df)

if len(resulta):
out_df = pandas.concat(resulta, sort=False)
out_df.sort_values(by="used_gpu_mem", axis=0, ascending=False, inplace=True)
out_df.sort_values(by="used", axis=0, ascending=False, inplace=True)
if len(out_df) == 0:
return pandas.DataFrame()

Expand All @@ -61,4 +63,5 @@ def to_overall_du_process_df(gpu_stats: Mapping) -> DataFrame:
out_df = out_df[out_cols]

return out_df

return pandas.DataFrame(data=["no data"], columns=["no data"])
13 changes: 12 additions & 1 deletion heimdallr/utilities/server/gpu_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,31 @@ def to_overall_gpu_process_df(
"""
resulta = []
columns = []

logger.warning(gpu_stats)

if len(gpu_stats):

for k2, v2 in gpu_stats.items():
device_info = v2["devices"]

for device_i in device_info:

processes = device_i["processes"]

if len(processes) > 0:
columns = list(processes[0].keys())

df = pandas.DataFrame(data=processes)
df["machine"] = [k2] * len(processes)

resulta.append(df)

out_df = pandas.concat(resulta, sort=False)

if sort_by_key in out_df.columns:
logger.warning(out_df)

if sort_by_key:
out_df.sort_values(by=sort_by_key, axis=0, ascending=False, inplace=True)
else:
logger.warning(f"{sort_by_key} was not found in {out_df.columns}")
Expand Down

0 comments on commit 49ba3b0

Please sign in to comment.