Skip to content

Commit

Permalink
Merge pull request #3290 from shnizzedy/fix/gantt-chart
Browse files Browse the repository at this point in the history
FIX: Restore generate_gantt_chart functionality
  • Loading branch information
effigies authored Nov 18, 2024
2 parents 5dc8701 + 7223914 commit 2e36f69
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 17 deletions.
1 change: 1 addition & 0 deletions nipype/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def get_nipype_gitversion():

TESTS_REQUIRES = [
"coverage >= 5.2.1",
"pandas >= 1.5.0",
"pytest >= 6",
"pytest-cov >=2.11",
"pytest-env",
Expand Down
53 changes: 51 additions & 2 deletions nipype/pipeline/plugins/tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Tests for workflow callbacks
"""
"""Tests for workflow callbacks."""
from pathlib import Path
from time import sleep
import json
import pytest
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as pe
Expand Down Expand Up @@ -60,3 +61,51 @@ def test_callback_exception(tmpdir, plugin, stop_on_first_crash):

sleep(0.5) # Wait for callback to be called (python 2.7)
assert so.statuses == [("f_node", "start"), ("f_node", "exception")]


@pytest.mark.parametrize("plugin", ["Linear", "MultiProc", "LegacyMultiProc"])
def test_callback_gantt(tmp_path: Path, plugin: str) -> None:
import logging

from os import path

from nipype.utils.profiler import log_nodes_cb
from nipype.utils.draw_gantt_chart import generate_gantt_chart

log_filename = tmp_path / "callback.log"
logger = logging.getLogger("callback")
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler(log_filename)
logger.addHandler(handler)

# create workflow
wf = pe.Workflow(name="test", base_dir=str(tmp_path))
f_node = pe.Node(
niu.Function(function=func, input_names=[], output_names=[]), name="f_node"
)
wf.add_nodes([f_node])
wf.config["execution"] = {"crashdump_dir": wf.base_dir, "poll_sleep_duration": 2}

plugin_args = {"status_callback": log_nodes_cb}
if plugin != "Linear":
plugin_args["n_procs"] = 8
wf.run(plugin=plugin, plugin_args=plugin_args)

with open(log_filename, "r") as _f:
loglines = _f.readlines()

# test missing duration
first_line = json.loads(loglines[0])
if "duration" in first_line:
del first_line["duration"]
loglines[0] = f"{json.dumps(first_line)}\n"

# test duplicate timestamp warning
loglines.append(loglines[-1])

with open(log_filename, "w") as _f:
_f.write("".join(loglines))

with pytest.warns(Warning):
generate_gantt_chart(str(log_filename), 1 if plugin == "Linear" else 8)
assert (tmp_path / "callback.log.html").exists()
74 changes: 59 additions & 15 deletions nipype/utils/draw_gantt_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import random
import datetime
import simplejson as json
from typing import Union

from collections import OrderedDict
from warnings import warn

# Pandas
try:
Expand Down Expand Up @@ -66,9 +68,9 @@ def create_event_dict(start_time, nodes_list):
finish_delta = (node["finish"] - start_time).total_seconds()

# Populate dictionary
if events.get(start_delta) or events.get(finish_delta):
if events.get(start_delta):
err_msg = "Event logged twice or events started at exact same time!"
raise KeyError(err_msg)
warn(err_msg, category=Warning)
events[start_delta] = start_node
events[finish_delta] = finish_node

Expand Down Expand Up @@ -101,15 +103,25 @@ def log_to_dict(logfile):

nodes_list = [json.loads(l) for l in lines]

def _convert_string_to_datetime(datestring):
try:
def _convert_string_to_datetime(
datestring: Union[str, datetime.datetime],
) -> datetime.datetime:
"""Convert a date string to a datetime object."""
if isinstance(datestring, datetime.datetime):
datetime_object = datestring
elif isinstance(datestring, str):
date_format = (
"%Y-%m-%dT%H:%M:%S.%f%z"
if "+" in datestring
else "%Y-%m-%dT%H:%M:%S.%f"
)
datetime_object: datetime.datetime = datetime.datetime.strptime(
datestring, "%Y-%m-%dT%H:%M:%S.%f"
datestring, date_format
)
return datetime_object
except Exception as _:
pass
return datestring
else:
msg = f"{datestring} is not a string or datetime object."
raise TypeError(msg)
return datetime_object

date_object_node_list: list = list()
for n in nodes_list:
Expand Down Expand Up @@ -154,12 +166,18 @@ def calculate_resource_timeseries(events, resource):
# Iterate through the events
for _, event in sorted(events.items()):
if event["event"] == "start":
if resource in event and event[resource] != "Unknown":
all_res += float(event[resource])
if resource in event:
try:
all_res += float(event[resource])
except ValueError:
continue
current_time = event["start"]
elif event["event"] == "finish":
if resource in event and event[resource] != "Unknown":
all_res -= float(event[resource])
if resource in event:
try:
all_res -= float(event[resource])
except ValueError:
continue
current_time = event["finish"]
res[current_time] = all_res

Expand Down Expand Up @@ -284,7 +302,14 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co
# Left
left = 60
for core in range(len(end_times)):
if end_times[core] < node_start:
try:
end_time_condition = end_times[core] < node_start
except TypeError:
# if one has a timezone and one does not
end_time_condition = end_times[core].replace(
tzinfo=None
) < node_start.replace(tzinfo=None)
if end_time_condition:
left += core * 30
end_times[core] = datetime.datetime(
node_finish.year,
Expand All @@ -307,7 +332,7 @@ def draw_nodes(start, nodes_list, cores, minute_scale, space_between_minutes, co
"offset": offset,
"scale_duration": scale_duration,
"color": color,
"node_name": node["name"],
"node_name": node.get("name", node.get("id", "")),
"node_dur": node["duration"] / 60.0,
"node_start": node_start.strftime("%Y-%m-%d %H:%M:%S"),
"node_finish": node_finish.strftime("%Y-%m-%d %H:%M:%S"),
Expand Down Expand Up @@ -527,6 +552,25 @@ def generate_gantt_chart(
# Read in json-log to get list of node dicts
nodes_list = log_to_dict(logfile)

# Only include nodes with timing information, and convert timestamps
# from strings to datetimes
nodes_list = [
{
k: (
datetime.datetime.strptime(i[k], "%Y-%m-%dT%H:%M:%S.%f")
if k in {"start", "finish"} and isinstance(i[k], str)
else i[k]
)
for k in i
}
for i in nodes_list
if "start" in i and "finish" in i
]

for node in nodes_list:
if "duration" not in node:
node["duration"] = (node["finish"] - node["start"]).total_seconds()

# Create the header of the report with useful information
start_node = nodes_list[0]
last_node = nodes_list[-1]
Expand Down

0 comments on commit 2e36f69

Please sign in to comment.