diff --git a/ptychodus/api/reconstructor.py b/ptychodus/api/reconstructor.py index ccc84dcf..95fd56d7 100644 --- a/ptychodus/api/reconstructor.py +++ b/ptychodus/api/reconstructor.py @@ -74,21 +74,29 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: class TrainableReconstructor(Reconstructor): @abstractmethod - def ingest(self, parameters: ReconstructInput) -> None: + def ingestTrainingData(self, parameters: ReconstructInput) -> None: pass @abstractmethod - def train(self) -> Plot2D: + def getSaveFileFilterList(self) -> Sequence[str]: pass @abstractmethod - def reset(self) -> None: + def getSaveFileFilter(self) -> str: pass @abstractmethod def saveTrainingData(self, filePath: Path) -> None: pass + @abstractmethod + def train(self) -> Plot2D: + pass + + @abstractmethod + def clearTrainingData(self) -> None: + pass + class NullReconstructor(TrainableReconstructor): @@ -109,18 +117,24 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: result=0, ) - def ingest(self, parameters: ReconstructInput) -> None: + def ingestTrainingData(self, parameters: ReconstructInput) -> None: pass - def train(self) -> Plot2D: - return Plot2D.createNull() + def getSaveFileFilterList(self) -> Sequence[str]: + return list() - def reset(self) -> None: - pass + def getSaveFileFilter(self) -> str: + return str() def saveTrainingData(self, filePath: Path) -> None: pass + def train(self) -> Plot2D: + return Plot2D.createNull() + + def clearTrainingData(self) -> None: + pass + class ReconstructorLibrary(Iterable[Reconstructor]): diff --git a/ptychodus/controller/core.py b/ptychodus/controller/core.py index 2f1a640d..3e2b33e8 100644 --- a/ptychodus/controller/core.py +++ b/ptychodus/controller/core.py @@ -68,6 +68,7 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: model.objectPresenter, view.reconstructorParametersView, view.reconstructorPlotView, + self._fileDialogFactory, [ self._ptychopyViewControllerFactory, self._ptychonnViewControllerFactory, self._tikeViewControllerFactory diff --git a/ptychodus/controller/ptychonn/model.py b/ptychodus/controller/ptychonn/model.py index 2f977285..1938bb8d 100644 --- a/ptychodus/controller/ptychonn/model.py +++ b/ptychodus/controller/ptychonn/model.py @@ -24,8 +24,8 @@ def createInstance(cls, presenter: PtychoNNModelPresenter, view: PtychoNNModelPa view.modelStateLineEdit.editingFinished.connect(controller._syncModelStateFilePathToModel) view.modelStateBrowseButton.clicked.connect(controller._openModelState) - view.numberOfConvolutionChannelsSpinBox.valueChanged.connect( - presenter.setNumberOfConvolutionChannels) + view.numberOfConvolutionKernelsSpinBox.valueChanged.connect( + presenter.setNumberOfConvolutionKernels) view.batchSizeSpinBox.valueChanged.connect(presenter.setBatchSize) view.useBatchNormalizationCheckBox.toggled.connect(presenter.setBatchNormalizationEnabled) @@ -54,13 +54,13 @@ def _syncModelToView(self) -> None: else: self._view.modelStateLineEdit.clear() - self._view.numberOfConvolutionChannelsSpinBox.blockSignals(True) - self._view.numberOfConvolutionChannelsSpinBox.setRange( - self._presenter.getNumberOfConvolutionChannelsLimits().lower, - self._presenter.getNumberOfConvolutionChannelsLimits().upper) - self._view.numberOfConvolutionChannelsSpinBox.setValue( - self._presenter.getNumberOfConvolutionChannels()) - self._view.numberOfConvolutionChannelsSpinBox.blockSignals(False) + self._view.numberOfConvolutionKernelsSpinBox.blockSignals(True) + self._view.numberOfConvolutionKernelsSpinBox.setRange( + self._presenter.getNumberOfConvolutionKernelsLimits().lower, + self._presenter.getNumberOfConvolutionKernelsLimits().upper) + self._view.numberOfConvolutionKernelsSpinBox.setValue( + self._presenter.getNumberOfConvolutionKernels()) + self._view.numberOfConvolutionKernelsSpinBox.blockSignals(False) self._view.batchSizeSpinBox.blockSignals(True) self._view.batchSizeSpinBox.setRange(self._presenter.getBatchSizeLimits().lower, diff --git a/ptychodus/controller/ptychonn/training.py b/ptychodus/controller/ptychonn/training.py index 0d58a87e..c4085d17 100644 --- a/ptychodus/controller/ptychonn/training.py +++ b/ptychodus/controller/ptychonn/training.py @@ -6,7 +6,6 @@ from ...api.observer import Observable, Observer from ...model.ptychonn import PtychoNNTrainingPresenter from ...view.ptychonn import PtychoNNOutputParametersView, PtychoNNTrainingParametersView -from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory logger = logging.getLogger(__name__) @@ -80,7 +79,6 @@ def __init__(self, presenter: PtychoNNTrainingPresenter, view: PtychoNNTrainingP super().__init__() self._presenter = presenter self._view = view - self._fileDialogFactory = fileDialogFactory self._outputParametersController = PtychoNNOutputParametersController.createInstance( presenter, view.outputParametersView, fileDialogFactory) @@ -99,26 +97,11 @@ def createInstance( view.minimumLearningRateLineEdit.valueChanged.connect(presenter.setMinimumLearningRate) view.trainingEpochsSpinBox.valueChanged.connect(presenter.setTrainingEpochs) view.statusIntervalSpinBox.valueChanged.connect(presenter.setStatusIntervalInEpochs) - view.saveTrainingDataButton.clicked.connect(controller._saveTrainingData) controller._syncModelToView() return controller - def _saveTrainingData(self) -> None: - filePath, _ = self._fileDialogFactory.getSaveFilePath( - self._view, - 'Save Training Data', - nameFilters=self._presenter.getSaveFileFilterList(), - selectedNameFilter=self._presenter.getSaveFileFilter()) - - if filePath: - try: - self._presenter.saveTrainingData(filePath) - except Exception as err: - logger.exception(err) - ExceptionDialog.showException('File writer', err) - def _syncModelToView(self) -> None: self._view.validationSetFractionalSizeSlider.setValueAndRange( self._presenter.getValidationSetFractionalSize(), diff --git a/ptychodus/controller/reconstructor.py b/ptychodus/controller/reconstructor.py index 452711b3..a2216634 100644 --- a/ptychodus/controller/reconstructor.py +++ b/ptychodus/controller/reconstructor.py @@ -14,6 +14,7 @@ from ..model.scan import ScanPresenter from ..view.reconstructor import ReconstructorParametersView, ReconstructorPlotView from ..view.widgets import ExceptionDialog +from .data import FileDialogFactory logger = logging.getLogger(__name__) @@ -32,16 +33,11 @@ def createViewController(self, reconstructorName: str) -> QWidget: class ReconstructorParametersController(Observer): - def __init__( - self, - presenter: ReconstructorPresenter, - scanPresenter: ScanPresenter, - probePresenter: ProbePresenter, - objectPresenter: ObjectPresenter, - view: ReconstructorParametersView, - plotView: ReconstructorPlotView, - viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory], - ) -> None: + def __init__(self, presenter: ReconstructorPresenter, scanPresenter: ScanPresenter, + probePresenter: ProbePresenter, objectPresenter: ObjectPresenter, + view: ReconstructorParametersView, plotView: ReconstructorPlotView, + fileDialogFactory: FileDialogFactory, + viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory]) -> None: super().__init__() self._presenter = presenter self._scanPresenter = scanPresenter @@ -49,6 +45,7 @@ def __init__( self._objectPresenter = objectPresenter self._view = view self._plotView = plotView + self._fileDialogFactory = fileDialogFactory self._viewControllerFactoryDict: dict[str, ReconstructorViewControllerFactory] = \ { vcf.backendName: vcf for vcf in viewControllerFactoryList } self._scanListModel = QStringListModel() @@ -57,17 +54,14 @@ def __init__( @classmethod def createInstance( - cls, - presenter: ReconstructorPresenter, - scanPresenter: ScanPresenter, - probePresenter: ProbePresenter, - objectPresenter: ObjectPresenter, - view: ReconstructorParametersView, - plotView: ReconstructorPlotView, - viewControllerFactoryList: list[ReconstructorViewControllerFactory], + cls, presenter: ReconstructorPresenter, scanPresenter: ScanPresenter, + probePresenter: ProbePresenter, objectPresenter: ObjectPresenter, + view: ReconstructorParametersView, plotView: ReconstructorPlotView, + fileDialogFactory: FileDialogFactory, + viewControllerFactoryList: list[ReconstructorViewControllerFactory] ) -> ReconstructorParametersController: controller = cls(presenter, scanPresenter, probePresenter, objectPresenter, view, plotView, - viewControllerFactoryList) + fileDialogFactory, viewControllerFactoryList) presenter.addObserver(controller) scanPresenter.addObserver(controller) probePresenter.addObserver(controller) @@ -91,9 +85,10 @@ def createInstance( view.reconstructorView.reconstructButton.clicked.connect(controller._reconstruct) view.reconstructorView.reconstructSplitButton.clicked.connect(controller._reconstructSplit) - view.reconstructorView.ingestButton.clicked.connect(controller._ingest) + view.reconstructorView.ingestButton.clicked.connect(controller._ingestTrainingData) + view.reconstructorView.saveButton.clicked.connect(controller._saveTrainingData) view.reconstructorView.trainButton.clicked.connect(controller._train) - view.reconstructorView.resetButton.clicked.connect(controller._reset) + view.reconstructorView.clearButton.clicked.connect(controller._clearTrainingData) controller._syncModelToView() controller._syncScanToView() @@ -130,13 +125,27 @@ def _reconstructSplit(self) -> None: logger.exception(err) ExceptionDialog.showException('Split Reconstructor', err) - def _ingest(self) -> None: + def _ingestTrainingData(self) -> None: try: - self._presenter.ingest() + self._presenter.ingestTrainingData() except Exception as err: logger.exception(err) ExceptionDialog.showException('Ingester', err) + def _saveTrainingData(self) -> None: + filePath, _ = self._fileDialogFactory.getSaveFilePath( + self._view, + 'Save Training Data', + nameFilters=self._presenter.getSaveFileFilterList(), + selectedNameFilter=self._presenter.getSaveFileFilter()) + + if filePath: + try: + self._presenter.saveTrainingData(filePath) + except Exception as err: + logger.exception(err) + ExceptionDialog.showException('File writer', err) + def _train(self) -> None: try: self._presenter.train() @@ -144,12 +153,12 @@ def _train(self) -> None: logger.exception(err) ExceptionDialog.showException('Trainer', err) - def _reset(self) -> None: + def _clearTrainingData(self) -> None: try: - self._presenter.reset() + self._presenter.clearTrainingData() except Exception as err: logger.exception(err) - ExceptionDialog.showException('Reset', err) + ExceptionDialog.showException('Clear', err) def _syncScanToView(self) -> None: self._view.reconstructorView.scanComboBox.blockSignals(True) @@ -221,8 +230,9 @@ def _syncModelToView(self) -> None: isTrainable = self._presenter.isTrainable self._view.reconstructorView.ingestButton.setVisible(isTrainable) + self._view.reconstructorView.saveButton.setVisible(isTrainable) self._view.reconstructorView.trainButton.setVisible(isTrainable) - self._view.reconstructorView.resetButton.setVisible(isTrainable) + self._view.reconstructorView.clearButton.setVisible(isTrainable) self._redrawPlot() diff --git a/ptychodus/controller/tike/basic.py b/ptychodus/controller/tike/basic.py index 5b7bdaf6..01df3aba 100644 --- a/ptychodus/controller/tike/basic.py +++ b/ptychodus/controller/tike/basic.py @@ -39,7 +39,6 @@ def createInstance(cls, presenter: TikePresenter, view.numIterSpinBox.valueChanged.connect(presenter.setNumIter) view.convergenceWindowSpinBox.valueChanged.connect(presenter.setConvergenceWindow) - view.cgIterSpinBox.valueChanged.connect(presenter.setCgIter) view.alphaSlider.valueChanged.connect(presenter.setAlpha) view.stepLengthSlider.valueChanged.connect(presenter.setStepLength) @@ -81,12 +80,6 @@ def _syncModelToView(self) -> None: self._view.convergenceWindowSpinBox.setValue(self._presenter.getConvergenceWindow()) self._view.convergenceWindowSpinBox.blockSignals(False) - self._view.cgIterSpinBox.blockSignals(True) - self._view.cgIterSpinBox.setRange(self._presenter.getCgIterLimits().lower, - self._presenter.getCgIterLimits().upper) - self._view.cgIterSpinBox.setValue(self._presenter.getCgIter()) - self._view.cgIterSpinBox.blockSignals(False) - self._view.alphaSlider.setValueAndRange(self._presenter.getAlpha(), self._presenter.getAlphaLimits(), blockValueChangedSignal=True) diff --git a/ptychodus/controller/tike/factory.py b/ptychodus/controller/tike/factory.py index fe3b7ac0..426ab55f 100644 --- a/ptychodus/controller/tike/factory.py +++ b/ptychodus/controller/tike/factory.py @@ -21,21 +21,13 @@ def createViewController(self, reconstructorName: str) -> QWidget: view = None if reconstructorName == 'rpie': - view = TikeParametersView.createInstance(showCgIter=False, - showAlpha=True, - showStepLength=False) + view = TikeParametersView.createInstance(showAlpha=True, showStepLength=False) elif reconstructorName == 'lstsq_grad': - view = TikeParametersView.createInstance(showCgIter=False, - showAlpha=False, - showStepLength=False) + view = TikeParametersView.createInstance(showAlpha=False, showStepLength=False) elif reconstructorName == 'dm': - view = TikeParametersView.createInstance(showCgIter=False, - showAlpha=False, - showStepLength=False) + view = TikeParametersView.createInstance(showAlpha=False, showStepLength=False) else: - view = TikeParametersView.createInstance(showCgIter=True, - showAlpha=True, - showStepLength=True) + view = TikeParametersView.createInstance(showAlpha=True, showStepLength=True) controller = TikeParametersController.createInstance(self._model, view) self._controllerList.append(controller) diff --git a/ptychodus/model/automation/workflow.py b/ptychodus/model/automation/workflow.py index bcc7fc55..35fca68b 100644 --- a/ptychodus/model/automation/workflow.py +++ b/ptychodus/model/automation/workflow.py @@ -85,5 +85,5 @@ def __init__(self, registry: StateDataRegistry, reconstructorAPI: ReconstructorA def execute(self, filePath: Path) -> None: # TODO watch for ptychodus NPZ files self._registry.openStateData(filePath) - self._reconstructorAPI.ingest() + self._reconstructorAPI.ingestTrainingData() self._reconstructorAPI.train() diff --git a/ptychodus/model/core.py b/ptychodus/model/core.py index 31f2436a..70af6b14 100644 --- a/ptychodus/model/core.py +++ b/ptychodus/model/core.py @@ -248,7 +248,7 @@ def batchModeTrain(self, directoryPath: Path) -> float: for filePath in directoryPath.glob('*.npz'): # TODO sort by filePath.stat().st_mtime self._stateDataRegistry.openStateData(filePath) - self._reconstructorCore.reconstructorAPI.ingest() + self._reconstructorCore.reconstructorAPI.ingestTrainingData() self._reconstructorCore.reconstructorAPI.train() diff --git a/ptychodus/model/ptychonn/core.py b/ptychodus/model/ptychonn/core.py index fc08a9cf..5419ea80 100644 --- a/ptychodus/model/ptychonn/core.py +++ b/ptychodus/model/ptychonn/core.py @@ -19,15 +19,15 @@ class PtychoNNModelPresenter(Observable, Observer): MAX_INT: Final[int] = 0x7FFFFFFF - def __init__(self, modelSettings: PtychoNNModelSettings) -> None: + def __init__(self, settings: PtychoNNModelSettings) -> None: super().__init__() - self._modelSettings = modelSettings + self._settings = settings self._fileFilterList: list[str] = ['PyTorch Model State Files (*.pt *.pth)'] @classmethod - def createInstance(cls, modelSettings: PtychoNNModelSettings) -> PtychoNNModelPresenter: - presenter = cls(modelSettings) - modelSettings.addObserver(presenter) + def createInstance(cls, settings: PtychoNNModelSettings) -> PtychoNNModelPresenter: + presenter = cls(settings) + settings.addObserver(presenter) return presenter def getStateFileFilterList(self) -> Sequence[str]: @@ -37,56 +37,53 @@ def getStateFileFilter(self) -> str: return self._fileFilterList[0] def getStateFilePath(self) -> Path: - return self._modelSettings.stateFilePath.value + return self._settings.stateFilePath.value def setStateFilePath(self, directory: Path) -> None: - self._modelSettings.stateFilePath.value = directory + self._settings.stateFilePath.value = directory - def getNumberOfConvolutionChannelsLimits(self) -> Interval[int]: + def getNumberOfConvolutionKernelsLimits(self) -> Interval[int]: return Interval[int](1, self.MAX_INT) - def getNumberOfConvolutionChannels(self) -> int: - limits = self.getNumberOfConvolutionChannelsLimits() - return limits.clamp(self._modelSettings.numberOfConvolutionChannels.value) + def getNumberOfConvolutionKernels(self) -> int: + limits = self.getNumberOfConvolutionKernelsLimits() + return limits.clamp(self._settings.numberOfConvolutionKernels.value) - def setNumberOfConvolutionChannels(self, value: int) -> None: - self._modelSettings.numberOfConvolutionChannels.value = value + def setNumberOfConvolutionKernels(self, value: int) -> None: + self._settings.numberOfConvolutionKernels.value = value def getBatchSizeLimits(self) -> Interval[int]: return Interval[int](1, self.MAX_INT) def getBatchSize(self) -> int: limits = self.getBatchSizeLimits() - return limits.clamp(self._modelSettings.batchSize.value) + return limits.clamp(self._settings.batchSize.value) def setBatchSize(self, value: int) -> None: - self._modelSettings.batchSize.value = value + self._settings.batchSize.value = value def isBatchNormalizationEnabled(self) -> bool: - return self._modelSettings.useBatchNormalization.value + return self._settings.useBatchNormalization.value def setBatchNormalizationEnabled(self, enabled: bool) -> None: - self._modelSettings.useBatchNormalization.value = enabled + self._settings.useBatchNormalization.value = enabled def update(self, observable: Observable) -> None: - if observable is self._modelSettings: + if observable is self._settings: self.notifyObservers() class PtychoNNTrainingPresenter(Observable, Observer): MAX_INT: Final[int] = 0x7FFFFFFF - def __init__(self, modelSettings: PtychoNNTrainingSettings, - trainer: TrainableReconstructor) -> None: + def __init__(self, settings: PtychoNNTrainingSettings) -> None: super().__init__() - self._modelSettings = modelSettings - self._trainer = trainer + self._settings = settings @classmethod - def createInstance(cls, modelSettings: PtychoNNTrainingSettings, - trainer: TrainableReconstructor) -> PtychoNNTrainingPresenter: - presenter = cls(modelSettings, trainer) - modelSettings.addObserver(presenter) + def createInstance(cls, settings: PtychoNNTrainingSettings) -> PtychoNNTrainingPresenter: + presenter = cls(settings) + settings.addObserver(presenter) return presenter def getValidationSetFractionalSizeLimits(self) -> Interval[Decimal]: @@ -94,94 +91,81 @@ def getValidationSetFractionalSizeLimits(self) -> Interval[Decimal]: def getValidationSetFractionalSize(self) -> Decimal: limits = self.getValidationSetFractionalSizeLimits() - return limits.clamp(self._modelSettings.validationSetFractionalSize.value) + return limits.clamp(self._settings.validationSetFractionalSize.value) def setValidationSetFractionalSize(self, value: Decimal) -> None: - self._modelSettings.validationSetFractionalSize.value = value + self._settings.validationSetFractionalSize.value = value def getOptimizationEpochsPerHalfCycleLimits(self) -> Interval[int]: return Interval[int](1, self.MAX_INT) def getOptimizationEpochsPerHalfCycle(self) -> int: limits = self.getOptimizationEpochsPerHalfCycleLimits() - return limits.clamp(self._modelSettings.optimizationEpochsPerHalfCycle.value) + return limits.clamp(self._settings.optimizationEpochsPerHalfCycle.value) def setOptimizationEpochsPerHalfCycle(self, value: int) -> None: - self._modelSettings.optimizationEpochsPerHalfCycle.value = value + self._settings.optimizationEpochsPerHalfCycle.value = value def getMaximumLearningRateLimits(self) -> Interval[Decimal]: return Interval[Decimal](Decimal(0), Decimal(1)) def getMaximumLearningRate(self) -> Decimal: limits = self.getMaximumLearningRateLimits() - return limits.clamp(self._modelSettings.maximumLearningRate.value) + return limits.clamp(self._settings.maximumLearningRate.value) def setMaximumLearningRate(self, value: Decimal) -> None: - self._modelSettings.maximumLearningRate.value = value + self._settings.maximumLearningRate.value = value def getMinimumLearningRateLimits(self) -> Interval[Decimal]: return Interval[Decimal](Decimal(0), Decimal(1)) def getMinimumLearningRate(self) -> Decimal: limits = self.getMinimumLearningRateLimits() - return limits.clamp(self._modelSettings.minimumLearningRate.value) + return limits.clamp(self._settings.minimumLearningRate.value) def setMinimumLearningRate(self, value: Decimal) -> None: - self._modelSettings.minimumLearningRate.value = value + self._settings.minimumLearningRate.value = value def getTrainingEpochsLimits(self) -> Interval[int]: return Interval[int](1, self.MAX_INT) def getTrainingEpochs(self) -> int: limits = self.getTrainingEpochsLimits() - return limits.clamp(self._modelSettings.trainingEpochs.value) + return limits.clamp(self._settings.trainingEpochs.value) def setTrainingEpochs(self, value: int) -> None: - self._modelSettings.trainingEpochs.value = value + self._settings.trainingEpochs.value = value def isSaveTrainingArtifactsEnabled(self) -> bool: - return self._modelSettings.saveTrainingArtifacts.value + return self._settings.saveTrainingArtifacts.value def setSaveTrainingArtifactsEnabled(self, enabled: bool) -> None: - self._modelSettings.saveTrainingArtifacts.value = enabled + self._settings.saveTrainingArtifacts.value = enabled def getOutputPath(self) -> Path: - return self._modelSettings.outputPath.value + return self._settings.outputPath.value def setOutputPath(self, directory: Path) -> None: - self._modelSettings.outputPath.value = directory + self._settings.outputPath.value = directory def getOutputSuffix(self) -> str: - return self._modelSettings.outputSuffix.value + return self._settings.outputSuffix.value def setOutputSuffix(self, suffix: str) -> None: - self._modelSettings.outputSuffix.value = suffix + self._settings.outputSuffix.value = suffix def getStatusIntervalInEpochsLimits(self) -> Interval[int]: return Interval[int](1, self.MAX_INT) def getStatusIntervalInEpochs(self) -> int: limits = self.getStatusIntervalInEpochsLimits() - return limits.clamp(self._modelSettings.statusIntervalInEpochs.value) + return limits.clamp(self._settings.statusIntervalInEpochs.value) def setStatusIntervalInEpochs(self, value: int) -> None: - self._modelSettings.statusIntervalInEpochs.value = value - - def train(self) -> None: - self._trainer.train() - - def getSaveFileFilterList(self) -> Sequence[str]: - return [self.getSaveFileFilter()] - - def getSaveFileFilter(self) -> str: - return 'NumPy Zipped Archive (*.npz)' - - def saveTrainingData(self, filePath: Path) -> None: - logger.debug(f'Writing \"{filePath}\" as \"NPZ\"') - self._trainer.saveTrainingData(filePath) + self._settings.statusIntervalInEpochs.value = value def update(self, observable: Observable) -> None: - if observable is self._modelSettings: + if observable is self._settings: self.notifyObservers() @@ -189,14 +173,12 @@ class PtychoNNReconstructorLibrary(ReconstructorLibrary): def __init__(self, modelSettings: PtychoNNModelSettings, trainingSettings: PtychoNNTrainingSettings, - phaseOnlyTrainableReconstructor: TrainableReconstructor, reconstructors: Sequence[Reconstructor]) -> None: super().__init__() self._modelSettings = modelSettings self._trainingSettings = trainingSettings self.modelPresenter = PtychoNNModelPresenter.createInstance(modelSettings) - self.trainingPresenter = PtychoNNTrainingPresenter.createInstance( - trainingSettings, phaseOnlyTrainableReconstructor) + self.trainingPresenter = PtychoNNTrainingPresenter.createInstance(trainingSettings) self._reconstructors = reconstructors @classmethod @@ -204,23 +186,31 @@ def createInstance(cls, settingsRegistry: SettingsRegistry, objectAPI: ObjectAPI isDeveloperModeEnabled: bool) -> PtychoNNReconstructorLibrary: modelSettings = PtychoNNModelSettings.createInstance(settingsRegistry) trainingSettings = PtychoNNTrainingSettings.createInstance(settingsRegistry) - phaseOnlyTrainableReconstructor: TrainableReconstructor = NullReconstructor('PhaseOnly') + phaseOnlyReconstructor: TrainableReconstructor = NullReconstructor('PhaseOnly') + amplitudePhaseReconstructor: TrainableReconstructor = NullReconstructor('AmplitudePhase') reconstructors: list[TrainableReconstructor] = list() try: - from .phaseOnly import PtychoNNPhaseOnlyTrainableReconstructor + from .reconstructor import PtychoNNTrainableReconstructor except ModuleNotFoundError: logger.info('PtychoNN not found.') if isDeveloperModeEnabled: - reconstructors.append(phaseOnlyTrainableReconstructor) + reconstructors.append(phaseOnlyReconstructor) + reconstructors.append(amplitudePhaseReconstructor) else: - phaseOnlyTrainableReconstructor = PtychoNNPhaseOnlyTrainableReconstructor( - modelSettings, trainingSettings, objectAPI) - reconstructors.append(phaseOnlyTrainableReconstructor) - - return cls(modelSettings, trainingSettings, phaseOnlyTrainableReconstructor, - reconstructors) + phaseOnlyReconstructor = PtychoNNTrainableReconstructor(modelSettings, + trainingSettings, + objectAPI, + enableAmplitude=False) + amplitudePhaseReconstructor = PtychoNNTrainableReconstructor(modelSettings, + trainingSettings, + objectAPI, + enableAmplitude=True) + reconstructors.append(phaseOnlyReconstructor) + reconstructors.append(amplitudePhaseReconstructor) + + return cls(modelSettings, trainingSettings, reconstructors) @property def name(self) -> str: diff --git a/ptychodus/model/ptychonn/phaseOnly.py b/ptychodus/model/ptychonn/reconstructor.py similarity index 65% rename from ptychodus/model/ptychonn/phaseOnly.py rename to ptychodus/model/ptychonn/reconstructor.py index f1484237..e53241e3 100644 --- a/ptychodus/model/ptychonn/phaseOnly.py +++ b/ptychodus/model/ptychonn/reconstructor.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Sequence from importlib.metadata import version from pathlib import Path from typing import Any, Mapping, TypeAlias @@ -10,7 +11,7 @@ import numpy.typing from ...api.image import ImageExtent -from ...api.object import ObjectPatchAxis +from ...api.object import ObjectArrayType, ObjectPatchAxis from ...api.plot import Plot2D, PlotAxis, PlotSeries from ...api.reconstructor import ReconstructInput, ReconstructOutput, TrainableReconstructor from ..object import ObjectAPI @@ -21,15 +22,18 @@ logger = logging.getLogger(__name__) -class CircularBuffer: +class PatternCircularBuffer: def __init__(self, extent: ImageExtent, maxSize: int) -> None: - self._buffer: FloatArrayType = numpy.zeros((maxSize, *extent.shape), dtype=numpy.float32) + self._buffer: FloatArrayType = numpy.zeros( + (maxSize, *extent.shape), + dtype=numpy.float32, + ) self._pos = 0 self._full = False @classmethod - def createZeroSized(cls) -> CircularBuffer: + def createZeroSized(cls) -> PatternCircularBuffer: return cls(ImageExtent(0, 0), 0) @property @@ -48,28 +52,66 @@ def getBuffer(self) -> FloatArrayType: return self._buffer if self._full else self._buffer[:self._pos] -class PtychoNNPhaseOnlyTrainableReconstructor(TrainableReconstructor): +class ObjectPatchCircularBuffer: + + def __init__(self, extent: ImageExtent, channels: int, maxSize: int) -> None: + self._buffer: FloatArrayType = numpy.zeros( + (maxSize, channels, *extent.shape), + dtype=numpy.float32, + ) + self._pos = 0 + self._full = False + + @classmethod + def createZeroSized(cls) -> ObjectPatchCircularBuffer: + return cls(ImageExtent(0, 0), 0, 0) + + @property + def isZeroSized(self) -> bool: + return (self._buffer.size == 0) + + def append(self, array: ObjectArrayType) -> None: + self._buffer[self._pos, 0, :, :] = numpy.angle(array).astype(numpy.float32) + + if self._buffer.shape[1] > 1: + self._buffer[self._pos, 1, :, :] = numpy.absolute(array).astype(numpy.float32) + + self._pos += 1 + + if self._pos == self._buffer.shape[0]: + self._pos = 0 + self._full = True + + def getBuffer(self) -> FloatArrayType: + return self._buffer if self._full else self._buffer[:self._pos] + + +class PtychoNNTrainableReconstructor(TrainableReconstructor): - def __init__(self, settings: PtychoNNModelSettings, trainingSettings: PtychoNNTrainingSettings, - objectAPI: ObjectAPI) -> None: - self._settings = settings + def __init__(self, modelSettings: PtychoNNModelSettings, + trainingSettings: PtychoNNTrainingSettings, objectAPI: ObjectAPI, *, + enableAmplitude: bool) -> None: + self._modelSettings = modelSettings self._trainingSettings = trainingSettings self._objectAPI = objectAPI - self._diffractionPatternBuffer = CircularBuffer.createZeroSized() - self._objectPatchBuffer = CircularBuffer.createZeroSized() + self._patternBuffer = PatternCircularBuffer.createZeroSized() + self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() + self._enableAmplitude = enableAmplitude + self._fileFilterList: list[str] = ['NumPy Zipped Archive (*.npz)'] ptychonnVersion = version('ptychonn') logger.info(f'\tPtychoNN {ptychonnVersion}') @property def name(self) -> str: - return 'PhaseOnly' + return 'AmplitudePhase' if self._enableAmplitude else 'PhaseOnly' def _createModel(self) -> ReconSmallModel: logger.debug('Building model...') return ReconSmallModel( - nconv=self._settings.numberOfConvolutionChannels.value, - use_batch_norm=self._settings.useBatchNormalization.value, + nconv=self._modelSettings.numberOfConvolutionKernels.value, + use_batch_norm=self._modelSettings.useBatchNormalization.value, + enable_amplitude=self._enableAmplitude, ) def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: @@ -103,12 +145,12 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: logger.debug('Loading model state...') tester = Tester( model=self._createModel(), - model_params_path=self._settings.stateFilePath.value, + model_params_path=self._modelSettings.stateFilePath.value, ) logger.debug('Inferring...') tester.setTestData(binnedData.astype(numpy.float32), - batch_size=self._settings.batchSize.value) + batch_size=self._modelSettings.batchSize.value) npzSavePath = None # TODO self._trainingSettings.outputPath.value / 'preds.npz' objectPatches = tester.predictTestData(npz_save_path=npzSavePath) @@ -124,8 +166,13 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: height=objectPatches.shape[-2], ) - for scanPoint, objectPatchReals in zip(parameters.scan.values(), objectPatches): - objectPatch = 0.5 * numpy.exp(1j * objectPatchReals[0]) + for scanPoint, objectPatchChannels in zip(parameters.scan.values(), objectPatches): + objectPatch = numpy.exp(1j * objectPatchChannels[0]) + + if objectPatchChannels.shape[0] == 2: + objectPatch *= objectPatchChannels[1] + else: + objectPatch *= 0.5 patchAxisX = ObjectPatchAxis(objectGrid.axisX, scanPoint.x, patchExtent.width) patchAxisY = ObjectPatchAxis(objectGrid.axisY, scanPoint.y, patchExtent.height) @@ -152,23 +199,24 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput: result=0, ) - def ingest(self, parameters: ReconstructInput) -> None: + def ingestTrainingData(self, parameters: ReconstructInput) -> None: objectInterpolator = parameters.objectInterpolator - if self._diffractionPatternBuffer.isZeroSized: + if self._patternBuffer.isZeroSized: diffractionPatternExtent = parameters.diffractionPatternExtent maximumSize = max(1, self._trainingSettings.maximumTrainingDatasetSize.value) - self._diffractionPatternBuffer = CircularBuffer(diffractionPatternExtent, maximumSize) - self._objectPatchBuffer = CircularBuffer(diffractionPatternExtent, maximumSize) + channels = 2 if self._enableAmplitude else 1 + self._patternBuffer = PatternCircularBuffer(diffractionPatternExtent, maximumSize) + self._objectPatchBuffer = ObjectPatchCircularBuffer(diffractionPatternExtent, channels, + maximumSize) for scanIndex, scanPoint in parameters.scan.items(): objectPatch = objectInterpolator.getPatch(scanPoint, parameters.probeExtent) - objectPhasePatch = numpy.angle(objectPatch.array).astype(numpy.float32) - self._objectPatchBuffer.append(objectPhasePatch) + self._objectPatchBuffer.append(objectPatch.array) for pattern in parameters.diffractionPatternArray.astype(numpy.float32): - self._diffractionPatternBuffer.append(pattern) + self._patternBuffer.append(pattern) def _plotMetrics(self, metrics: Mapping[str, Any]) -> Plot2D: trainingLoss = [losses[0] for losses in metrics['losses']] @@ -182,23 +230,34 @@ def _plotMetrics(self, metrics: Mapping[str, Any]) -> Plot2D: axisY=PlotAxis(label='Loss', series=[trainingLossSeries, validationLossSeries]), ) + def getSaveFileFilterList(self) -> Sequence[str]: + return self._fileFilterList + + def getSaveFileFilter(self) -> str: + return self._fileFilterList[0] + + def saveTrainingData(self, filePath: Path) -> None: + logger.debug(f'Writing \"{filePath}\" as \"NPZ\"') + trainingData = { + 'diffractionPatterns': self._patternBuffer.getBuffer(), + 'objectPatches': self._objectPatchBuffer.getBuffer(), + } + numpy.savez(filePath, **trainingData) + def train(self) -> Plot2D: outputPath = self._trainingSettings.outputPath.value \ if self._trainingSettings.saveTrainingArtifacts.value else None trainer = Trainer( model=self._createModel(), - batch_size=self._settings.batchSize.value, + batch_size=self._modelSettings.batchSize.value, output_path=outputPath, output_suffix=self._trainingSettings.outputSuffix.value, ) - X_train_full = self._diffractionPatternBuffer.getBuffer() - Y_ph_train_full = numpy.expand_dims(self._objectPatchBuffer.getBuffer(), 1) - trainer.setTrainingData( - X_train_full=X_train_full, - Y_ph_train_full=Y_ph_train_full, + X_train_full=self._patternBuffer.getBuffer(), + Y_ph_train_full=self._objectPatchBuffer.getBuffer(), valid_data_ratio=float(self._trainingSettings.validationSetFractionalSize.value), ) trainer.setOptimizationParams( @@ -218,13 +277,6 @@ def train(self) -> Plot2D: return self._plotMetrics(trainer.metrics) - def reset(self) -> None: - self._diffractionPatternBuffer = CircularBuffer.createZeroSized() - self._objectPatchBuffer = CircularBuffer.createZeroSized() - - def saveTrainingData(self, filePath: Path) -> None: - trainingData = { - 'diffractionPatterns': self._diffractionPatternBuffer.getBuffer(), - 'objectPatches': self._objectPatchBuffer.getBuffer(), - } - numpy.savez(filePath, **trainingData) + def clearTrainingData(self) -> None: + self._patternBuffer = PatternCircularBuffer.createZeroSized() + self._objectPatchBuffer = ObjectPatchCircularBuffer.createZeroSized() diff --git a/ptychodus/model/ptychonn/settings.py b/ptychodus/model/ptychonn/settings.py index edc701f1..a1fe5ebd 100644 --- a/ptychodus/model/ptychonn/settings.py +++ b/ptychodus/model/ptychonn/settings.py @@ -12,16 +12,17 @@ def __init__(self, settingsGroup: SettingsGroup) -> None: self._settingsGroup = settingsGroup self.stateFilePath = settingsGroup.createPathEntry('StateFilePath', Path('/path/to/best_model.pth')) - self.numberOfConvolutionChannels = settingsGroup.createIntegerEntry( - 'NumberOfConvolutionChannels', 16) + self.numberOfConvolutionKernels = settingsGroup.createIntegerEntry( + 'NumberOfConvolutionKernels', 16) self.batchSize = settingsGroup.createIntegerEntry('BatchSize', 64) self.useBatchNormalization = settingsGroup.createBooleanEntry( 'UseBatchNormalization', False) @classmethod def createInstance(cls, settingsRegistry: SettingsRegistry) -> PtychoNNModelSettings: - settings = cls(settingsRegistry.createGroup('PtychoNN')) - settings._settingsGroup.addObserver(settings) + settingsGroup = settingsRegistry.createGroup('PtychoNN') + settings = cls(settingsGroup) + settingsGroup.addObserver(settings) return settings def update(self, observable: Observable) -> None: @@ -51,8 +52,9 @@ def __init__(self, settingsGroup: SettingsGroup) -> None: @classmethod def createInstance(cls, settingsRegistry: SettingsRegistry) -> PtychoNNTrainingSettings: - settings = cls(settingsRegistry.createGroup('PtychoNNTraining')) - settings._settingsGroup.addObserver(settings) + settingsGroup = settingsRegistry.createGroup('PtychoNNTraining') + settings = cls(settingsGroup) + settingsGroup.addObserver(settings) return settings def update(self, observable: Observable) -> None: diff --git a/ptychodus/model/reconstructor/active.py b/ptychodus/model/reconstructor/active.py index 621a106e..0d9dbdd4 100644 --- a/ptychodus/model/reconstructor/active.py +++ b/ptychodus/model/reconstructor/active.py @@ -1,5 +1,6 @@ from __future__ import annotations from collections.abc import Iterable, Sequence +from pathlib import Path import logging import time @@ -123,7 +124,7 @@ def isTrainable(self) -> bool: reconstructor = self._pluginChooser.currentPlugin.strategy return isinstance(reconstructor, TrainableReconstructor) - def ingest(self) -> None: + def ingestTrainingData(self) -> None: reconstructor = self._pluginChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): @@ -135,11 +136,43 @@ def ingest(self) -> None: logger.info('Ingesting...') tic = time.perf_counter() - reconstructor.ingest(parameters) + reconstructor.ingestTrainingData(parameters) toc = time.perf_counter() logger.info(f'Ingest time {toc - tic:.4f} seconds.') else: - logger.error('Reconstructor is not trainable!') + logger.warning('Reconstructor is not trainable!') + + def getSaveFileFilterList(self) -> Sequence[str]: + reconstructor = self._pluginChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + return reconstructor.getSaveFileFilterList() + else: + logger.warning('Reconstructor is not trainable!') + + return list() + + def getSaveFileFilter(self) -> str: + reconstructor = self._pluginChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + return reconstructor.getSaveFileFilter() + else: + logger.warning('Reconstructor is not trainable!') + + return str() + + def saveTrainingData(self, filePath: Path) -> None: + reconstructor = self._pluginChooser.currentPlugin.strategy + + if isinstance(reconstructor, TrainableReconstructor): + logger.info('Saving...') + tic = time.perf_counter() + reconstructor.saveTrainingData(filePath) + toc = time.perf_counter() + logger.info(f'Save time {toc - tic:.4f} seconds.') + else: + logger.warning('Reconstructor is not trainable!') def train(self) -> Plot2D: reconstructor = self._pluginChooser.currentPlugin.strategy @@ -152,21 +185,21 @@ def train(self) -> Plot2D: toc = time.perf_counter() logger.info(f'Training time {toc - tic:.4f} seconds.') else: - logger.error('Reconstructor is not trainable!') + logger.warning('Reconstructor is not trainable!') return plot2D - def reset(self) -> None: + def clearTrainingData(self) -> None: reconstructor = self._pluginChooser.currentPlugin.strategy if isinstance(reconstructor, TrainableReconstructor): logger.info('Resetting...') tic = time.perf_counter() - reconstructor.reset() + reconstructor.clearTrainingData() toc = time.perf_counter() logger.info(f'Reset time {toc - tic:.4f} seconds.') else: - logger.error('Reconstructor is not trainable!') + logger.warning('Reconstructor is not trainable!') def selectReconstructor(self, name: str) -> None: self._pluginChooser.setCurrentPluginByName(name) diff --git a/ptychodus/model/reconstructor/api.py b/ptychodus/model/reconstructor/api.py index ef0fe1f2..ce813ed5 100644 --- a/ptychodus/model/reconstructor/api.py +++ b/ptychodus/model/reconstructor/api.py @@ -1,3 +1,4 @@ +from pathlib import Path import logging from ...api.plot import Plot2D @@ -38,11 +39,14 @@ def reconstructSplit(self) -> tuple[ReconstructOutput, ReconstructOutput]: return resultOdd, resultEven - def ingest(self) -> None: - self._activeReconstructor.ingest() + def ingestTrainingData(self) -> None: + self._activeReconstructor.ingestTrainingData() + + def saveTrainingData(self, filePath: Path) -> None: + self._activeReconstructor.saveTrainingData(filePath) def train(self) -> Plot2D: return self._activeReconstructor.train() - def reset(self) -> None: - self._activeReconstructor.reset() + def clearTrainingData(self) -> None: + self._activeReconstructor.clearTrainingData() diff --git a/ptychodus/model/reconstructor/core.py b/ptychodus/model/reconstructor/core.py index 4aac0912..f94ef300 100644 --- a/ptychodus/model/reconstructor/core.py +++ b/ptychodus/model/reconstructor/core.py @@ -1,5 +1,6 @@ from __future__ import annotations from collections.abc import Sequence +from pathlib import Path import logging from ...api.observer import Observable, Observer @@ -72,15 +73,24 @@ def getPlot(self) -> Plot2D: def isTrainable(self) -> bool: return self._activeReconstructor.isTrainable - def ingest(self) -> None: - self._reconstructorAPI.ingest() + def ingestTrainingData(self) -> None: + self._reconstructorAPI.ingestTrainingData() + + def getSaveFileFilterList(self) -> Sequence[str]: + return self._activeReconstructor.getSaveFileFilterList() + + def getSaveFileFilter(self) -> str: + return self._activeReconstructor.getSaveFileFilter() + + def saveTrainingData(self, filePath: Path) -> None: + self._reconstructorAPI.saveTrainingData(filePath) def train(self) -> None: self._plot2D = self._reconstructorAPI.train() self.notifyObservers() - def reset(self) -> None: - self._reconstructorAPI.reset() + def clearTrainingData(self) -> None: + self._reconstructorAPI.clearTrainingData() def update(self, observable: Observable) -> None: if observable is self._activeReconstructor: diff --git a/ptychodus/model/tike/core.py b/ptychodus/model/tike/core.py index a350134e..2f3beeb3 100644 --- a/ptychodus/model/tike/core.py +++ b/ptychodus/model/tike/core.py @@ -85,16 +85,6 @@ def getConvergenceWindow(self) -> int: def setConvergenceWindow(self, value: int) -> None: self._settings.convergenceWindow.value = value - def getCgIterLimits(self) -> Interval[int]: - return Interval[int](1, 64) - - def getCgIter(self) -> int: - limits = self.getCgIterLimits() - return limits.clamp(self._settings.cgIter.value) - - def setCgIter(self, value: int) -> None: - self._settings.cgIter.value = value - def getAlphaLimits(self) -> Interval[Decimal]: return Interval[Decimal](Decimal(0), Decimal(1)) diff --git a/ptychodus/model/tike/settings.py b/ptychodus/model/tike/settings.py index 5adf35a3..7c082f83 100644 --- a/ptychodus/model/tike/settings.py +++ b/ptychodus/model/tike/settings.py @@ -15,7 +15,6 @@ def __init__(self, settingsGroup: SettingsGroup) -> None: self.batchMethod = settingsGroup.createStringEntry('BatchMethod', 'wobbly_center') self.numIter = settingsGroup.createIntegerEntry('NumIter', 1) self.convergenceWindow = settingsGroup.createIntegerEntry('ConvergenceWindow', 0) - self.cgIter = settingsGroup.createIntegerEntry('CgIter', 2) self.alpha = settingsGroup.createRealEntry('Alpha', '0.05') self.stepLength = settingsGroup.createRealEntry('StepLength', '1') diff --git a/ptychodus/view/ptychonn.py b/ptychodus/view/ptychonn.py index f19d2c7f..ace52bfc 100644 --- a/ptychodus/view/ptychonn.py +++ b/ptychodus/view/ptychonn.py @@ -15,8 +15,8 @@ def __init__(self, parent: Optional[QWidget]) -> None: self.modelStateLabel = QLabel('Model State:') self.modelStateLineEdit = QLineEdit() self.modelStateBrowseButton = QPushButton('Browse') - self.numberOfConvolutionChannelsLabel = QLabel('Convolution Channels:') - self.numberOfConvolutionChannelsSpinBox = QSpinBox() + self.numberOfConvolutionKernelsLabel = QLabel('Convolution Kernels:') + self.numberOfConvolutionKernelsSpinBox = QSpinBox() self.batchSizeLabel = QLabel('Batch Size:') self.batchSizeSpinBox = QSpinBox() self.useBatchNormalizationCheckBox = QCheckBox('Use Batch Normalization') @@ -29,8 +29,8 @@ def createInstance(cls, parent: Optional[QWidget] = None) -> PtychoNNModelParame layout.addWidget(view.modelStateLabel, 0, 0) layout.addWidget(view.modelStateLineEdit, 0, 1) layout.addWidget(view.modelStateBrowseButton, 0, 2) - layout.addWidget(view.numberOfConvolutionChannelsLabel, 1, 0) - layout.addWidget(view.numberOfConvolutionChannelsSpinBox, 1, 1, 1, 2) + layout.addWidget(view.numberOfConvolutionKernelsLabel, 1, 0) + layout.addWidget(view.numberOfConvolutionKernelsSpinBox, 1, 1, 1, 2) layout.addWidget(view.batchSizeLabel, 2, 0) layout.addWidget(view.batchSizeSpinBox, 2, 1, 1, 2) layout.addWidget(view.useBatchNormalizationCheckBox, 3, 0, 1, 3) @@ -79,7 +79,6 @@ def __init__(self, parent: Optional[QWidget]) -> None: self.trainingEpochsSpinBox = QSpinBox() self.statusIntervalSpinBox = QSpinBox() self.outputParametersView = PtychoNNOutputParametersView.createInstance() - self.saveTrainingDataButton = QPushButton("Save Training Data") @classmethod def createInstance(cls, parent: Optional[QWidget] = None) -> PtychoNNTrainingParametersView: @@ -94,7 +93,6 @@ def createInstance(cls, parent: Optional[QWidget] = None) -> PtychoNNTrainingPar layout.addRow('Training Epochs:', view.trainingEpochsSpinBox) layout.addRow('Status Interval:', view.statusIntervalSpinBox) layout.addRow(view.outputParametersView) - layout.addRow(view.saveTrainingDataButton) view.setLayout(layout) return view diff --git a/ptychodus/view/reconstructor.py b/ptychodus/view/reconstructor.py index eef37ab5..bf11f2e6 100644 --- a/ptychodus/view/reconstructor.py +++ b/ptychodus/view/reconstructor.py @@ -24,36 +24,42 @@ def __init__(self, parent: Optional[QWidget]) -> None: self.objectLabel = QLabel('Object:') self.objectComboBox = QComboBox() self.objectValidationLabel = QLabel() - self.reconstructButton = QPushButton('Reconstruct') - self.reconstructSplitButton = QPushButton('Split') self.ingestButton = QPushButton('Ingest') + self.saveButton = QPushButton('Save') self.trainButton = QPushButton('Train') - self.resetButton = QPushButton('Reset') + self.clearButton = QPushButton('Clear') + self.reconstructButton = QPushButton('Reconstruct') + self.reconstructSplitButton = QPushButton('Split') @classmethod def createInstance(cls, parent: Optional[QWidget] = None) -> ReconstructorView: view = cls(parent) + view.ingestButton.setToolTip('Ingest Training Data') + view.saveButton.setToolTip('Save Training Data') + view.trainButton.setToolTip('Train Model') + view.clearButton.setToolTip('Reset Training Data Buffers') view.reconstructButton.setToolTip('Reconstruct Full Dataset') view.reconstructSplitButton.setToolTip('Reconstruct Odd/Even Split Dataset') layout = QGridLayout() layout.addWidget(view.algorithmLabel, 0, 0) - layout.addWidget(view.algorithmComboBox, 0, 1, 1, 3) + layout.addWidget(view.algorithmComboBox, 0, 1, 1, 4) layout.addWidget(view.scanLabel, 1, 0) - layout.addWidget(view.scanComboBox, 1, 1, 1, 3) - layout.addWidget(view.scanValidationLabel, 1, 4) + layout.addWidget(view.scanComboBox, 1, 1, 1, 4) + layout.addWidget(view.scanValidationLabel, 1, 5) layout.addWidget(view.probeLabel, 2, 0) - layout.addWidget(view.probeComboBox, 2, 1, 1, 3) - layout.addWidget(view.probeValidationLabel, 2, 4) + layout.addWidget(view.probeComboBox, 2, 1, 1, 4) + layout.addWidget(view.probeValidationLabel, 2, 5) layout.addWidget(view.objectLabel, 3, 0) - layout.addWidget(view.objectComboBox, 3, 1, 1, 3) - layout.addWidget(view.objectValidationLabel, 3, 4) + layout.addWidget(view.objectComboBox, 3, 1, 1, 4) + layout.addWidget(view.objectValidationLabel, 3, 5) layout.addWidget(view.ingestButton, 4, 1) - layout.addWidget(view.trainButton, 4, 2) - layout.addWidget(view.resetButton, 4, 3) - layout.addWidget(view.reconstructButton, 5, 1, 1, 2) - layout.addWidget(view.reconstructSplitButton, 5, 3) + layout.addWidget(view.saveButton, 4, 2) + layout.addWidget(view.trainButton, 4, 3) + layout.addWidget(view.clearButton, 4, 4) + layout.addWidget(view.reconstructButton, 5, 1, 1, 3) + layout.addWidget(view.reconstructSplitButton, 5, 4) layout.setColumnStretch(1, 1) layout.setColumnStretch(2, 1) layout.setColumnStretch(3, 1) diff --git a/ptychodus/view/tike.py b/ptychodus/view/tike.py index 8b3b9355..2d864fe1 100644 --- a/ptychodus/view/tike.py +++ b/ptychodus/view/tike.py @@ -18,14 +18,12 @@ def __init__(self, parent: Optional[QWidget]) -> None: self.batchMethodComboBox = QComboBox() self.numIterSpinBox = QSpinBox() self.convergenceWindowSpinBox = QSpinBox() - self.cgIterSpinBox = QSpinBox() self.alphaSlider = DecimalSlider.createInstance(Qt.Horizontal) self.stepLengthSlider = DecimalSlider.createInstance(Qt.Horizontal) self.logLevelComboBox = QComboBox() @classmethod def createInstance(cls, - showCgIter: bool, showAlpha: bool, showStepLength: bool, parent: Optional[QWidget] = None) -> TikeBasicParametersView: @@ -42,8 +40,6 @@ def createInstance(cls, view.convergenceWindowSpinBox.setToolTip( 'The number of epochs to consider for convergence monitoring. ' 'Set to any value less than 2 to disable.') - view.cgIterSpinBox.setToolTip( - 'The number of conjugate directions to search for each update.') view.alphaSlider.setToolTip('RPIE becomes EPIE when this parameter is 1.') view.stepLengthSlider.setToolTip( 'Scales the inital search directions before the line search.') @@ -56,9 +52,6 @@ def createInstance(cls, layout.addRow('Number of Iterations:', view.numIterSpinBox) layout.addRow('Convergence Window:', view.convergenceWindowSpinBox) - if showCgIter: - layout.addRow('CG Search Directions:', view.cgIterSpinBox) - if showAlpha: layout.addRow('Alpha:', view.alphaSlider) @@ -225,11 +218,10 @@ def createInstance(cls, parent: Optional[QWidget] = None) -> TikeObjectCorrectio class TikeParametersView(QWidget): - def __init__(self, showCgIter: bool, showAlpha: bool, showStepLength: bool, - parent: Optional[QWidget]) -> None: + def __init__(self, showAlpha: bool, showStepLength: bool, parent: Optional[QWidget]) -> None: super().__init__(parent) self.basicParametersView = TikeBasicParametersView.createInstance( - showCgIter, showAlpha, showStepLength) + showAlpha, showStepLength) self.multigridView = TikeMultigridView.createInstance() self.positionCorrectionView = TikePositionCorrectionView.createInstance() self.probeCorrectionView = TikeProbeCorrectionView.createInstance() @@ -237,11 +229,10 @@ def __init__(self, showCgIter: bool, showAlpha: bool, showStepLength: bool, @classmethod def createInstance(cls, - showCgIter: bool, showAlpha: bool, showStepLength: bool, parent: Optional[QWidget] = None) -> TikeParametersView: - view = cls(showCgIter, showAlpha, showStepLength, parent) + view = cls(showAlpha, showStepLength, parent) layout = QVBoxLayout() layout.addWidget(view.basicParametersView)