Skip to content

Commit

Permalink
2375 convergence plots (#2399)
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc authored Nov 21, 2024
2 parents 700e3e8 + 1b8b2e9 commit d056e76
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#2375: Add Convergence Plot in AsyncTaskDialog Displayed During Reconstruction

13 changes: 9 additions & 4 deletions mantidimaging/core/reconstruct/cil_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def __init__(self, verbose=1, progress: Progress | None = None) -> None:

def __call__(self, algo: Algorithm) -> None:
if self.progress:
self.progress.update(steps=1,
msg=f'CIL: Iteration {self.iteration_count } of {algo.max_iteration}'
f': Objective {algo.get_last_objective():.2f}',
force_continue=False)
extra_info = {'iterations': algo.iterations, 'losses': algo.loss}
self.progress.update(
steps=1,
msg=f'CIL: Iteration {self.iteration_count } of {algo.max_iteration}'
f': Objective {algo.get_last_objective():.2f}',
force_continue=False,
extra_info=extra_info,
)
self.iteration_count += 1


Expand Down Expand Up @@ -407,6 +411,7 @@ def full(images: ImageStack,
LOG.info(f'Reconstructed 3D volume with shape: {volume.shape}')
t1 = time.perf_counter()
LOG.info(f"full reconstruction time: {t1-t0}s for shape {images.data.shape}")
ImageStack(volume).metadata['convergence'] = {'iterations': algo.iterations, 'losses': algo.loss}
return ImageStack(volume)


Expand Down
11 changes: 8 additions & 3 deletions mantidimaging/core/utility/progress_reporting/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from mantidimaging.core.utility.memory_usage import get_memory_usage_linux_str

ProgressHistory = NamedTuple('ProgressHistory', [('time', float), ('step', int), ('msg', str)])
ProgressHistory = NamedTuple('ProgressHistory', [('time', float), ('step', int), ('msg', str),
('extra_info', dict | None)])


class ProgressHandler:
Expand Down Expand Up @@ -167,7 +168,11 @@ def _format_time(t: SupportsInt) -> str:
t = int(t)
return f'{t // 3600:02}:{t % 3600 // 60:02}:{t % 60:02}'

def update(self, steps: int = 1, msg: str = "", force_continue: bool = False) -> None:
def update(self,
steps: int = 1,
msg: str = "",
force_continue: bool = False,
extra_info: dict | None = None) -> None:
"""
Updates the progress of the task.
Expand All @@ -188,7 +193,7 @@ def update(self, steps: int = 1, msg: str = "", force_continue: bool = False) ->

msg = f"{f'{msg}' if len(msg) > 0 else ''} | {self.current_step}/{self.end_step} | " \
f"Time: {self._format_time(self.execution_time())}, ETA: {self._format_time(eta)}"
step_details = ProgressHistory(time.perf_counter(), self.current_step, msg)
step_details = ProgressHistory(time.perf_counter(), self.current_step, msg, extra_info)
self.progress_history.append(step_details)

# process progress callbacks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,20 @@ def test_format_time(self):
def test_calculate_mean_time(self):
progress_history = []

progress_history.append(ProgressHistory(100, 0, ""))
progress_history.append(ProgressHistory(100, 0, "", None))
self.assertEqual(Progress.calculate_mean_time(progress_history), 0)

# first step 5 seconds
progress_history.append(ProgressHistory(105, 1, ""))
progress_history.append(ProgressHistory(105, 1, "", None))
self.assertEqual(Progress.calculate_mean_time(progress_history), 5)

# second step 10 seconds
progress_history.append(ProgressHistory(115, 2, ""))
progress_history.append(ProgressHistory(115, 2, "", None))
self.assertEqual(Progress.calculate_mean_time(progress_history), 7.5)

for i in range(1, 50):
# add many 2 second steps
progress_history.append(ProgressHistory(115 + (i * 2), 2 + (i * 2), ""))
progress_history.append(ProgressHistory(115 + (i * 2), 2 + (i * 2), "", None))
self.assertEqual(Progress.calculate_mean_time(progress_history), 2)


Expand Down
11 changes: 11 additions & 0 deletions mantidimaging/gui/dialogs/async_task/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ class Notification(Enum):

class AsyncTaskDialogPresenter(QObject, ProgressHandler):
progress_updated = pyqtSignal(float, str)
progress_plot_updated = pyqtSignal(list, list)

def __init__(self, view):
super().__init__()

self.view = view
self.progress_updated.connect(self.view.set_progress)
self.progress_plot_updated.connect(self.view.set_progress_plot)

self.model = AsyncTaskDialogModel()
self.model.task_done.connect(self.view.handle_completion)
Expand Down Expand Up @@ -62,10 +64,19 @@ def do_start_processing(self) -> None:
def task_is_running(self) -> bool:
return self.model.task_is_running

def update_progress_plot(self, iterations: list, losses: list) -> None:
y = [a[0] for a in losses]
self.progress_plot_updated.emit(iterations, y)

def progress_update(self) -> None:
msg = self.progress.last_status_message()
progress_info = self.progress.progress_history
extra_info = progress_info[-1].extra_info
self.progress_updated.emit(self.progress.completion(), msg if msg is not None else '')

if extra_info:
self.update_progress_plot(extra_info['iterations'], extra_info['losses'])

def show_stop_button(self, show: bool = False) -> None:
self.view.show_cancel_button(show)

Expand Down
10 changes: 10 additions & 0 deletions mantidimaging/gui/dialogs/async_task/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any
from collections.abc import Callable
from pyqtgraph import PlotWidget

from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.mvp_base import BaseDialogView
Expand All @@ -23,6 +24,11 @@ def __init__(self, parent: QMainWindow):

self.progressBar.setMinimum(0)
self.progressBar.setMaximum(1000)
self.progress_plot = PlotWidget()
self.PlotVerticalLayout.addWidget(self.progress_plot)
self.progress_plot.hide()
self.progress_plot.setLogMode(y=True)
self.progress_plot.setMinimumHeight(300)

self.show_timer = QTimer(self)
self.cancelButton.clicked.connect(self.presenter.stop_progress)
Expand Down Expand Up @@ -63,6 +69,10 @@ def set_progress(self, progress: float, message: str):
# Update progress bar
self.progressBar.setValue(int(progress * 1000))

def set_progress_plot(self, x: list, y: list):
self.progress_plot.show()
self.progress_plot.plotItem.plot(x, y)

def show_delayed(self, timeout) -> None:
self.show_timer.singleShot(timeout, self.show_from_timer)
self.show_timer.start()
Expand Down
53 changes: 28 additions & 25 deletions mantidimaging/gui/ui/async_task_dialog.ui
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,36 @@
<string>Progress</string>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="QLabel" name="infoText">
<property name="text">
<string>Progress</string>
</property>
</widget>
</item>
<item>
<widget class="QProgressBar" name="progressBar">
<property name="value">
<number>0</number>
</property>
<property name="textVisible">
<bool>true</bool>
</property>
<property name="invertedAppearance">
<bool>false</bool>
</property>
</widget>
<item>
<widget class="QPushButton" name="cancelButton">
<property name="text">
<widget class="QLabel" name="infoText">
<property name="text">
<string>Progress</string>
</property>
</widget>
</item>
<item>
<widget class="QProgressBar" name="progressBar">
<property name="value">
<number>0</number>
</property>
<property name="textVisible">
<bool>true</bool>
</property>
<property name="invertedAppearance">
<bool>false</bool>
</property>
</widget>
</item>
<item>
<layout class="QVBoxLayout" name="PlotVerticalLayout"/>
</item>
<item>
<widget class="QPushButton" name="cancelButton">
<property name="text">
<string>Cancel</string>
</property>
</widget>
</item>
</item>
</property>
</widget>
</item>
</layout>
</widget>
<resources/>
Expand Down

0 comments on commit d056e76

Please sign in to comment.