diff --git a/src/jobflow_remote/cli/flow.py b/src/jobflow_remote/cli/flow.py index 088331e0..b5eac7f8 100644 --- a/src/jobflow_remote/cli/flow.py +++ b/src/jobflow_remote/cli/flow.py @@ -38,7 +38,7 @@ loading_spinner, out_console, ) -from jobflow_remote.jobs.graph import get_graph +from jobflow_remote.jobs.graph import get_graph, plot_dash app_flow = JFRTyper( name="flow", help="Commands for managing the flows", no_args_is_help=True @@ -211,6 +211,14 @@ def graph( help="If defined, the graph will be dumped to a file", ), ] = None, + dash_plot: Annotated[ + bool, + typer.Option( + "--dash", + "-d", + help="Show the graph in a dash app", + ), + ] = False, ): """ Provide detailed information on a Flow @@ -237,8 +245,11 @@ def graph( if not flows_info: exit_with_error_msg("No data matching the request") - plt = draw_graph(get_graph(flows_info[0], label=label)) - if file_path: - plt.savefig(file_path) + if dash_plot: + plot_dash(flows_info[0]) else: - plt.show() + plt = draw_graph(get_graph(flows_info[0], label=label)) + if file_path: + plt.savefig(file_path) + else: + plt.show() diff --git a/src/jobflow_remote/jobs/data.py b/src/jobflow_remote/jobs/data.py index 992da360..afcec745 100644 --- a/src/jobflow_remote/jobs/data.py +++ b/src/jobflow_remote/jobs/data.py @@ -271,6 +271,7 @@ class FlowInfo(BaseModel): job_states: list[JobState] job_names: list[str] parents: list[list[str]] + hosts: list[list[str]] @classmethod def from_query_dict(cls, d): @@ -282,6 +283,7 @@ def from_query_dict(cls, d): job_states = [] job_names = [] parents = [] + job_hosts = [] if jobs_data: db_ids = [] @@ -296,6 +298,7 @@ def from_query_dict(cls, d): job_states.append(JobState(state)) workers.append(job_doc["worker"]) parents.append(job_doc["parents"] or []) + job_hosts.append(job_doc["job"]["hosts"] or []) else: db_ids, job_ids, job_indexes = list(zip(*d["ids"])) # parents could be determined in this case as well from the Flow document. @@ -316,6 +319,7 @@ def from_query_dict(cls, d): job_states=job_states, job_names=job_names, parents=parents, + hosts=job_hosts, ) @cached_property @@ -339,6 +343,7 @@ def iter_job_prop(self): d["name"] = self.job_names[i] d["state"] = self.job_states[i] d["parents"] = self.parents[i] + d["hosts"] = self.hosts[i] yield d diff --git a/src/jobflow_remote/jobs/graph.py b/src/jobflow_remote/jobs/graph.py index 7adceb6f..94622c5a 100644 --- a/src/jobflow_remote/jobs/graph.py +++ b/src/jobflow_remote/jobs/graph.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +from jobflow_remote.jobs.state import JobState + if TYPE_CHECKING: from networkx import DiGraph @@ -30,3 +32,145 @@ def get_graph(flow: FlowInfo, label: str = "name") -> DiGraph: graph.add_edge(parent_node, child_node) return graph + + +def get_graph_elements(flow: FlowInfo): + ids_mapping = flow.ids_mapping + + nodes = {} + for job_prop in flow.iter_job_prop(): + db_id = job_prop["db_id"] + nodes[db_id] = job_prop + + # edges based on parents + edges = [] + for child_node, parents in zip(flow.db_ids, flow.parents): + for parent_uuid in parents: + for parent_node in ids_mapping[parent_uuid].values(): + edges.append((parent_node, child_node)) + + # group of nodes based on hosts + # from collections import defaultdict + # groups = defaultdict(list) + hosts = {} + # for job_prop in flow.iter_job_prop(): + # for host in job_prop["hosts"]: + # groups[host].append(job_prop["db_id"]) + for job_prop in flow.iter_job_prop(): + hosts[job_prop["db_id"]] = job_prop["hosts"] + + return nodes, edges, hosts + + +def plot_dash(flow: FlowInfo): + nodes, edges, hosts = get_graph_elements(flow) + + import dash_cytoscape as cyto + from dash import Dash, Input, Output, callback, html + + app = Dash(f"{flow.name} - {flow.flow_id}") + + elements = [] + + # parent elements + hosts_hierarchy = {} + jobs_inner_hosts = {} + hosts_set = set() + for db_id, job_hosts in hosts.items(): + job_hosts = list(reversed(job_hosts)) + if len(job_hosts) < 2: + continue + for i, host in enumerate(job_hosts[1:-1], 1): + hosts_hierarchy[job_hosts[i + 1]] = host + + hosts_set.update(job_hosts[1:]) + jobs_inner_hosts[db_id] = job_hosts[-1] + + for host in hosts_set: + elements.append({"data": {"id": host, "parent": hosts_hierarchy.get(host)}}) + + for db_id, node_info in nodes.items(): + node_info["id"] = str(db_id) + node_info["label"] = node_info["name"] + node_info["parent"] = jobs_inner_hosts.get(db_id) + elements.append( + { + "data": node_info, + } + ) + + for edge in edges: + elements.append({"data": {"source": str(edge[0]), "target": str(edge[1])}}) + + stylesheet: list[dict] = [ + { + "selector": f'[state = "{state}"]', + "style": { + "background-color": color, + }, + } + for state, color in COLOR_MAPPING.items() + ] + stylesheet.append( + { + "selector": "node", + "style": { + "label": "data(name)", + }, + } + ) + stylesheet.append( + { + "selector": "node:parent", + "style": { + "background-opacity": 0.2, + "background-color": "#2B65EC", + "border-color": "#2B65EC", + }, + } + ) + + app.layout = html.Div( + [ + cyto.Cytoscape( + id="flow-graph", + layout={"name": "breadthfirst", "directed": True}, + # layout={'name': 'cose'}, + style={"width": "100%", "height": "500px"}, + elements=elements, + stylesheet=stylesheet, + ), + html.P(id="job-info-output"), + ] + ) + + @callback( + Output("job-info-output", "children"), Input("flow-graph", "mouseoverNodeData") + ) + def displayTapNodeData(data): + if data: + return str(data) + + app.run(debug=True) + + +BLUE_COLOR = "#5E6BFF" +RED_COLOR = "#fC3737" +COLOR_MAPPING = { + JobState.WAITING.value: "grey", + JobState.READY.value: "#DAF7A6", + JobState.CHECKED_OUT.value: BLUE_COLOR, + JobState.UPLOADED.value: BLUE_COLOR, + JobState.SUBMITTED.value: BLUE_COLOR, + JobState.RUNNING.value: BLUE_COLOR, + JobState.TERMINATED.value: BLUE_COLOR, + JobState.DOWNLOADED.value: BLUE_COLOR, + JobState.REMOTE_ERROR.value: RED_COLOR, + JobState.COMPLETED.value: "#47bf00", + JobState.FAILED.value: RED_COLOR, + JobState.PAUSED.value: "#EAE200", + JobState.STOPPED.value: RED_COLOR, + JobState.CANCELLED.value: RED_COLOR, + JobState.BATCH_SUBMITTED.value: BLUE_COLOR, + JobState.BATCH_RUNNING.value: BLUE_COLOR, +} diff --git a/src/jobflow_remote/jobs/jobcontroller.py b/src/jobflow_remote/jobs/jobcontroller.py index cd3df6d0..25e24009 100644 --- a/src/jobflow_remote/jobs/jobcontroller.py +++ b/src/jobflow_remote/jobs/jobcontroller.py @@ -1079,6 +1079,7 @@ def get_flows_info( # TODO reduce the projection to the bare minimum to reduce the amount of # fecthed data? projection = {f"jobs_list.{f}": 1 for f in projection_job_info} + projection["jobs_list.job.hosts"] = 1 for k in FlowDoc.model_fields.keys(): projection[k] = 1 @@ -1399,20 +1400,17 @@ def _append_flow( new_flow = get_flow(new_flow, allow_external_references=True) - # get job parents and set the previous hosts + # get job parents if response_type == DynamicResponseType.REPLACE: job_parents = job_doc.parents else: job_parents = [(job_doc.uuid, job_doc.index)] - if job_doc.job.hosts: - new_flow.add_hosts_uuids(job_doc.job.hosts) - - flow_updates: dict[str, dict[str, Any]] = {} # add new jobs to flow flow_dict = dict(flow_dict) - # flow_dict["jobs"].extend(new_flow.job_uuids) - flow_updates = {"$push": {"jobs": {"$each": new_flow.job_uuids}}} + flow_updates: dict[str, dict[str, Any]] = { + "$push": {"jobs": {"$each": new_flow.job_uuids}} + } # add new jobs jobs_list = list(new_flow.iterflow()) diff --git a/src/jobflow_remote/jobs/state.py b/src/jobflow_remote/jobs/state.py index 0e6c72ff..63f8ff45 100644 --- a/src/jobflow_remote/jobs/state.py +++ b/src/jobflow_remote/jobs/state.py @@ -41,6 +41,8 @@ def short_value(self) -> str: JobState.PAUSED: "P", JobState.STOPPED: "ST", JobState.CANCELLED: "CA", + JobState.BATCH_SUBMITTED: "BS", + JobState.BATCH_RUNNING: "BR", }