From d3ce8c2235600aa9617659e844913e3ed9fd2f6a Mon Sep 17 00:00:00 2001 From: stevehenke <91344068+stevehenke@users.noreply.github.com> Date: Wed, 4 Sep 2024 20:23:55 -0500 Subject: [PATCH] S31workflow (#96) - prototype VSPI algorithm for ptycho-enhanced XRF - update two-step ptycho-enhanced XRF algorithms to use coordinate conversions from object API - extract reconstructor API - implement S31 batch reconstruction workflow - add ptychoshelves product file reader --- ptychodus/__main__.py | 2 +- ptychodus/api/object.py | 47 +++--- ptychodus/api/workflow.py | 29 ++-- ptychodus/controller/automation.py | 11 +- ptychodus/controller/core.py | 23 +-- ptychodus/controller/memory.py | 19 +-- ptychodus/controller/probe/exposure.py | 2 +- ptychodus/controller/probe/fluorescence.py | 11 +- ptychodus/controller/probe/stxm.py | 2 +- ptychodus/model/analysis/core.py | 3 +- ptychodus/model/analysis/fluorescence.py | 151 ++++++++++++++++-- ptychodus/model/analysis/settings.py | 1 + ptychodus/model/automation/core.py | 25 ++- ptychodus/model/automation/watcher.py | 4 +- ptychodus/model/automation/workflow.py | 9 +- ptychodus/model/core.py | 14 +- ptychodus/model/patterns/builder.py | 2 +- ptychodus/model/patterns/core.py | 14 +- ptychodus/model/patterns/io.py | 8 +- ptychodus/model/product/api.py | 15 ++ ptychodus/model/product/object/item.py | 1 + ptychodus/model/reconstructor/__init__.py | 2 + ptychodus/model/reconstructor/api.py | 145 +++++++++++++++++ ptychodus/model/reconstructor/core.py | 7 +- ptychodus/model/reconstructor/presenter.py | 136 +++------------- ptychodus/model/tike/reconstructor.py | 15 +- ptychodus/model/workflow/api.py | 56 +++++-- ptychodus/model/workflow/core.py | 6 +- ptychodus/plugins/ptychoShelvesProductFile.py | 139 ++++++++++++++++ ptychodus/plugins/upscaling.py | 68 ++++---- ptychodus/plugins/workflow.py | 123 ++++++++++++-- ptychodus/ptychodus_bdp.py | 4 +- ptychodus/view/object.py | 145 +---------------- ptychodus/view/probe.py | 148 ++++++++++++++++- 34 files changed, 920 insertions(+), 467 deletions(-) create mode 100644 ptychodus/model/reconstructor/api.py create mode 100644 ptychodus/plugins/ptychoShelvesProductFile.py diff --git a/ptychodus/__main__.py b/ptychodus/__main__.py index 0c0f0266..663a11f9 100644 --- a/ptychodus/__main__.py +++ b/ptychodus/__main__.py @@ -103,7 +103,7 @@ def main() -> int: view = ViewCore.createInstance(parsedArgs.dev) from ptychodus.controller import ControllerCore - controller = ControllerCore.createInstance(model, view) + controller = ControllerCore(model, view) controller.showMainWindow(versionString()) return app.exec() diff --git a/ptychodus/api/object.py b/ptychodus/api/object.py index 624ff372..b84b1267 100644 --- a/ptychodus/api/object.py +++ b/ptychodus/api/object.py @@ -8,12 +8,19 @@ import numpy import numpy.typing -from .geometry import ImageExtent, PixelGeometry, Point2D +from .geometry import ImageExtent, PixelGeometry from .scan import ScanPoint ObjectArrayType: TypeAlias = numpy.typing.NDArray[numpy.complexfloating[Any, Any]] +@dataclass(frozen=True) +class ObjectPoint: + index: int + positionXInPixels: float + positionYInPixels: float + + @dataclass(frozen=True) class ObjectGeometry: widthInPixels: int @@ -39,29 +46,33 @@ def minimumXInMeters(self) -> float: def minimumYInMeters(self) -> float: return self.centerYInMeters - self.heightInMeters / 2. - @property - def _radiusX(self) -> float: - return self.widthInPixels / 2 - - @property - def _radiusY(self) -> float: - return self.heightInPixels / 2 - def getPixelGeometry(self) -> PixelGeometry: return PixelGeometry( widthInMeters=self.pixelWidthInMeters, heightInMeters=self.pixelHeightInMeters, ) - def mapObjectPointToScanPoint(self, objectPoint: Point2D) -> Point2D: - x = self.centerXInMeters + self.pixelWidthInMeters * (objectPoint.x - self._radiusX) - y = self.centerYInMeters + self.pixelHeightInMeters * (objectPoint.y - self._radiusY) - return Point2D(x, y) + def mapObjectPointToScanPoint(self, point: ObjectPoint) -> ScanPoint: + rx_px = self.widthInPixels / 2 + ry_px = self.heightInPixels / 2 + dx_m = self.pixelWidthInMeters + dy_m = self.pixelHeightInMeters + + x_m = self.centerXInMeters + dx_m * (point.positionXInPixels - rx_px) + y_m = self.centerYInMeters + dy_m * (point.positionYInPixels - ry_px) + + return ScanPoint(point.index, x_m, y_m) + + def mapScanPointToObjectPoint(self, point: ScanPoint) -> ObjectPoint: + rx_px = self.widthInPixels / 2 + ry_px = self.heightInPixels / 2 + dx_m = self.pixelWidthInMeters + dy_m = self.pixelHeightInMeters + + x_px = (point.positionXInMeters - self.centerXInMeters) / dx_m + rx_px + y_px = (point.positionYInMeters - self.centerYInMeters) / dy_m + ry_px - def mapScanPointToObjectPoint(self, scanPoint: Point2D) -> Point2D: - x = (scanPoint.x - self.centerXInMeters) / self.pixelWidthInMeters + self._radiusX - y = (scanPoint.y - self.centerYInMeters) / self.pixelHeightInMeters + self._radiusY - return Point2D(x, y) + return ObjectPoint(point.index, x_px, y_px) def contains(self, geometry: ObjectGeometry) -> bool: dx = self.centerXInMeters - geometry.centerXInMeters @@ -109,7 +120,7 @@ def __init__(self, expectedLayers = self.numberOfLayers actualLayers = len(self._layerDistanceInMeters) - if actualLayers != expectedLayers: + if actualLayers < expectedLayers: raise ValueError(f'Expected {expectedLayers} layer distances; got {actualLayers}!') self._pixelWidthInMeters = pixelWidthInMeters diff --git a/ptychodus/api/workflow.py b/ptychodus/api/workflow.py index 5110bddf..0b6f57f2 100644 --- a/ptychodus/api/workflow.py +++ b/ptychodus/api/workflow.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import Mapping from pathlib import Path @@ -15,7 +16,9 @@ def openScan(self, filePath: Path, *, fileType: str | None = None) -> None: pass @abstractmethod - def buildScan(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: + def buildScan(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: pass @abstractmethod @@ -23,11 +26,9 @@ def openProbe(self, filePath: Path, *, fileType: str | None = None) -> None: pass @abstractmethod - def buildProbe(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: - pass - - @abstractmethod - def buildProbeFromSettings(self) -> None: + def buildProbe(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: pass @abstractmethod @@ -35,15 +36,17 @@ def openObject(self, filePath: Path, *, fileType: str | None = None) -> None: pass @abstractmethod - def buildObject(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: + def buildObject(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: pass @abstractmethod - def buildObjectFromSettings(self) -> None: + def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: pass @abstractmethod - def reconstruct(self) -> None: + def reconstructRemote(self) -> None: pass @abstractmethod @@ -103,8 +106,14 @@ def saveSettings(self, class FileBasedWorkflow(ABC): + @property + @abstractmethod + def isWatchRecursive(self) -> bool: + '''indicates whether the data directory must be watched recursively''' + pass + @abstractmethod - def getFilePattern(self) -> str: + def getWatchFilePattern(self) -> str: '''UNIX-style filename pattern. For rules see fnmatch from Python standard library.''' pass diff --git a/ptychodus/controller/automation.py b/ptychodus/controller/automation.py index 988a5e86..253e97fd 100644 --- a/ptychodus/controller/automation.py +++ b/ptychodus/controller/automation.py @@ -155,7 +155,8 @@ def __init__(self, core: AutomationCore, presenter: AutomationPresenter, self._processingPresenter = processingPresenter self._listModel = AutomationProcessingListModel(processingPresenter) self._view = view - self._timer = QTimer() + self._executeWaitingTasksTimer = QTimer() + self._automationTimer = QTimer() @classmethod def createInstance(cls, core: AutomationCore, presenter: AutomationPresenter, @@ -174,8 +175,12 @@ def createInstance(cls, core: AutomationCore, presenter: AutomationPresenter, view.clearButton.clicked.connect(presenter.clearDatasetRepository) controller._syncModelToView() - controller._timer.timeout.connect(core.executeWaitingTasks) - controller._timer.start(60 * 1000) # TODO customize (in milliseconds) + + controller._executeWaitingTasksTimer.timeout.connect(core.executeWaitingTasks) + controller._executeWaitingTasksTimer.start(60 * 1000) # TODO customize (in milliseconds) + + controller._automationTimer.timeout.connect(core.refreshDatasetRepository) + controller._automationTimer.start(10 * 1000) # TODO customize (in milliseconds) return controller diff --git a/ptychodus/controller/core.py b/ptychodus/controller/core.py index cf9ff8d0..91df2822 100644 --- a/ptychodus/controller/core.py +++ b/ptychodus/controller/core.py @@ -26,8 +26,7 @@ class ControllerCore: def __init__(self, model: ModelCore, view: ViewCore) -> None: self.view = view - self._memoryController = MemoryController.createInstance(model.memoryPresenter, - view.memoryProgressBar) + self._memoryController = MemoryController(model.memoryPresenter, view.memoryProgressBar) self._fileDialogFactory = FileDialogFactory() self._ptychonnViewControllerFactory = PtychoNNViewControllerFactory( model.ptychonnReconstructorLibrary, self._fileDialogFactory) @@ -85,27 +84,15 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: self._automationController = AutomationController.createInstance( model._automationCore, model.automationPresenter, model.automationProcessingPresenter, view.automationView, self._fileDialogFactory) - self._refreshDataTimer = QTimer() - self._automationTimer = QTimer() - self._processMessagesTimer = QTimer() - @classmethod - def createInstance(cls, model: ModelCore, view: ViewCore) -> ControllerCore: - controller = cls(model, view) + self._refreshDataTimer = QTimer() + self._refreshDataTimer.timeout.connect(model.refreshActiveDataset) + self._refreshDataTimer.start(1000) # TODO make configurable view.navigationActionGroup.triggered.connect( - lambda action: controller.swapCentralWidgets(action)) - + lambda action: self.swapCentralWidgets(action)) view.workflowAction.setVisible(model.areWorkflowsSupported) - controller._refreshDataTimer.timeout.connect(model.refreshActiveDataset) - controller._refreshDataTimer.start(1000) # TODO make configurable - - controller._automationTimer.timeout.connect(model.refreshAutomationDatasets) - controller._automationTimer.start(1000) # TODO make configurable - - return controller - def showMainWindow(self, windowTitle: str) -> None: self.view.setWindowTitle(windowTitle) self.view.show() diff --git a/ptychodus/controller/memory.py b/ptychodus/controller/memory.py index 1ca07ff7..ef758928 100644 --- a/ptychodus/controller/memory.py +++ b/ptychodus/controller/memory.py @@ -8,23 +8,14 @@ class MemoryController: - def __init__(self, presenter: MemoryPresenter, progressBar: QProgressBar, - timer: QTimer) -> None: + def __init__(self, presenter: MemoryPresenter, progressBar: QProgressBar) -> None: self._presenter = presenter self._progressBar = progressBar - self._timer = timer + self._timer = QTimer() + self._timer.timeout.connect(self._updateProgressBar) - @classmethod - def createInstance(cls, presenter: MemoryPresenter, - progressBar: QProgressBar) -> MemoryController: - timer = QTimer() - controller = cls(presenter, progressBar, timer) - controller._updateProgressBar() - - timer.timeout.connect(controller._updateProgressBar) - timer.start(10 * 1000) # TODO customize (in milliseconds) - - return controller + self._updateProgressBar() + self._timer.start(10 * 1000) # TODO customize (in milliseconds) def _updateProgressBar(self) -> None: stats = self._presenter.getStatistics() diff --git a/ptychodus/controller/probe/exposure.py b/ptychodus/controller/probe/exposure.py index a886c56d..ae090c82 100644 --- a/ptychodus/controller/probe/exposure.py +++ b/ptychodus/controller/probe/exposure.py @@ -2,7 +2,7 @@ from ...model.analysis import ExposureAnalyzer, ExposureMap from ...model.visualization import VisualizationEngine -from ...view.object import ExposureDialog +from ...view.probe import ExposureDialog from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..visualization import VisualizationParametersController, VisualizationWidgetController diff --git a/ptychodus/controller/probe/fluorescence.py b/ptychodus/controller/probe/fluorescence.py index f7901e36..ca950764 100644 --- a/ptychodus/controller/probe/fluorescence.py +++ b/ptychodus/controller/probe/fluorescence.py @@ -7,7 +7,7 @@ from ...model.analysis import FluorescenceEnhancer from ...model.visualization import VisualizationEngine -from ...view.object import FluorescenceDialog +from ...view.probe import FluorescenceDialog from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..visualization import VisualizationParametersController, VisualizationWidgetController @@ -40,6 +40,8 @@ def __init__(self, enhancer: FluorescenceEnhancer, engine: VisualizationEngine, self._engine = engine self._fileDialogFactory = fileDialogFactory self._dialog = FluorescenceDialog() + self._enhancementModel = QStringListModel() + self._enhancementModel.setStringList(self._enhancer.getEnhancementStrategyList()) self._upscalingModel = QStringListModel() self._upscalingModel.setStringList(self._enhancer.getUpscalingStrategyList()) self._deconvolutionModel = QStringListModel() @@ -49,6 +51,11 @@ def __init__(self, enhancer: FluorescenceEnhancer, engine: VisualizationEngine, self._dialog.fluorescenceParametersView.openButton.clicked.connect( self._openMeasuredDataset) + self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setModel( + self._enhancementModel) + self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.textActivated.connect( + enhancer.setEnhancementStrategy) + self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setModel( self._upscalingModel) self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.textActivated.connect( @@ -127,6 +134,8 @@ def _saveEnhancedDataset(self) -> None: ExceptionDialog.showException(title, err) def _syncModelToView(self) -> None: + self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setCurrentText( + self._enhancer.getEnhancementStrategy()) self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setCurrentText( self._enhancer.getUpscalingStrategy()) self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.setCurrentText( diff --git a/ptychodus/controller/probe/stxm.py b/ptychodus/controller/probe/stxm.py index 270067fc..7b21efaf 100644 --- a/ptychodus/controller/probe/stxm.py +++ b/ptychodus/controller/probe/stxm.py @@ -4,7 +4,7 @@ from ...model.analysis import STXMSimulator from ...model.visualization import VisualizationEngine -from ...view.object import STXMDialog +from ...view.probe import STXMDialog from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..visualization import VisualizationParametersController, VisualizationWidgetController diff --git a/ptychodus/model/analysis/core.py b/ptychodus/model/analysis/core.py index 32f8d8ee..a0f40bbb 100644 --- a/ptychodus/model/analysis/core.py +++ b/ptychodus/model/analysis/core.py @@ -39,7 +39,8 @@ def __init__(self, settingsRegistry: SettingsRegistry, upscalingStrategyChooser, deconvolutionStrategyChooser, fluorescenceFileReaderChooser, - fluorescenceFileWriterChooser) + fluorescenceFileWriterChooser, + settingsRegistry) self.fluorescenceVisualizationEngine = VisualizationEngine(isComplex=False) self.xmcdAnalyzer = XMCDAnalyzer(objectRepository) self.xmcdVisualizationEngine = VisualizationEngine(isComplex=False) diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index bbc7b03e..54aba065 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -1,14 +1,21 @@ from __future__ import annotations from collections.abc import Sequence from pathlib import Path +from typing import Final import logging +from scipy.sparse.linalg import gmres, LinearOperator +import math +import numpy + from ptychodus.api.fluorescence import (DeconvolutionStrategy, ElementMap, FluorescenceDataset, FluorescenceFileReader, FluorescenceFileWriter, UpscalingStrategy) from ptychodus.api.geometry import PixelGeometry from ptychodus.api.observer import Observable, Observer from ptychodus.api.plugins import PluginChooser +from ptychodus.api.product import Product +from ptychodus.api.typing import RealArrayType from ..reconstructor import DiffractionPatternPositionMatcher from .settings import FluorescenceSettings @@ -16,14 +23,97 @@ logger = logging.getLogger(__name__) -class FluorescenceEnhancer(Observable, Observer): +def get_axis_weights_and_indexes(xmin_o: float, dx_o: float, xmin_p: float, dx_p: float, + N_p: int) -> tuple[Sequence[float], Sequence[int]]: + weight: list[float] = [] + index: list[int] = [] + + x_l = xmin_p + n_o = math.ceil((x_l - xmin_o) / dx_o) + + for n_p in range(N_p): + x_p = xmin_p + (n_p + 1) * dx_p + + while True: + x_o = xmin_o + n_o + dx_o + + if x_o >= x_p: + break + + weight.append((x_o - x_l) / dx_p) + index.append(n_o) + + n_o += 1 + x_l = x_o + + weight.append((x_p - x_l) / dx_p) + index.append(n_o) + x_l = x_p + + if x_o == x_p: + n_o += 1 + + return weight, index + + +class VSPILinearOperator(LinearOperator): + + def __init__(self, product: Product, xrf_nchannels: int) -> None: + ''' + M: number of XRF positions + N: number of ptychography object pixels + P: number of XRF channels + + A[M,N] * X[N,P] = B[M,P] + ''' + super().__init__(float, (len(product.scan), xrf_nchannels)) + self._product = product + + def matmat(self, X: RealArrayType) -> RealArrayType: + AX = numpy.zeros(self.shape, dtype=self.dtype) + + probeGeometry = self._product.probe.getGeometry() + dx_p_m = probeGeometry.pixelWidthInMeters + dy_p_m = probeGeometry.pixelHeightInMeters - def __init__(self, settings: FluorescenceSettings, - dataMatcher: DiffractionPatternPositionMatcher, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fileReaderChooser: PluginChooser[FluorescenceFileReader], - fileWriterChooser: PluginChooser[FluorescenceFileWriter]) -> None: + objectGeometry = self._product.object_.getGeometry() + objectShape = objectGeometry.heightInPixels, objectGeometry.widthInPixels + xmin_o_m = objectGeometry.minimumXInMeters + ymin_o_m = objectGeometry.minimumYInMeters + dx_o_m = objectGeometry.pixelWidthInMeters + dy_o_m = objectGeometry.pixelHeightInMeters + + for index, point in enumerate(self._product.scan): + xmin_p_m = point.positionXInMeters - probeGeometry.widthInMeters / 2 + ymin_p_m = point.positionYInMeters - probeGeometry.heightInMeters / 2 + + wx, ix = get_axis_weights_and_indexes(xmin_o_m, dx_o_m, xmin_p_m, dx_p_m, + probeGeometry.widthInPixels) + wy, iy = get_axis_weights_and_indexes(ymin_o_m, dy_o_m, ymin_p_m, dy_p_m, + probeGeometry.heightInPixels) + + IY, IX = numpy.meshgrid(iy, ix) + i_nz = numpy.ravel_multi_index(list(zip(IY.flat, IX.flat)), objectShape) + X_nz = X.take(i_nz, axis=0) + + AX[index, :] = numpy.matmul(numpy.outer(wy, wx).ravel(), X_nz) + + return AX + + +class FluorescenceEnhancer(Observable, Observer): + VSPI: Final[str] = 'Virtual Single Pixel Imaging' + TWO_STEP: Final[str] = 'Upscale and Deconvolve' + + def __init__( + self, + settings: FluorescenceSettings, + dataMatcher: DiffractionPatternPositionMatcher, # FIXME match XRF too + upscalingStrategyChooser: PluginChooser[UpscalingStrategy], + deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], + fileReaderChooser: PluginChooser[FluorescenceFileReader], + fileWriterChooser: PluginChooser[FluorescenceFileWriter], + reinitObservable: Observable) -> None: super().__init__() self._settings = settings self._dataMatcher = dataMatcher @@ -31,6 +121,7 @@ def __init__(self, settings: FluorescenceSettings, self._deconvolutionStrategyChooser = deconvolutionStrategyChooser self._fileReaderChooser = fileReaderChooser self._fileWriterChooser = fileWriterChooser + self._reinitObservable = reinitObservable self._productIndex = -1 self._measured: FluorescenceDataset | None = None @@ -42,6 +133,7 @@ def __init__(self, settings: FluorescenceSettings, deconvolutionStrategyChooser.setCurrentPluginByName(settings.deconvolutionStrategy.value) fileReaderChooser.setCurrentPluginByName(settings.fileType.value) fileWriterChooser.setCurrentPluginByName(settings.fileType.value) + reinitObservable.addObserver(self) def setProduct(self, productIndex: int) -> None: if self._productIndex != productIndex: @@ -89,6 +181,15 @@ def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: return self._measured.element_maps[channelIndex] + def getEnhancementStrategyList(self) -> Sequence[str]: + return [self.VSPI, self.TWO_STEP] + + def getEnhancementStrategy(self) -> str: + return self.VSPI if self._settings.useVSPI.value else self.TWO_STEP + + def setEnhancementStrategy(self, name: str) -> None: + self._settings.useVSPI.value = (name.casefold() == self.VSPI.casefold()) + def getUpscalingStrategyList(self) -> Sequence[str]: return self._upscalingStrategyChooser.getDisplayNameList() @@ -114,14 +215,29 @@ def enhanceFluorescence(self) -> None: reconstructInput = self._dataMatcher.matchDiffractionPatternsWithPositions( self._productIndex) element_maps: list[ElementMap] = list() - upscaler = self._upscalingStrategyChooser.currentPlugin.strategy - deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy - for emap in self._measured.element_maps: - logger.debug(f'Processing \"{emap.name}\"') - emap_upscaled = upscaler(emap, reconstructInput.product) - emap_enhanced = deconvolver(emap_upscaled, reconstructInput.product) - element_maps.append(emap_enhanced) + if self._settings.useVSPI.value: + measured_emaps = self._measured.element_maps + A = VSPILinearOperator(reconstructInput.product, len(measured_emaps)) + B = numpy.stack([b.counts_per_second.flatten() for b in measured_emaps]).T + X, info = gmres(A, B, atol=1e-5) # TODO expose atol + + if info != 0: + logger.warning(f'Convergence to tolerance not achieved! {info=}') + + for m_emap, e_cps in zip(measured_emaps, X.T): + e_emap = ElementMap(m_emap.name, e_cps.reshape(m_emap.counts_per_second.shape)) + element_maps.append(e_emap) + + else: + upscaler = self._upscalingStrategyChooser.currentPlugin.strategy + deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy + + for emap in self._measured.element_maps: + logger.info(f'Enhancing \"{emap.name}\"') + emap_upscaled = upscaler(emap, reconstructInput.product) + emap_enhanced = deconvolver(emap_upscaled, reconstructInput.product) + element_maps.append(emap_enhanced) self._enhanced = FluorescenceDataset( element_maps=element_maps, @@ -155,8 +271,13 @@ def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: writer = self._fileWriterChooser.currentPlugin.strategy writer.write(filePath, self._enhanced) + def _openFluorescenceFileFromSettings(self) -> None: + self.openMeasuredDataset(self._settings.filePath.value, self._settings.fileType.value) + def update(self, observable: Observable) -> None: - if observable is self._upscalingStrategyChooser: + if observable is self._reinitObservable: + self._openFluorescenceFileFromSettings() + elif observable is self._upscalingStrategyChooser: strategy = self._upscalingStrategyChooser.currentPlugin.simpleName self._settings.upscalingStrategy.value = strategy self.notifyObservers() diff --git a/ptychodus/model/analysis/settings.py b/ptychodus/model/analysis/settings.py index f7638a0f..cec6cab3 100644 --- a/ptychodus/model/analysis/settings.py +++ b/ptychodus/model/analysis/settings.py @@ -32,6 +32,7 @@ def __init__(self, registry: SettingsRegistry) -> None: self.filePath = self._settingsGroup.createPathEntry('FilePath', Path('/path/to/dataset.h5')) self.fileType = self._settingsGroup.createStringEntry('FileType', 'XRF-Maps') + self.useVSPI = self._settingsGroup.createBooleanEntry('UseVSPI', True) self.upscalingStrategy = self._settingsGroup.createStringEntry( 'UpscalingStrategy', 'Linear') self.deconvolutionStrategy = self._settingsGroup.createStringEntry( diff --git a/ptychodus/model/automation/core.py b/ptychodus/model/automation/core.py index c6b7efd8..54d1d48e 100644 --- a/ptychodus/model/automation/core.py +++ b/ptychodus/model/automation/core.py @@ -1,5 +1,5 @@ from __future__ import annotations -from collections.abc import Generator, Sequence +from collections.abc import Sequence from pathlib import Path import queue @@ -59,10 +59,11 @@ def setProcessingIntervalInSeconds(self, value: int) -> None: def loadExistingDatasetsToRepository(self) -> None: dataDirectory = self.getDataDirectory() - scanFileGlob: Generator[Path, None, None] = \ - dataDirectory.glob(self._workflow.getFilePattern()) + pattern = '**/' if self._workflow.isWatchRecursive else '' + pattern += self._workflow.getWatchFilePattern() + scanFileList = sorted(scanFile for scanFile in dataDirectory.glob(pattern)) - for scanFile in scanFileGlob: + for scanFile in scanFileList: self._datasetBuffer.put(scanFile) def clearDatasetRepository(self) -> None: @@ -109,13 +110,8 @@ def __init__(self, settings: AutomationSettings, repository: AutomationDatasetRe self._repository = repository self._processor = processor - @classmethod - def createInstance(cls, settings: AutomationSettings, repository: AutomationDatasetRepository, - processor: AutomationDatasetProcessor) -> AutomationProcessingPresenter: - presenter = cls(settings, repository, processor) - settings.addObserver(presenter) - repository.addObserver(presenter) - return presenter + settings.addObserver(self) + repository.addObserver(self) def getDatasetLabel(self, index: int) -> str: return self._repository.getLabel(index) @@ -158,12 +154,15 @@ def __init__(self, settingsRegistry: SettingsRegistry, workflowAPI: WorkflowAPI, self._watcher = DataDirectoryWatcher(self._settings, self._workflow, self._datasetBuffer) self.presenter = AutomationPresenter(self._settings, self._workflow, self._watcher, self._datasetBuffer, self.repository) - self.processingPresenter = AutomationProcessingPresenter.createInstance( - self._settings, self.repository, self._processor) + self.processingPresenter = AutomationProcessingPresenter(self._settings, self.repository, + self._processor) def start(self) -> None: self._datasetBuffer.start() + def refreshDatasetRepository(self) -> None: + self.repository.notifyObserversIfRepositoryChanged() + def executeWaitingTasks(self) -> None: self._processor.runOnce() diff --git a/ptychodus/model/automation/watcher.py b/ptychodus/model/automation/watcher.py index 235f06cb..11383fdd 100644 --- a/ptychodus/model/automation/watcher.py +++ b/ptychodus/model/automation/watcher.py @@ -28,7 +28,7 @@ def on_created_or_modified(self, event: watchdog.events.FileSystemEvent) -> None if not event.is_directory: srcPath = Path(event.src_path) - if srcPath.match(self._workflow.getFilePattern()): + if srcPath.match(self._workflow.getWatchFilePattern()): self._datasetBuffer.put(srcPath) def on_created(self, event: watchdog.events.FileSystemEvent) -> None: @@ -63,7 +63,7 @@ def _updateWatch(self) -> None: observedWatch = self._observer.schedule( event_handler=DataDirectoryEventHandler(self._workflow, self._datasetBuffer), path=dataDirectory, - recursive=False, # TODO generalize + recursive=self._workflow.isWatchRecursive, ) logger.debug(observedWatch) else: diff --git a/ptychodus/model/automation/workflow.py b/ptychodus/model/automation/workflow.py index a04bf715..de521a47 100644 --- a/ptychodus/model/automation/workflow.py +++ b/ptychodus/model/automation/workflow.py @@ -29,9 +29,14 @@ def setWorkflow(self, name: str) -> None: self._workflowChooser.setCurrentPluginByName(name) self._settings.strategy.value = self._workflowChooser.currentPlugin.simpleName - def getFilePattern(self) -> str: + @property + def isWatchRecursive(self) -> bool: workflow = self._workflowChooser.currentPlugin.strategy - return workflow.getFilePattern() + return workflow.isWatchRecursive + + def getWatchFilePattern(self) -> str: + workflow = self._workflowChooser.currentPlugin.strategy + return workflow.getWatchFilePattern() def execute(self, api: WorkflowAPI, filePath: Path) -> None: workflow = self._workflowChooser.currentPlugin.strategy diff --git a/ptychodus/model/core.py b/ptychodus/model/core.py index ef1887b1..19706c9e 100644 --- a/ptychodus/model/core.py +++ b/ptychodus/model/core.py @@ -109,7 +109,8 @@ def __init__(self, self._pluginRegistry.fluorescenceFileWriters) self._workflowCore = WorkflowCore(self.settingsRegistry, self._patternsCore.patternsAPI, self._productCore.productAPI, self._productCore.scanAPI, - self._productCore.probeAPI, self._productCore.objectAPI) + self._productCore.probeAPI, self._productCore.objectAPI, + self._reconstructorCore.reconstructorAPI) self._automationCore = AutomationCore(self.settingsRegistry, self._workflowCore.workflowAPI, self._pluginRegistry.fileBasedWorkflows) @@ -213,9 +214,6 @@ def getDiffractionPatternAssemblyQueueSize(self) -> int: def refreshActiveDataset(self) -> None: self._patternsCore.dataset.notifyObserversIfDatasetChanged() - def refreshAutomationDatasets(self) -> None: - self._automationCore.repository.notifyObserversIfRepositoryChanged() - def batchModeExecute(self, action: str, inputFilePath: Path, outputFilePath: Path) -> int: # TODO add enum for actions; implement using workflow API inputProductIndex = self._productCore.productAPI.openProduct(inputFilePath, fileType='NPZ') @@ -226,7 +224,7 @@ def batchModeExecute(self, action: str, inputFilePath: Path, outputFilePath: Pat if action.lower() == 'reconstruct': outputProductName = self._productCore.productAPI.getItemName(inputProductIndex) - outputProductIndex = self._reconstructorCore.presenter.reconstruct( + outputProductIndex = self._reconstructorCore.reconstructorAPI.reconstruct( inputProductIndex, outputProductName) if outputProductIndex < 0: @@ -237,9 +235,9 @@ def batchModeExecute(self, action: str, inputFilePath: Path, outputFilePath: Pat outputFilePath, fileType='NPZ') elif action.lower() == 'train': - self._reconstructorCore.presenter.ingestTrainingData(inputProductIndex) - _ = self._reconstructorCore.presenter.train() - self._reconstructorCore.presenter.saveModel(outputFilePath) + self._reconstructorCore.reconstructorAPI.ingestTrainingData(inputProductIndex) + _ = self._reconstructorCore.reconstructorAPI.train() + self._reconstructorCore.reconstructorAPI.saveModel(outputFilePath) else: logger.error(f'Unknown batch mode action \"{action}\"!') return -1 diff --git a/ptychodus/model/patterns/builder.py b/ptychodus/model/patterns/builder.py index d04e3509..0634514a 100644 --- a/ptychodus/model/patterns/builder.py +++ b/ptychodus/model/patterns/builder.py @@ -48,7 +48,7 @@ def _getArrayAndAssemble(self) -> None: logger.exception('Error while assembling array!') def _assemble(self, array: DiffractionPatternArray) -> None: - logger.debug(f'Assembling {array.getLabel()}...') + logger.info(f'Assembling {array.getLabel()}...') try: data = array.getData() diff --git a/ptychodus/model/patterns/core.py b/ptychodus/model/patterns/core.py index fa58db48..84564e8a 100644 --- a/ptychodus/model/patterns/core.py +++ b/ptychodus/model/patterns/core.py @@ -46,13 +46,8 @@ def __init__(self, settings: PatternSettings, dataset: ActiveDiffractionDataset) self._settings = settings self._dataset = dataset - @classmethod - def createInstance(cls, settings: PatternSettings, - dataset: ActiveDiffractionDataset) -> DiffractionDatasetPresenter: - presenter = cls(settings, dataset) - settings.addObserver(presenter) - dataset.addObserver(presenter) - return presenter + settings.addObserver(self) + dataset.addObserver(self) def __iter__(self) -> Iterator[DiffractionPatternArrayPresenter]: for array in self._dataset: @@ -162,9 +157,8 @@ def __init__(self, settingsRegistry: SettingsRegistry, self.metadataPresenter = DiffractionMetadataPresenter(self.dataset, self.detector, self.patternSettings, self.productSettings) - self.datasetPresenter = DiffractionDatasetPresenter.createInstance( - self.patternSettings, self.dataset) - self.datasetInputOutputPresenter = DiffractionDatasetInputOutputPresenter.createInstance( + self.datasetPresenter = DiffractionDatasetPresenter(self.patternSettings, self.dataset) + self.datasetInputOutputPresenter = DiffractionDatasetInputOutputPresenter( self.patternSettings, self.dataset, self.patternsAPI, settingsRegistry) def start(self) -> None: diff --git a/ptychodus/model/patterns/io.py b/ptychodus/model/patterns/io.py index 7c8c3115..7a067fe4 100644 --- a/ptychodus/model/patterns/io.py +++ b/ptychodus/model/patterns/io.py @@ -22,13 +22,7 @@ def __init__(self, settings: PatternSettings, dataset: ActiveDiffractionDataset, self._patternsAPI = patternsAPI self._reinitObservable = reinitObservable - @classmethod - def createInstance(cls, settings: PatternSettings, dataset: ActiveDiffractionDataset, - patternsAPI: PatternsAPI, - reinitObservable: Observable) -> DiffractionDatasetInputOutputPresenter: - presenter = cls(settings, dataset, patternsAPI, reinitObservable) - reinitObservable.addObserver(presenter) - return presenter + reinitObservable.addObserver(self) def getOpenFileFilterList(self) -> Sequence[str]: return self._patternsAPI.getOpenFileFilterList() diff --git a/ptychodus/model/product/api.py b/ptychodus/model/product/api.py index 51a02d3e..9258148e 100644 --- a/ptychodus/model/product/api.py +++ b/ptychodus/model/product/api.py @@ -59,6 +59,21 @@ def buildScan(self, item.setBuilder(builder) + def buildScanFromSettings(self, index: int) -> None: + try: + item = self._repository[index] + except IndexError: + logger.warning(f'Failed to access item {index}!') + return + + try: + builder = self._builderFactory.createFromSettings() + except KeyError: + logger.warning('Failed to create builder from settings!') + return + + item.setBuilder(builder) + def getOpenFileFilterList(self) -> Sequence[str]: return self._builderFactory.getOpenFileFilterList() diff --git a/ptychodus/model/product/object/item.py b/ptychodus/model/product/object/item.py index bb65dd5c..85c6dcf7 100644 --- a/ptychodus/model/product/object/item.py +++ b/ptychodus/model/product/object/item.py @@ -81,6 +81,7 @@ def _rebuild(self) -> None: return self._object = object_ + self.layerDistanceInMeters.setValue(object_.layerDistanceInMeters) self.notifyObservers() def update(self, observable: Observable) -> None: diff --git a/ptychodus/model/reconstructor/__init__.py b/ptychodus/model/reconstructor/__init__.py index b5557422..f9db485b 100644 --- a/ptychodus/model/reconstructor/__init__.py +++ b/ptychodus/model/reconstructor/__init__.py @@ -1,9 +1,11 @@ +from .api import ReconstructorAPI from .core import ReconstructorCore from .matcher import DiffractionPatternPositionMatcher from .presenter import ReconstructorPresenter __all__ = [ 'DiffractionPatternPositionMatcher', + 'ReconstructorAPI', 'ReconstructorCore', 'ReconstructorPresenter', ] diff --git a/ptychodus/model/reconstructor/api.py b/ptychodus/model/reconstructor/api.py new file mode 100644 index 00000000..76a8bc7f --- /dev/null +++ b/ptychodus/model/reconstructor/api.py @@ -0,0 +1,145 @@ +from pathlib import Path +import logging +import time + +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.reconstructor import Reconstructor, TrainableReconstructor, TrainOutput + +from ..product import ProductRepository +from .matcher import DiffractionPatternPositionMatcher, ScanIndexFilter + +logger = logging.getLogger(__name__) + + +class ReconstructorAPI: + + def __init__(self, dataMatcher: DiffractionPatternPositionMatcher, + productRepository: ProductRepository, + reconstructorChooser: PluginChooser[Reconstructor]) -> None: + self._dataMatcher = dataMatcher + self._productRepository = productRepository + self._reconstructorChooser = reconstructorChooser + + def reconstruct(self, + inputProductIndex: int, + outputProductName: str, + indexFilter: ScanIndexFilter = ScanIndexFilter.ALL) -> int: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( + inputProductIndex, indexFilter) + + tic = time.perf_counter() + result = reconstructor.reconstruct(parameters) + toc = time.perf_counter() + logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') + + outputProductIndex = self._productRepository.insertProduct(result.product) + return outputProductIndex + + def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: + outputProductIndexOdd = self.reconstruct( + inputProductIndex, + f'{outputProductName}_odd', + ScanIndexFilter.ODD, + ) + outputProductIndexEven = self.reconstruct( + inputProductIndex, + f'{outputProductName}_even', + ScanIndexFilter.EVEN, + ) + + return outputProductIndexOdd, outputProductIndexEven + + def ingestTrainingData(self, inputProductIndex: int) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Preparing input data...') + tic = time.perf_counter() + parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( + inputProductIndex, ScanIndexFilter.ALL) + toc = time.perf_counter() + logger.info(f'Data preparation time {toc - tic:.4f} seconds.') + + logger.info('Ingesting...') + tic = time.perf_counter() + reconstructor.ingestTrainingData(parameters) + toc = time.perf_counter() + logger.info(f'Ingest time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') + + def openTrainingData(self, filePath: Path) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Opening training data...') + tic = time.perf_counter() + reconstructor.openTrainingData(filePath) + toc = time.perf_counter() + logger.info(f'Open time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') + + def saveTrainingData(self, filePath: Path) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Saving training data...') + tic = time.perf_counter() + reconstructor.saveTrainingData(filePath) + toc = time.perf_counter() + logger.info(f'Save time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') + + def train(self) -> TrainOutput: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + result = TrainOutput([], [], -1) + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Training...') + tic = time.perf_counter() + result = reconstructor.train() + toc = time.perf_counter() + logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') + else: + logger.warning('Reconstructor is not trainable!') + + return result + + def clearTrainingData(self) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Resetting...') + tic = time.perf_counter() + reconstructor.clearTrainingData() + toc = time.perf_counter() + logger.info(f'Reset time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') + + def openModel(self, filePath: Path) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Opening model...') + tic = time.perf_counter() + reconstructor.openModel(filePath) + toc = time.perf_counter() + logger.info(f'Open time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') + + def saveModel(self, filePath: Path) -> None: + reconstructor = self._reconstructorChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Saving model...') + tic = time.perf_counter() + reconstructor.saveModel(filePath) + toc = time.perf_counter() + logger.info(f'Save time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') diff --git a/ptychodus/model/reconstructor/core.py b/ptychodus/model/reconstructor/core.py index 8201a838..a7fa1e1c 100644 --- a/ptychodus/model/reconstructor/core.py +++ b/ptychodus/model/reconstructor/core.py @@ -7,6 +7,7 @@ from ..patterns import ActiveDiffractionDataset from ..product import ProductRepository +from .api import ReconstructorAPI from .matcher import DiffractionPatternPositionMatcher from .presenter import ReconstructorPresenter from .settings import ReconstructorSettings @@ -34,5 +35,7 @@ def __init__(self, settingsRegistry: SettingsRegistry, self._pluginChooser.registerPlugin(NullReconstructor('None'), displayName='None/None') self.dataMatcher = DiffractionPatternPositionMatcher(diffractionDataset, productRepository) - self.presenter = ReconstructorPresenter(self.settings, self.dataMatcher, productRepository, - self._pluginChooser, settingsRegistry) + self.reconstructorAPI = ReconstructorAPI(self.dataMatcher, productRepository, + self._pluginChooser) + self.presenter = ReconstructorPresenter(self.settings, self._pluginChooser, + self.reconstructorAPI, settingsRegistry) diff --git a/ptychodus/model/reconstructor/presenter.py b/ptychodus/model/reconstructor/presenter.py index 22521e54..eff95628 100644 --- a/ptychodus/model/reconstructor/presenter.py +++ b/ptychodus/model/reconstructor/presenter.py @@ -1,14 +1,13 @@ from collections.abc import Sequence from pathlib import Path import logging -import time from ptychodus.api.observer import Observable, Observer from ptychodus.api.plugins import PluginChooser from ptychodus.api.reconstructor import Reconstructor, TrainableReconstructor, TrainOutput -from ..product import ProductRepository -from .matcher import DiffractionPatternPositionMatcher, ScanIndexFilter +from .api import ReconstructorAPI +from .matcher import ScanIndexFilter from .settings import ReconstructorSettings logger = logging.getLogger(__name__) @@ -17,17 +16,15 @@ class ReconstructorPresenter(Observable, Observer): def __init__(self, settings: ReconstructorSettings, - dataMatcher: DiffractionPatternPositionMatcher, - productRepository: ProductRepository, reconstructorChooser: PluginChooser[Reconstructor], - reinitObservable: Observable) -> None: + reconstructorAPI: ReconstructorAPI, reinitObservable: Observable) -> None: super().__init__() self._settings = settings - self._dataMatcher = dataMatcher - self._productRepository = productRepository self._reconstructorChooser = reconstructorChooser + self._reconstructorAPI = reconstructorAPI self._reinitObservable = reinitObservable + reconstructorChooser.addObserver(self) reinitObservable.addObserver(self) self._syncFromSettings() @@ -39,41 +36,22 @@ def getReconstructor(self) -> str: def setReconstructor(self, name: str) -> None: self._reconstructorChooser.setCurrentPluginByName(name) - self._settings.algorithm.value = self._reconstructorChooser.currentPlugin.simpleName - self.notifyObservers() def _syncFromSettings(self) -> None: self.setReconstructor(self._settings.algorithm.value) + def _syncToSettings(self) -> None: + self._settings.algorithm.value = self._reconstructorChooser.currentPlugin.simpleName + def reconstruct(self, inputProductIndex: int, outputProductName: str, indexFilter: ScanIndexFilter = ScanIndexFilter.ALL) -> int: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( - inputProductIndex, indexFilter) - - tic = time.perf_counter() - result = reconstructor.reconstruct(parameters) - toc = time.perf_counter() - logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') - - outputProductIndex = self._productRepository.insertProduct(result.product) - return outputProductIndex + return self._reconstructorAPI.reconstruct(inputProductIndex, outputProductName, + indexFilter) def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: - outputProductIndexOdd = self.reconstruct( - inputProductIndex, - f'{outputProductName} - Odd', - ScanIndexFilter.ODD, - ) - outputProductIndexEven = self.reconstruct( - inputProductIndex, - f'{outputProductName} - Even', - ScanIndexFilter.EVEN, - ) - - return outputProductIndexOdd, outputProductIndexEven + return self._reconstructorAPI.reconstructSplit(inputProductIndex, outputProductName) @property def isTrainable(self) -> bool: @@ -81,23 +59,7 @@ def isTrainable(self) -> bool: return isinstance(reconstructor, TrainableReconstructor) def ingestTrainingData(self, inputProductIndex: int) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Preparing input data...') - tic = time.perf_counter() - parameters = self._dataMatcher.matchDiffractionPatternsWithPositions( - inputProductIndex, ScanIndexFilter.ALL) - toc = time.perf_counter() - logger.info(f'Data preparation time {toc - tic:.4f} seconds.') - - logger.info('Ingesting...') - tic = time.perf_counter() - reconstructor.ingestTrainingData(parameters) - toc = time.perf_counter() - logger.info(f'Ingest time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return self._reconstructorAPI.ingestTrainingData(inputProductIndex) def getOpenTrainingDataFileFilterList(self) -> Sequence[str]: reconstructor = self._reconstructorChooser.currentPlugin.strategy @@ -120,16 +82,7 @@ def getOpenTrainingDataFileFilter(self) -> str: return str() def openTrainingData(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening training data...') - tic = time.perf_counter() - reconstructor.openTrainingData(filePath) - toc = time.perf_counter() - logger.info(f'Open time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return self._reconstructorAPI.openTrainingData(filePath) def getSaveTrainingDataFileFilterList(self) -> Sequence[str]: reconstructor = self._reconstructorChooser.currentPlugin.strategy @@ -152,43 +105,13 @@ def getSaveTrainingDataFileFilter(self) -> str: return str() def saveTrainingData(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving training data...') - tic = time.perf_counter() - reconstructor.saveTrainingData(filePath) - toc = time.perf_counter() - logger.info(f'Save time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return self._reconstructorAPI.saveTrainingData(filePath) def train(self) -> TrainOutput: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - result = TrainOutput([], [], -1) - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Training...') - tic = time.perf_counter() - result = reconstructor.train() - toc = time.perf_counter() - logger.info(f'Training time {toc - tic:.4f} seconds. (code={result.result})') - else: - logger.warning('Reconstructor is not trainable!') - - return result + return self._reconstructorAPI.train() def clearTrainingData(self) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Resetting...') - tic = time.perf_counter() - reconstructor.clearTrainingData() - toc = time.perf_counter() - logger.info(f'Reset time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + self._reconstructorAPI.clearTrainingData() def getOpenModelFileFilterList(self) -> Sequence[str]: reconstructor = self._reconstructorChooser.currentPlugin.strategy @@ -211,16 +134,7 @@ def getOpenModelFileFilter(self) -> str: return str() def openModel(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Opening model...') - tic = time.perf_counter() - reconstructor.openModel(filePath) - toc = time.perf_counter() - logger.info(f'Open time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return self._reconstructorAPI.openModel(filePath) def getSaveModelFileFilterList(self) -> Sequence[str]: reconstructor = self._reconstructorChooser.currentPlugin.strategy @@ -243,17 +157,11 @@ def getSaveModelFileFilter(self) -> str: return str() def saveModel(self, filePath: Path) -> None: - reconstructor = self._reconstructorChooser.currentPlugin.strategy - - if isinstance(reconstructor, TrainableReconstructor): - logger.info('Saving model...') - tic = time.perf_counter() - reconstructor.saveModel(filePath) - toc = time.perf_counter() - logger.info(f'Save time {toc - tic:.4f} seconds.') - else: - logger.warning('Reconstructor is not trainable!') + return self._reconstructorAPI.saveModel(filePath) def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: + if observable is self._reconstructorChooser: + self._syncToSettings() + self.notifyObservers() + elif observable is self._reinitObservable: self._syncFromSettings() diff --git a/ptychodus/model/tike/reconstructor.py b/ptychodus/model/tike/reconstructor.py index 675c209b..79fa1f4e 100644 --- a/ptychodus/model/tike/reconstructor.py +++ b/ptychodus/model/tike/reconstructor.py @@ -8,8 +8,7 @@ import tike.ptycho -from ptychodus.api.geometry import Point2D -from ptychodus.api.object import Object +from ptychodus.api.object import Object, ObjectPoint from ptychodus.api.probe import Probe from ptychodus.api.product import Product from ptychodus.api.reconstructor import Reconstructor, ReconstructInput, ReconstructOutput @@ -128,10 +127,9 @@ def __call__(self, parameters: ReconstructInput, uy = -probeInputArray.shape[-2] / 2 for scanPoint in scanInput: - point = Point2D(scanPoint.positionXInMeters, scanPoint.positionYInMeters) - objectPoint = objectGeometry.mapScanPointToObjectPoint(point) - scanInputCoords.append(objectPoint.y + uy) - scanInputCoords.append(objectPoint.x + ux) + objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) + scanInputCoords.append(objectPoint.positionYInPixels + uy) + scanInputCoords.append(objectPoint.positionXInPixels + ux) scanInputArray = numpy.array( scanInputCoords, @@ -190,9 +188,8 @@ def __call__(self, parameters: ReconstructInput, scanOutputPoints: list[ScanPoint] = list() for uncorrectedPoint, xy in zip(scanInput, result.scan): - objectPoint = Point2D(x=xy[1] - ux, y=xy[0] - uy) - point = objectGeometry.mapObjectPointToScanPoint(objectPoint) - scanPoint = ScanPoint(uncorrectedPoint.index, point.x, point.y) + objectPoint = ObjectPoint(uncorrectedPoint.index, xy[1] - ux, xy[0] - uy) + scanPoint = objectGeometry.mapObjectPointToScanPoint(objectPoint) scanOutputPoints.append(scanPoint) scanOutput = Scan(scanOutputPoints) diff --git a/ptychodus/model/workflow/api.py b/ptychodus/model/workflow/api.py index b6a117e0..7688bb57 100644 --- a/ptychodus/model/workflow/api.py +++ b/ptychodus/model/workflow/api.py @@ -1,3 +1,4 @@ +from __future__ import annotations from collections.abc import Mapping from pathlib import Path from typing import Any @@ -10,6 +11,7 @@ from ..patterns import PatternsAPI from ..product import ObjectAPI, ProbeAPI, ProductAPI, ScanAPI +from ..reconstructor import ReconstructorAPI from .executor import WorkflowExecutor logger = logging.getLogger(__name__) @@ -18,39 +20,58 @@ class ConcreteWorkflowProductAPI(WorkflowProductAPI): def __init__(self, productAPI: ProductAPI, scanAPI: ScanAPI, probeAPI: ProbeAPI, - objectAPI: ObjectAPI, executor: WorkflowExecutor, productIndex: int) -> None: + objectAPI: ObjectAPI, reconstructorAPI: ReconstructorAPI, + executor: WorkflowExecutor, productIndex: int) -> None: self._productAPI = productAPI self._scanAPI = scanAPI self._probeAPI = probeAPI self._objectAPI = objectAPI + self._reconstructorAPI = reconstructorAPI self._executor = executor self._productIndex = productIndex def openScan(self, filePath: Path, *, fileType: str | None = None) -> None: self._scanAPI.openScan(self._productIndex, filePath, fileType=fileType) - def buildScan(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: - self._scanAPI.buildScan(self._productIndex, builderName, builderParameters) + def buildScan(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: + if builderName is None: + self._scanAPI.buildScanFromSettings(self._productIndex) + else: + self._scanAPI.buildScan(self._productIndex, builderName, builderParameters) def openProbe(self, filePath: Path, *, fileType: str | None = None) -> None: self._probeAPI.openProbe(self._productIndex, filePath, fileType=fileType) - def buildProbe(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: - self._probeAPI.buildProbe(self._productIndex, builderName, builderParameters) - - def buildProbeFromSettings(self) -> None: - self._probeAPI.buildProbeFromSettings(self._productIndex) + def buildProbe(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: + if builderName is None: + self._probeAPI.buildProbeFromSettings(self._productIndex) + else: + self._probeAPI.buildProbe(self._productIndex, builderName, builderParameters) def openObject(self, filePath: Path, *, fileType: str | None = None) -> None: self._objectAPI.openObject(self._productIndex, filePath, fileType=fileType) - def buildObject(self, builderName: str, builderParameters: Mapping[str, Any] = {}) -> None: - self._objectAPI.buildObject(self._productIndex, builderName, builderParameters) - - def buildObjectFromSettings(self) -> None: - self._objectAPI.buildObjectFromSettings(self._productIndex) + def buildObject(self, + builderName: str | None = None, + builderParameters: Mapping[str, Any] = {}) -> None: + if builderName is None: + self._objectAPI.buildObjectFromSettings(self._productIndex) + else: + self._objectAPI.buildObject(self._productIndex, builderName, builderParameters) + + def reconstructLocal(self, outputProductName: str) -> WorkflowProductAPI: + logger.debug(f'Reconstruct: index={self._productIndex}') + outputProductIndex = self._reconstructorAPI.reconstruct(self._productIndex, + outputProductName) + return ConcreteWorkflowProductAPI(self._productAPI, self._scanAPI, self._probeAPI, + self._objectAPI, self._reconstructorAPI, self._executor, + outputProductIndex) - def reconstruct(self) -> None: + def reconstructRemote(self) -> None: logger.debug(f'Execute Workflow: index={self._productIndex}') self._executor.runFlow(self._productIndex) @@ -62,13 +83,15 @@ class ConcreteWorkflowAPI(WorkflowAPI): def __init__(self, settingsRegistry: SettingsRegistry, patternsAPI: PatternsAPI, productAPI: ProductAPI, scanAPI: ScanAPI, probeAPI: ProbeAPI, - objectAPI: ObjectAPI, executor: WorkflowExecutor) -> None: + objectAPI: ObjectAPI, reconstructorAPI: ReconstructorAPI, + executor: WorkflowExecutor) -> None: self._settingsRegistry = settingsRegistry self._patternsAPI = patternsAPI self._productAPI = productAPI self._scanAPI = scanAPI self._probeAPI = probeAPI self._objectAPI = objectAPI + self._reconstructorAPI = reconstructorAPI self._executor = executor def _createProductAPI(self, productIndex: int) -> WorkflowProductAPI: @@ -76,7 +99,8 @@ def _createProductAPI(self, productIndex: int) -> WorkflowProductAPI: raise ValueError(f'Bad product index ({productIndex=})!') return ConcreteWorkflowProductAPI(self._productAPI, self._scanAPI, self._probeAPI, - self._objectAPI, self._executor, productIndex) + self._objectAPI, self._reconstructorAPI, self._executor, + productIndex) def openPatterns( self, diff --git a/ptychodus/model/workflow/core.py b/ptychodus/model/workflow/core.py index 6353b1a6..508d9a68 100644 --- a/ptychodus/model/workflow/core.py +++ b/ptychodus/model/workflow/core.py @@ -13,6 +13,7 @@ from ..patterns import PatternsAPI from ..product import ObjectAPI, ProbeAPI, ProductAPI, ScanAPI +from ..reconstructor import ReconstructorAPI from .api import ConcreteWorkflowAPI from .authorizer import WorkflowAuthorizer from .executor import WorkflowExecutor @@ -188,7 +189,7 @@ class WorkflowCore: def __init__(self, settingsRegistry: SettingsRegistry, patternsAPI: PatternsAPI, productAPI: ProductAPI, scanAPI: ScanAPI, probeAPI: ProbeAPI, - objectAPI: ObjectAPI) -> None: + objectAPI: ObjectAPI, reconstructorAPI: ReconstructorAPI) -> None: self._settings = WorkflowSettings(settingsRegistry) self._inputDataLocator = SimpleDataLocator.createInstance(self._settings.group, 'Input') self._computeDataLocator = SimpleDataLocator.createInstance(self._settings.group, @@ -201,7 +202,8 @@ def __init__(self, settingsRegistry: SettingsRegistry, patternsAPI: PatternsAPI, self._computeDataLocator, self._outputDataLocator, settingsRegistry, patternsAPI, productAPI) self.workflowAPI = ConcreteWorkflowAPI(settingsRegistry, patternsAPI, productAPI, scanAPI, - probeAPI, objectAPI, self._executor) + probeAPI, objectAPI, reconstructorAPI, + self._executor) self._thread: threading.Thread | None = None try: diff --git a/ptychodus/plugins/ptychoShelvesProductFile.py b/ptychodus/plugins/ptychoShelvesProductFile.py new file mode 100644 index 00000000..39e40a77 --- /dev/null +++ b/ptychodus/plugins/ptychoShelvesProductFile.py @@ -0,0 +1,139 @@ +from pathlib import Path +from typing import Final, Sequence + +import scipy.io + +from ptychodus.api.constants import ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ +from ptychodus.api.object import Object, ObjectArrayType, ObjectFileWriter +from ptychodus.api.plugins import PluginRegistry +from ptychodus.api.probe import Probe, ProbeFileWriter +from ptychodus.api.product import Product, ProductFileReader, ProductMetadata +from ptychodus.api.propagator import WavefieldArrayType +from ptychodus.api.scan import Scan, ScanPoint + + +class MATProductFileReader(ProductFileReader): + SIMPLE_NAME: Final[str] = 'PtychoShelves' + DISPLAY_NAME: Final[str] = 'PtychoShelves Files (*.mat)' + + def _load_probe_array(self, probeMatrix: WavefieldArrayType) -> WavefieldArrayType: + if probeMatrix.ndim == 4: + # probeMatrix[width, height, num_shared_modes, num_varying_modes] + # TODO support spatially varying probe modes + probeMatrix = probeMatrix[..., 0] + + if probeMatrix.ndim == 3: + # probeMatrix[width, height, num_shared_modes] + probeMatrix = probeMatrix + + return probeMatrix.transpose(2, 0, 1) + + def _load_object_array(self, objectMatrix: ObjectArrayType) -> ObjectArrayType: + if objectMatrix.ndim == 3: + # objectMatrix[width, height, num_layers] + objectMatrix = objectMatrix.transpose(2, 0, 1) + + return objectMatrix + + def read(self, filePath: Path) -> Product: + scanPointList: list[ScanPoint] = list() + + hc_eVm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S / ELECTRON_VOLT_J + matDict = scipy.io.loadmat(filePath, simplify_cells=True) + p_struct = matDict['p'] + probe_energy_eV = hc_eVm / p_struct['lambda'] + + metadata = ProductMetadata( + name=filePath.stem, + comments='', + detectorDistanceInMeters=0., # not included in file + probeEnergyInElectronVolts=probe_energy_eV, + probePhotonsPerSecond=0., # not included in file + exposureTimeInSeconds=0., # not included in file + ) + + dx_spec = p_struct['dx_spec'] + pixel_width_m = dx_spec[0] + pixel_height_m = dx_spec[1] + + outputs_struct = matDict['outputs'] + probe_positions = outputs_struct['probe_positions'] + + for idx, pos_px in enumerate(probe_positions): + point = ScanPoint( + idx, + pos_px[0] * pixel_width_m, + pos_px[1] * pixel_height_m, + ) + scanPointList.append(point) + + probe = Probe( + self._load_probe_array(matDict['probe']), + pixelWidthInMeters=pixel_width_m, + pixelHeightInMeters=pixel_height_m, + ) + + layer_distance_m: Sequence[float] | None = None + + try: + multi_slice_param = p_struct['multi_slice_param'] + except KeyError: + pass + else: + try: + z_distance = multi_slice_param['z_distance'] + except KeyError: + pass + else: + layer_distance_m = z_distance.tolist() + + object_ = Object( + self._load_object_array(matDict['object']), + layer_distance_m, + pixelWidthInMeters=pixel_width_m, + pixelHeightInMeters=pixel_height_m, + ) + costs = outputs_struct['fourier_error_out'] + + return Product( + metadata=metadata, + scan=Scan(scanPointList), + probe=probe, + object_=object_, + costs=costs, + ) + + +class MATObjectFileWriter(ObjectFileWriter): + + def write(self, filePath: Path, object_: Object) -> None: + array = object_.array + matDict = {'object': array.transpose(1, 2, 0)} + # TODO layer distance to p.z_distance + scipy.io.savemat(filePath, matDict) + + +class MATProbeFileWriter(ProbeFileWriter): + + def write(self, filePath: Path, probe: Probe) -> None: + array = probe.array + matDict = {'probe': array.transpose(1, 2, 0)} + scipy.io.savemat(filePath, matDict) + + +def registerPlugins(registry: PluginRegistry) -> None: + registry.productFileReaders.registerPlugin( + MATProductFileReader(), + simpleName=MATProductFileReader.SIMPLE_NAME, + displayName=MATProductFileReader.DISPLAY_NAME, + ) + registry.probeFileWriters.registerPlugin( + MATProbeFileWriter(), + simpleName=MATProductFileReader.SIMPLE_NAME, + displayName=MATProductFileReader.DISPLAY_NAME, + ) + registry.objectFileWriters.registerPlugin( + MATObjectFileWriter(), + simpleName=MATProductFileReader.SIMPLE_NAME, + displayName=MATProductFileReader.DISPLAY_NAME, + ) diff --git a/ptychodus/plugins/upscaling.py b/ptychodus/plugins/upscaling.py index 1d67a9eb..ef9caf06 100644 --- a/ptychodus/plugins/upscaling.py +++ b/ptychodus/plugins/upscaling.py @@ -2,36 +2,8 @@ import numpy from ptychodus.api.fluorescence import ElementMap, UpscalingStrategy -from ptychodus.api.object import Object from ptychodus.api.plugins import PluginRegistry from ptychodus.api.product import Product -from ptychodus.api.scan import Scan -from ptychodus.api.typing import RealArrayType - - -def _scan_to_array(scan: Scan) -> RealArrayType: - coords: list[float] = list() - - for point in scan: - coords.append(point.positionXInMeters) - coords.append(point.positionYInMeters) - - return numpy.reshape(coords, (-1, 2)) - - -def _object_coordinates(numberOfPixels: int, pixelSizeInMeters: float, - centerInMeters: float) -> RealArrayType: - positionInPixels = numpy.arange(numberOfPixels) - numberOfPixels / 2 - return centerInMeters + positionInPixels * pixelSizeInMeters - - -def _object_coordinates_yx(object_: Object) -> tuple[RealArrayType, RealArrayType]: - axisXInMeters = _object_coordinates(object_.widthInPixels, object_.pixelWidthInMeters, - object_.centerXInMeters) - axisYInMeters = _object_coordinates(object_.heightInPixels, object_.pixelHeightInMeters, - object_.centerYInMeters) - gridYInMeters, gridXInMeters = numpy.meshgrid(axisYInMeters, axisXInMeters) - return gridYInMeters, gridXInMeters class IdentityUpscaling(UpscalingStrategy): @@ -46,13 +18,22 @@ def __init__(self, method: str) -> None: self._method = method def __call__(self, emap: ElementMap, product: Product) -> ElementMap: - cps = griddata( - _scan_to_array(product.scan), - emap.counts_per_second.flat, - _object_coordinates_yx(product.object_), - method=self._method, - fill_value=0., - ) + objectGeometry = product.object_.getGeometry() + scanCoordinatesInPixels: list[float] = list() + + for scanPoint in product.scan: + objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) + scanCoordinatesInPixels.append(objectPoint.positionYInPixels) + scanCoordinatesInPixels.append(objectPoint.positionXInPixels) + + points = numpy.reshape(scanCoordinatesInPixels, (-1, 2)) + values = emap.counts_per_second.flat + YY, XX = numpy.mgrid[:objectGeometry.heightInPixels, :objectGeometry.widthInPixels] + query_points = numpy.transpose((YY.flat, XX.flat)) + + cps = griddata(points, values, query_points, method=self._method, + fill_value=0.).reshape(XX.shape) + return ElementMap(emap.name, cps.astype(emap.counts_per_second.dtype)) @@ -70,18 +51,25 @@ def __init__(self, self._degree = degree def __call__(self, emap: ElementMap, product: Product) -> ElementMap: + objectGeometry = product.object_.getGeometry() + scanCoordinatesInPixels: list[float] = list() + + for scanPoint in product.scan: + objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) + scanCoordinatesInPixels.append(objectPoint.positionYInPixels) + scanCoordinatesInPixels.append(objectPoint.positionXInPixels) + interpolator = RBFInterpolator( - _scan_to_array(product.scan), + numpy.reshape(scanCoordinatesInPixels, (-1, 2)), emap.counts_per_second.flat, kernel=self._kernel, neighbors=self._neighbors, epsilon=self._epsilon, degree=self._degree, ) - grid_y, grid_x = _object_coordinates_yx(product.object_) - cps = interpolator(numpy.transpose((grid_y.flat, grid_x.flat))) - return ElementMap(emap.name, - cps.astype(emap.counts_per_second.dtype).reshape(grid_x.shape)) + YY, XX = numpy.mgrid[:objectGeometry.heightInPixels, :objectGeometry.widthInPixels] + cps = interpolator(numpy.transpose((YY.flat, XX.flat))) + return ElementMap(emap.name, cps.astype(emap.counts_per_second.dtype).reshape(XX.shape)) def registerPlugins(registry: PluginRegistry) -> None: diff --git a/ptychodus/plugins/workflow.py b/ptychodus/plugins/workflow.py index 9dd06e7e..3dc08025 100644 --- a/ptychodus/plugins/workflow.py +++ b/ptychodus/plugins/workflow.py @@ -1,13 +1,24 @@ +from dataclasses import dataclass from pathlib import Path +import csv +import logging import re from ptychodus.api.plugins import PluginRegistry from ptychodus.api.workflow import FileBasedWorkflow, WorkflowAPI +logger = logging.getLogger(__name__) + +# FIXME plugin for loading products from file + class APS2IDFileBasedWorkflow(FileBasedWorkflow): - def getFilePattern(self) -> str: + @property + def isWatchRecursive(self) -> bool: + return False + + def getWatchFilePattern(self) -> str: return '*.csv' def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: @@ -18,14 +29,18 @@ def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: workflowAPI.openPatterns(diffractionFilePath, fileType='NeXus') productAPI = workflowAPI.createProduct(f'scan{scanID}') productAPI.openScan(filePath, fileType='CSV') - productAPI.buildProbe('Disk') - productAPI.buildObject('Random') - productAPI.reconstruct() + productAPI.buildProbe() + productAPI.buildObject() + productAPI.reconstructRemote() class APS26IDFileBasedWorkflow(FileBasedWorkflow): - def getFilePattern(self) -> str: + @property + def isWatchRecursive(self) -> bool: + return False + + def getWatchFilePattern(self) -> str: return '*.mda' def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: @@ -43,19 +58,107 @@ def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: workflowAPI.openPatterns(diffractionFilePath, fileType='HDF5') productAPI = workflowAPI.createProduct(f'scan_{scanID}') productAPI.openScan(filePath, fileType='MDA') - productAPI.buildProbe('Disk') - productAPI.buildObject('Random') - productAPI.reconstruct() + productAPI.buildProbe() + productAPI.buildObject() + productAPI.reconstructRemote() + + +@dataclass(frozen=True) +class APS31IDEMetadata: + scan_no: int + golden_angle: str + encoder_angle: str + measurement_id: str + subtomo_no: str + detector_position: str + label: str + + def __str__(self) -> str: + return f'''scan_no={self.scan_no} + golden_angle={self.golden_angle} + encoder_angle={self.encoder_angle} + measurement_id={self.measurement_id} + subtomo_no={self.subtomo_no} + detector_position={self.detector_position} + label={self.label} + ''' + + +class APS31IDEFileBasedWorkflow(FileBasedWorkflow): + + @property + def isWatchRecursive(self) -> bool: + return True + + def getWatchFilePattern(self) -> str: + return '*.h5' + + def execute(self, workflowAPI: WorkflowAPI, filePath: Path) -> None: + experimentDir = filePath.parents[3] + scan_no = int(re.findall(r'\d+', filePath.stem)[0]) + scanFile = experimentDir / 'scan_positions' / f'scan_{scan_no:05d}.dat' + scanNumbersFile = experimentDir / 'dat-files' / 'tomography_scannumbers.txt' + metadata: APS31IDEMetadata | None = None + + with scanNumbersFile.open(newline='') as csvFile: + csvReader = csv.reader(csvFile, delimiter=' ') + + for row in csvReader: + if row[0].startswith('#'): + continue + + if len(row) != 7: + logger.warning('Unexpected row in tomography_scannumbers.txt!') + logger.debug(row) + continue + + try: + row_no = int(row[0]) + except ValueError: + logger.warning('Failed to parse row ID in tomography_scannumbers.txt!') + logger.debug(row[0]) + continue + + if row_no == scan_no: + metadata = APS31IDEMetadata( + scan_no=scan_no, + golden_angle=str(row[1]), + encoder_angle=str(row[2]), + measurement_id=str(row[3]), + subtomo_no=str(row[4]), + detector_position=str(row[5]), + label=str(row[6]), + ) + break + + if metadata is None: + logger.warning(f'Failed to locate label for {row_no}!') + else: + productName = f'scan{scan_no:05d}_' + metadata.label + workflowAPI.openPatterns(filePath, fileType='LYNX') + inputProductAPI = workflowAPI.createProduct(productName, comments=str(metadata)) + inputProductAPI.openScan(scanFile, fileType='LYNXOrchestra') + inputProductAPI.buildProbe() + inputProductAPI.buildObject() + # TODO would prefer to write instructions and submit to queue + outputProductAPI = inputProductAPI.reconstructLocal(f'{productName}_out') + outputProductAPI.saveProduct(experimentDir / 'ptychodus' / f'{productName}.h5', + fileType='HDF5') def registerPlugins(registry: PluginRegistry) -> None: registry.fileBasedWorkflows.registerPlugin( APS2IDFileBasedWorkflow(), simpleName='APS_2ID', - displayName='LYNX Catalyst Particle', + displayName='APS 2-ID', ) registry.fileBasedWorkflows.registerPlugin( APS26IDFileBasedWorkflow(), simpleName='APS_26ID', - displayName='CNM/APS Hard X-Ray Nanoprobe', + displayName='APS 26-ID', + ) + registry.fileBasedWorkflows.registerPlugin( + APS31IDEFileBasedWorkflow(), + simpleName='APS_31IDE', + displayName='APS 31-ID-E', ) diff --git a/ptychodus/ptychodus_bdp.py b/ptychodus/ptychodus_bdp.py index 85c6f1c4..6ccb8e96 100755 --- a/ptychodus/ptychodus_bdp.py +++ b/ptychodus/ptychodus_bdp.py @@ -215,8 +215,8 @@ def main() -> int: exposureTimeInSeconds=args.exposure_time_s, ) workflowProductAPI.openScan(Path(args.scan_file_path.name)) - workflowProductAPI.buildProbeFromSettings() - workflowProductAPI.buildObjectFromSettings() + workflowProductAPI.buildProbe() + workflowProductAPI.buildObject() stagingDir = args.output_directory stagingDir.mkdir(parents=True, exist_ok=True) diff --git a/ptychodus/view/object.py b/ptychodus/view/object.py index fe8017c2..2498e33f 100644 --- a/ptychodus/view/object.py +++ b/ptychodus/view/object.py @@ -1,13 +1,11 @@ -from PyQt5.QtWidgets import (QCheckBox, QComboBox, QDialog, QFormLayout, QGridLayout, QGroupBox, - QHBoxLayout, QLabel, QListView, QPushButton, QRadioButton, QStatusBar, - QVBoxLayout, QWidget) +from PyQt5.QtWidgets import (QComboBox, QDialog, QFormLayout, QGridLayout, QGroupBox, QLabel, + QPushButton, QStatusBar, QVBoxLayout, QWidget) from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure from .visualization import VisualizationParametersView, VisualizationWidget -from .widgets import DecimalLineEdit class FourierRingCorrelationDialog(QDialog): @@ -38,145 +36,6 @@ def __init__(self, parent: QWidget | None = None) -> None: self.setLayout(layout) -class STXMDialog(QDialog): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.visualizationWidget = VisualizationWidget.createInstance('Transmission') - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.saveButton = QPushButton('Save') - self.statusBar = QStatusBar() - - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addStretch() - parameterLayout.addWidget(self.saveButton) - - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.visualizationWidget, 1) - contentsLayout.addLayout(parameterLayout) - - layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) - self.setLayout(layout) - - -class ExposureParametersView(QGroupBox): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__('Parameters', parent) - self.quantitativeProbeCheckBox = QCheckBox('Quantitative Probe') - self.photonFluxLineEdit = DecimalLineEdit.createInstance() - self.exposureTimeLineEdit = DecimalLineEdit.createInstance() - self.massAttenuationLabel = QLabel('Mass Attenuation [m\u00B2/kg]:') - self.massAttenuationLineEdit = DecimalLineEdit.createInstance() - - layout = QFormLayout() - layout.addRow(self.quantitativeProbeCheckBox) - layout.addRow('Photon Flux [ph/s]:', self.photonFluxLineEdit) - layout.addRow('Exposure Time [s]:', self.exposureTimeLineEdit) - layout.addRow(self.massAttenuationLabel) - layout.addRow(self.massAttenuationLineEdit) - self.setLayout(layout) - - -class ExposureQuantityView(QGroupBox): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__('Quantity', parent) - self.photonCountButton = QRadioButton('Photon Count') - self.photonFluxButton = QRadioButton('Photon Flux [Hz]') - self.exposureButton = QRadioButton('Exposure [J/m\u00B2]') - self.irradianceButton = QRadioButton('Irradiance [W/m\u00B2]') - self.doseButton = QRadioButton('Dose [Gy]') - self.doseRateButton = QRadioButton('Dose Rate [Gy/s]') - - layout = QVBoxLayout() - layout.addWidget(self.photonCountButton) - layout.addWidget(self.photonFluxButton) - layout.addWidget(self.exposureButton) - layout.addWidget(self.irradianceButton) - layout.addWidget(self.doseButton) - layout.addWidget(self.doseRateButton) - self.setLayout(layout) - - -class ExposureDialog(QDialog): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.visualizationWidget = VisualizationWidget.createInstance('Visualization') - self.exposureParametersView = ExposureParametersView() - self.exposureQuantityView = ExposureQuantityView() - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.saveButton = QPushButton('Save') - self.statusBar = QStatusBar() - - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.exposureParametersView) - parameterLayout.addWidget(self.exposureQuantityView) - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addWidget(self.saveButton) - parameterLayout.addStretch() - - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.visualizationWidget, 1) - contentsLayout.addLayout(parameterLayout) - - layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) - self.setLayout(layout) - - -class FluorescenceParametersView(QGroupBox): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__('Parameters', parent) - self.openButton = QPushButton('Open') - self.upscalingStrategyComboBox = QComboBox() - self.deconvolutionStrategyComboBox = QComboBox() - self.enhanceButton = QPushButton('Enhance') - self.saveButton = QPushButton('Save') - - layout = QFormLayout() - layout.addRow('Measured Dataset:', self.openButton) - layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) - layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) - layout.addRow(self.enhanceButton) - layout.addRow('Enhanced Dataset:', self.saveButton) - self.setLayout(layout) - - -class FluorescenceDialog(QDialog): - - def __init__(self, parent: QWidget | None = None) -> None: - super().__init__(parent) - self.measuredWidget = VisualizationWidget.createInstance('Measured') - self.enhancedWidget = VisualizationWidget.createInstance('Enhanced') - self.fluorescenceParametersView = FluorescenceParametersView() - self.fluorescenceChannelListView = QListView() - self.visualizationParametersView = VisualizationParametersView.createInstance() - self.statusBar = QStatusBar() - - parameterLayout = QVBoxLayout() - parameterLayout.addWidget(self.fluorescenceParametersView) - parameterLayout.addWidget(self.fluorescenceChannelListView, 1) - parameterLayout.addWidget(self.visualizationParametersView) - parameterLayout.addStretch() - - contentsLayout = QHBoxLayout() - contentsLayout.addWidget(self.measuredWidget, 1) - contentsLayout.addWidget(self.enhancedWidget, 1) - contentsLayout.addLayout(parameterLayout) - - layout = QVBoxLayout() - layout.addLayout(contentsLayout) - layout.addWidget(self.statusBar) - self.setLayout(layout) - - class XMCDParametersView(QGroupBox): def __init__(self, parent: QWidget | None = None) -> None: diff --git a/ptychodus/view/probe.py b/ptychodus/view/probe.py index 53048a1c..8515567e 100644 --- a/ptychodus/view/probe.py +++ b/ptychodus/view/probe.py @@ -1,9 +1,10 @@ from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import (QDialog, QFormLayout, QGridLayout, QGroupBox, QHBoxLayout, QLabel, - QPushButton, QSlider, QSpinBox, QStatusBar, QVBoxLayout, QWidget) +from PyQt5.QtWidgets import (QCheckBox, QComboBox, QDialog, QFormLayout, QGridLayout, QGroupBox, + QHBoxLayout, QLabel, QListView, QPushButton, QRadioButton, QSlider, + QSpinBox, QStatusBar, QVBoxLayout, QWidget) from .visualization import VisualizationParametersView, VisualizationWidget -from .widgets import LengthWidget +from .widgets import DecimalLineEdit, LengthWidget class ProbePropagationParametersView(QGroupBox): @@ -67,3 +68,144 @@ def __init__(self, parent: QWidget | None = None) -> None: layout.addLayout(contentsLayout) layout.addWidget(self.statusBar) self.setLayout(layout) + + +class STXMDialog(QDialog): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.visualizationWidget = VisualizationWidget.createInstance('Transmission') + self.visualizationParametersView = VisualizationParametersView.createInstance() + self.saveButton = QPushButton('Save') + self.statusBar = QStatusBar() + + parameterLayout = QVBoxLayout() + parameterLayout.addWidget(self.visualizationParametersView) + parameterLayout.addStretch() + parameterLayout.addWidget(self.saveButton) + + contentsLayout = QHBoxLayout() + contentsLayout.addWidget(self.visualizationWidget, 1) + contentsLayout.addLayout(parameterLayout) + + layout = QVBoxLayout() + layout.addLayout(contentsLayout) + layout.addWidget(self.statusBar) + self.setLayout(layout) + + +class ExposureParametersView(QGroupBox): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Parameters', parent) + self.quantitativeProbeCheckBox = QCheckBox('Quantitative Probe') + self.photonFluxLineEdit = DecimalLineEdit.createInstance() + self.exposureTimeLineEdit = DecimalLineEdit.createInstance() + self.massAttenuationLabel = QLabel('Mass Attenuation [m\u00B2/kg]:') + self.massAttenuationLineEdit = DecimalLineEdit.createInstance() + + layout = QFormLayout() + layout.addRow(self.quantitativeProbeCheckBox) + layout.addRow('Photon Flux [ph/s]:', self.photonFluxLineEdit) + layout.addRow('Exposure Time [s]:', self.exposureTimeLineEdit) + layout.addRow(self.massAttenuationLabel) + layout.addRow(self.massAttenuationLineEdit) + self.setLayout(layout) + + +class ExposureQuantityView(QGroupBox): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Quantity', parent) + self.photonCountButton = QRadioButton('Photon Count') + self.photonFluxButton = QRadioButton('Photon Flux [Hz]') + self.exposureButton = QRadioButton('Exposure [J/m\u00B2]') + self.irradianceButton = QRadioButton('Irradiance [W/m\u00B2]') + self.doseButton = QRadioButton('Dose [Gy]') + self.doseRateButton = QRadioButton('Dose Rate [Gy/s]') + + layout = QVBoxLayout() + layout.addWidget(self.photonCountButton) + layout.addWidget(self.photonFluxButton) + layout.addWidget(self.exposureButton) + layout.addWidget(self.irradianceButton) + layout.addWidget(self.doseButton) + layout.addWidget(self.doseRateButton) + self.setLayout(layout) + + +class ExposureDialog(QDialog): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.visualizationWidget = VisualizationWidget.createInstance('Visualization') + self.exposureParametersView = ExposureParametersView() + self.exposureQuantityView = ExposureQuantityView() + self.visualizationParametersView = VisualizationParametersView.createInstance() + self.saveButton = QPushButton('Save') + self.statusBar = QStatusBar() + + parameterLayout = QVBoxLayout() + parameterLayout.addWidget(self.exposureParametersView) + parameterLayout.addWidget(self.exposureQuantityView) + parameterLayout.addWidget(self.visualizationParametersView) + parameterLayout.addWidget(self.saveButton) + parameterLayout.addStretch() + + contentsLayout = QHBoxLayout() + contentsLayout.addWidget(self.visualizationWidget, 1) + contentsLayout.addLayout(parameterLayout) + + layout = QVBoxLayout() + layout.addLayout(contentsLayout) + layout.addWidget(self.statusBar) + self.setLayout(layout) + + +class FluorescenceParametersView(QGroupBox): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Parameters', parent) + self.openButton = QPushButton('Open') + self.enhancementStrategyComboBox = QComboBox() + self.upscalingStrategyComboBox = QComboBox() + self.deconvolutionStrategyComboBox = QComboBox() + self.enhanceButton = QPushButton('Enhance') + self.saveButton = QPushButton('Save') + + layout = QFormLayout() + layout.addRow('Measured Dataset:', self.openButton) + layout.addRow('Enhancement Strategy:', self.enhancementStrategyComboBox) + layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) + layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) + layout.addRow(self.enhanceButton) + layout.addRow('Enhanced Dataset:', self.saveButton) + self.setLayout(layout) + + +class FluorescenceDialog(QDialog): + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.measuredWidget = VisualizationWidget.createInstance('Measured') + self.enhancedWidget = VisualizationWidget.createInstance('Enhanced') + self.fluorescenceParametersView = FluorescenceParametersView() + self.fluorescenceChannelListView = QListView() + self.visualizationParametersView = VisualizationParametersView.createInstance() + self.statusBar = QStatusBar() + + parameterLayout = QVBoxLayout() + parameterLayout.addWidget(self.fluorescenceParametersView) + parameterLayout.addWidget(self.fluorescenceChannelListView, 1) + parameterLayout.addWidget(self.visualizationParametersView) + parameterLayout.addStretch() + + contentsLayout = QHBoxLayout() + contentsLayout.addWidget(self.measuredWidget, 1) + contentsLayout.addWidget(self.enhancedWidget, 1) + contentsLayout.addLayout(parameterLayout) + + layout = QVBoxLayout() + layout.addLayout(contentsLayout) + layout.addWidget(self.statusBar) + self.setLayout(layout)