Skip to content

Commit

Permalink
fix dynamic flow. Add graph utils
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Nov 27, 2023
1 parent 4c81cde commit 79a05de
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 47 deletions.
59 changes: 58 additions & 1 deletion src/jobflow_remote/cli/flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Annotated, Optional

import typer
from jobflow.utils.graph import draw_graph
from rich.prompt import Confirm
from rich.text import Text

Expand Down Expand Up @@ -35,6 +38,7 @@
loading_spinner,
out_console,
)
from jobflow_remote.jobs.graph import get_graph

app_flow = JFRTyper(
name="flow", help="Commands for managing the flows", no_args_is_help=True
Expand Down Expand Up @@ -68,7 +72,7 @@ def flows_list(

start_date = get_start_date(start_date, days, hours)

sort = [(sort.query_field, 1 if reverse_sort else -1)]
sort = [(sort.value, 1 if reverse_sort else -1)]

with loading_spinner():
flows_info = jc.get_flows_info(
Expand Down Expand Up @@ -185,3 +189,56 @@ def flow_info(
exit_with_error_msg("No data matching the request")

out_console.print(format_flow_info(flows_info[0]))


@app_flow.command()
def graph(
flow_db_id: flow_db_id_arg,
job_id_flag: job_flow_id_flag_opt = False,
label: Annotated[
Optional[str],
typer.Option(
"--label",
"-l",
help="The label used to identify the nodes",
),
] = "name",
file_path: Annotated[
Optional[str],
typer.Option(
"--path",
"-p",
help="If defined, the graph will be dumped to a file",
),
] = None,
):
"""
Provide detailed information on a Flow
"""
db_id, jf_id = get_job_db_ids(flow_db_id, None)
db_ids = job_ids = flow_ids = None
if db_id is not None:
db_ids = [db_id]
elif job_id_flag:
job_ids = [jf_id]
else:
flow_ids = [jf_id]

with loading_spinner():
jc = get_job_controller()

flows_info = jc.get_flows_info(
job_ids=job_ids,
db_ids=db_ids,
flow_ids=flow_ids,
limit=1,
full=True,
)
if not flows_info:
exit_with_error_msg("No data matching the request")

plt = draw_graph(get_graph(flows_info[0], label=label))
if file_path:
plt.savefig(file_path)
else:
plt.show()
2 changes: 1 addition & 1 deletion src/jobflow_remote/cli/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_flow_info_table(flows_info: list[FlowInfo], verbosity: int):
table.add_column("Job states")

for fi in flows_info:
# show the smallest fw_id as db_id
# show the smallest Job db_id as db_id
db_id = min(fi.db_ids)

row = [
Expand Down
59 changes: 56 additions & 3 deletions src/jobflow_remote/cli/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import io
from pathlib import Path
from typing import Optional

import typer
from monty.json import jsanitize
from monty.serialization import dumpfn
from qtoolkit.core.data_objects import QResources
from rich.pretty import pprint
from typing_extensions import Annotated

from jobflow_remote import SETTINGS
Expand Down Expand Up @@ -91,7 +93,7 @@ def jobs_list(

start_date = get_start_date(start_date, days, hours)

sort = [(sort.query_field, 1 if reverse_sort else -1)]
sort = [(sort.value, 1 if reverse_sort else -1)]

with loading_spinner():
if custom_query:
Expand Down Expand Up @@ -725,8 +727,6 @@ def resources(

@app_job.command(name="dump", hidden=True)
def job_dump(
job_db_id: job_db_id_arg = None,
job_index: job_index_arg = None,
job_id: job_ids_indexes_opt = None,
db_id: db_ids_opt = None,
flow_id: flow_ids_opt = None,
Expand Down Expand Up @@ -776,3 +776,56 @@ def job_dump(

if not jobs_doc:
exit_with_error_msg("No data matching the request")


@app_job.command()
def output(
job_db_id: job_db_id_arg,
job_index: job_index_arg = None,
file_path: Annotated[
Optional[str],
typer.Option(
"--path",
"-p",
help="If defined, the output will be dumped to this file based on the extension (json or yaml)",
),
] = None,
load: Annotated[
bool,
typer.Option(
"--load",
"-",
help="If enabled all the data from additional stores are also loaded ",
),
] = False,
):
"""
Detail information on a specific job
"""

db_id, job_id = get_job_db_ids(job_db_id, job_index)

with loading_spinner():
jc = get_job_controller()

if db_id:
job_info = jc.get_job_info(
job_id=job_id,
job_index=job_index,
db_id=db_id,
)
if job_info:
job_id = job_info.uuid
job_index = job_info.index

job_output = None
if job_id:
job_output = jc.jobstore.get_output(job_id, job_index or "last", load=load)

if not job_output:
exit_with_error_msg("No data matching the request")

if file_path:
dumpfn(job_output, file_path)
else:
pprint(job_output)
6 changes: 0 additions & 6 deletions src/jobflow_remote/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ class SortOption(Enum):
UPDATED_ON = "updated_on"
DB_ID = "db_id"

@property
def query_field(self) -> str:
if self == SortOption.DB_ID:
return "fw_id"
return self.value


class SerializeFileFormat(Enum):
JSON = "json"
Expand Down
56 changes: 47 additions & 9 deletions src/jobflow_remote/jobs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ class FlowDoc(BaseModel):
# be parents of the job with index=i+1, but will not be parents of
# the job with index i.
# index is stored as string, since mongodb needs string keys
# This dictionary include {job uuid: {job index: [parent's uuids]}}
parents: dict[str, dict[str, list[str]]] = Field(default_factory=dict)
# ids correspond to db_id, uuid, index for each JobDoc
ids: list[tuple[int, str, int]] = Field(default_factory=list)
# jobs_states: dict[str, FlowState]

def as_db_dict(self):
d = jsanitize(
Expand Down Expand Up @@ -270,23 +270,37 @@ class FlowInfo(BaseModel):
workers: list[str]
job_states: list[JobState]
job_names: list[str]
parents: list[list[str]]

@classmethod
def from_query_dict(cls, d):
updated_on = d["updated_on"]
flow_id = d["uuid"]

db_ids, job_ids, job_indexes = list(zip(*d["ids"]))

jobs_data = d.get("jobs_list") or []

workers = []
job_states = []
job_names = []
for job_doc in jobs_data:
job_names.append(job_doc["job"]["name"])
state = job_doc["state"]
job_states.append(JobState(state))
workers.append(job_doc["worker"])
parents = []

if jobs_data:
db_ids = []
job_ids = []
job_indexes = []
for job_doc in jobs_data:
db_ids.append(job_doc["db_id"])
job_ids.append(job_doc["uuid"])
job_indexes.append(job_doc["index"])
job_names.append(job_doc["job"]["name"])
state = job_doc["state"]
job_states.append(JobState(state))
workers.append(job_doc["worker"])
parents.append(job_doc["parents"] or [])
else:
db_ids, job_ids, job_indexes = list(zip(*d["ids"]))
# parents could be determined in this case as well from the Flow document.
# However, to match the correct order it would require lopping over them.
# To keep the generation faster add this only if a use case shows up.

state = FlowState(d["state"])

Expand All @@ -301,8 +315,32 @@ def from_query_dict(cls, d):
workers=workers,
job_states=job_states,
job_names=job_names,
parents=parents,
)

@cached_property
def ids_mapping(self) -> dict[str, dict[int, int]]:
d: dict = defaultdict(dict)

for db_id, job_id, index in zip(self.db_ids, self.job_ids, self.job_indexes):
d[job_id][int(index)] = db_id

return dict(d)

def iter_job_prop(self):
n_jobs = len(self.job_ids)
for i in range(n_jobs):
d = {
"db_id": self.db_ids[i],
"uuid": self.job_ids[i],
"index": self.job_indexes[i],
}
if self.job_names:
d["name"] = self.job_names[i]
d["state"] = self.job_states[i]
d["parents"] = self.parents[i]
yield d


class DynamicResponseType(Enum):
REPLACE = "replace"
Expand Down
32 changes: 32 additions & 0 deletions src/jobflow_remote/jobs/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from networkx import DiGraph

from jobflow_remote.jobs.data import FlowInfo


def get_graph(flow: FlowInfo, label: str = "name") -> DiGraph:
import networkx as nx

graph = nx.DiGraph()

ids_mapping = flow.ids_mapping

# Add nodes
for job_prop in flow.iter_job_prop():
db_id = job_prop["db_id"]
job_prop["label"] = job_prop[label]
# change this as the "name" is used in jobflow's graph plotting util
job_prop["job_name"] = job_prop.pop("name")
graph.add_node(db_id, **job_prop)

# Add edges based on parents
for child_node, parents in zip(flow.db_ids, flow.parents):
for parent_uuid in parents:
for parent_node in ids_mapping[parent_uuid].values():
graph.add_edge(parent_node, child_node)

return graph
Loading

0 comments on commit 79a05de

Please sign in to comment.