Skip to content

Commit

Permalink
Support PtychoNN amplitude+phase model (#67)
Browse files Browse the repository at this point in the history
- Add support for PtychoNN amplitude+phase model
  • Loading branch information
stevehenke authored Dec 18, 2023
1 parent 6e51417 commit afeee0e
Show file tree
Hide file tree
Showing 20 changed files with 324 additions and 256 deletions.
30 changes: 22 additions & 8 deletions ptychodus/api/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,29 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput:
class TrainableReconstructor(Reconstructor):

@abstractmethod
def ingest(self, parameters: ReconstructInput) -> None:
def ingestTrainingData(self, parameters: ReconstructInput) -> None:
pass

@abstractmethod
def train(self) -> Plot2D:
def getSaveFileFilterList(self) -> Sequence[str]:
pass

@abstractmethod
def reset(self) -> None:
def getSaveFileFilter(self) -> str:
pass

@abstractmethod
def saveTrainingData(self, filePath: Path) -> None:
pass

@abstractmethod
def train(self) -> Plot2D:
pass

@abstractmethod
def clearTrainingData(self) -> None:
pass


class NullReconstructor(TrainableReconstructor):

Expand All @@ -109,18 +117,24 @@ def reconstruct(self, parameters: ReconstructInput) -> ReconstructOutput:
result=0,
)

def ingest(self, parameters: ReconstructInput) -> None:
def ingestTrainingData(self, parameters: ReconstructInput) -> None:
pass

def train(self) -> Plot2D:
return Plot2D.createNull()
def getSaveFileFilterList(self) -> Sequence[str]:
return list()

def reset(self) -> None:
pass
def getSaveFileFilter(self) -> str:
return str()

def saveTrainingData(self, filePath: Path) -> None:
pass

def train(self) -> Plot2D:
return Plot2D.createNull()

def clearTrainingData(self) -> None:
pass


class ReconstructorLibrary(Iterable[Reconstructor]):

Expand Down
1 change: 1 addition & 0 deletions ptychodus/controller/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None:
model.objectPresenter,
view.reconstructorParametersView,
view.reconstructorPlotView,
self._fileDialogFactory,
[
self._ptychopyViewControllerFactory, self._ptychonnViewControllerFactory,
self._tikeViewControllerFactory
Expand Down
18 changes: 9 additions & 9 deletions ptychodus/controller/ptychonn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def createInstance(cls, presenter: PtychoNNModelPresenter, view: PtychoNNModelPa

view.modelStateLineEdit.editingFinished.connect(controller._syncModelStateFilePathToModel)
view.modelStateBrowseButton.clicked.connect(controller._openModelState)
view.numberOfConvolutionChannelsSpinBox.valueChanged.connect(
presenter.setNumberOfConvolutionChannels)
view.numberOfConvolutionKernelsSpinBox.valueChanged.connect(
presenter.setNumberOfConvolutionKernels)
view.batchSizeSpinBox.valueChanged.connect(presenter.setBatchSize)
view.useBatchNormalizationCheckBox.toggled.connect(presenter.setBatchNormalizationEnabled)

Expand Down Expand Up @@ -54,13 +54,13 @@ def _syncModelToView(self) -> None:
else:
self._view.modelStateLineEdit.clear()

self._view.numberOfConvolutionChannelsSpinBox.blockSignals(True)
self._view.numberOfConvolutionChannelsSpinBox.setRange(
self._presenter.getNumberOfConvolutionChannelsLimits().lower,
self._presenter.getNumberOfConvolutionChannelsLimits().upper)
self._view.numberOfConvolutionChannelsSpinBox.setValue(
self._presenter.getNumberOfConvolutionChannels())
self._view.numberOfConvolutionChannelsSpinBox.blockSignals(False)
self._view.numberOfConvolutionKernelsSpinBox.blockSignals(True)
self._view.numberOfConvolutionKernelsSpinBox.setRange(
self._presenter.getNumberOfConvolutionKernelsLimits().lower,
self._presenter.getNumberOfConvolutionKernelsLimits().upper)
self._view.numberOfConvolutionKernelsSpinBox.setValue(
self._presenter.getNumberOfConvolutionKernels())
self._view.numberOfConvolutionKernelsSpinBox.blockSignals(False)

self._view.batchSizeSpinBox.blockSignals(True)
self._view.batchSizeSpinBox.setRange(self._presenter.getBatchSizeLimits().lower,
Expand Down
17 changes: 0 additions & 17 deletions ptychodus/controller/ptychonn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...api.observer import Observable, Observer
from ...model.ptychonn import PtychoNNTrainingPresenter
from ...view.ptychonn import PtychoNNOutputParametersView, PtychoNNTrainingParametersView
from ...view.widgets import ExceptionDialog
from ..data import FileDialogFactory

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,7 +79,6 @@ def __init__(self, presenter: PtychoNNTrainingPresenter, view: PtychoNNTrainingP
super().__init__()
self._presenter = presenter
self._view = view
self._fileDialogFactory = fileDialogFactory
self._outputParametersController = PtychoNNOutputParametersController.createInstance(
presenter, view.outputParametersView, fileDialogFactory)

Expand All @@ -99,26 +97,11 @@ def createInstance(
view.minimumLearningRateLineEdit.valueChanged.connect(presenter.setMinimumLearningRate)
view.trainingEpochsSpinBox.valueChanged.connect(presenter.setTrainingEpochs)
view.statusIntervalSpinBox.valueChanged.connect(presenter.setStatusIntervalInEpochs)
view.saveTrainingDataButton.clicked.connect(controller._saveTrainingData)

controller._syncModelToView()

return controller

def _saveTrainingData(self) -> None:
filePath, _ = self._fileDialogFactory.getSaveFilePath(
self._view,
'Save Training Data',
nameFilters=self._presenter.getSaveFileFilterList(),
selectedNameFilter=self._presenter.getSaveFileFilter())

if filePath:
try:
self._presenter.saveTrainingData(filePath)
except Exception as err:
logger.exception(err)
ExceptionDialog.showException('File writer', err)

def _syncModelToView(self) -> None:
self._view.validationSetFractionalSizeSlider.setValueAndRange(
self._presenter.getValidationSetFractionalSize(),
Expand Down
64 changes: 37 additions & 27 deletions ptychodus/controller/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..model.scan import ScanPresenter
from ..view.reconstructor import ReconstructorParametersView, ReconstructorPlotView
from ..view.widgets import ExceptionDialog
from .data import FileDialogFactory

logger = logging.getLogger(__name__)

Expand All @@ -32,23 +33,19 @@ def createViewController(self, reconstructorName: str) -> QWidget:

class ReconstructorParametersController(Observer):

def __init__(
self,
presenter: ReconstructorPresenter,
scanPresenter: ScanPresenter,
probePresenter: ProbePresenter,
objectPresenter: ObjectPresenter,
view: ReconstructorParametersView,
plotView: ReconstructorPlotView,
viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory],
) -> None:
def __init__(self, presenter: ReconstructorPresenter, scanPresenter: ScanPresenter,
probePresenter: ProbePresenter, objectPresenter: ObjectPresenter,
view: ReconstructorParametersView, plotView: ReconstructorPlotView,
fileDialogFactory: FileDialogFactory,
viewControllerFactoryList: Iterable[ReconstructorViewControllerFactory]) -> None:
super().__init__()
self._presenter = presenter
self._scanPresenter = scanPresenter
self._probePresenter = probePresenter
self._objectPresenter = objectPresenter
self._view = view
self._plotView = plotView
self._fileDialogFactory = fileDialogFactory
self._viewControllerFactoryDict: dict[str, ReconstructorViewControllerFactory] = \
{ vcf.backendName: vcf for vcf in viewControllerFactoryList }
self._scanListModel = QStringListModel()
Expand All @@ -57,17 +54,14 @@ def __init__(

@classmethod
def createInstance(
cls,
presenter: ReconstructorPresenter,
scanPresenter: ScanPresenter,
probePresenter: ProbePresenter,
objectPresenter: ObjectPresenter,
view: ReconstructorParametersView,
plotView: ReconstructorPlotView,
viewControllerFactoryList: list[ReconstructorViewControllerFactory],
cls, presenter: ReconstructorPresenter, scanPresenter: ScanPresenter,
probePresenter: ProbePresenter, objectPresenter: ObjectPresenter,
view: ReconstructorParametersView, plotView: ReconstructorPlotView,
fileDialogFactory: FileDialogFactory,
viewControllerFactoryList: list[ReconstructorViewControllerFactory]
) -> ReconstructorParametersController:
controller = cls(presenter, scanPresenter, probePresenter, objectPresenter, view, plotView,
viewControllerFactoryList)
fileDialogFactory, viewControllerFactoryList)
presenter.addObserver(controller)
scanPresenter.addObserver(controller)
probePresenter.addObserver(controller)
Expand All @@ -91,9 +85,10 @@ def createInstance(

view.reconstructorView.reconstructButton.clicked.connect(controller._reconstruct)
view.reconstructorView.reconstructSplitButton.clicked.connect(controller._reconstructSplit)
view.reconstructorView.ingestButton.clicked.connect(controller._ingest)
view.reconstructorView.ingestButton.clicked.connect(controller._ingestTrainingData)
view.reconstructorView.saveButton.clicked.connect(controller._saveTrainingData)
view.reconstructorView.trainButton.clicked.connect(controller._train)
view.reconstructorView.resetButton.clicked.connect(controller._reset)
view.reconstructorView.clearButton.clicked.connect(controller._clearTrainingData)

controller._syncModelToView()
controller._syncScanToView()
Expand Down Expand Up @@ -130,26 +125,40 @@ def _reconstructSplit(self) -> None:
logger.exception(err)
ExceptionDialog.showException('Split Reconstructor', err)

def _ingest(self) -> None:
def _ingestTrainingData(self) -> None:
try:
self._presenter.ingest()
self._presenter.ingestTrainingData()
except Exception as err:
logger.exception(err)
ExceptionDialog.showException('Ingester', err)

def _saveTrainingData(self) -> None:
filePath, _ = self._fileDialogFactory.getSaveFilePath(
self._view,
'Save Training Data',
nameFilters=self._presenter.getSaveFileFilterList(),
selectedNameFilter=self._presenter.getSaveFileFilter())

if filePath:
try:
self._presenter.saveTrainingData(filePath)
except Exception as err:
logger.exception(err)
ExceptionDialog.showException('File writer', err)

def _train(self) -> None:
try:
self._presenter.train()
except Exception as err:
logger.exception(err)
ExceptionDialog.showException('Trainer', err)

def _reset(self) -> None:
def _clearTrainingData(self) -> None:
try:
self._presenter.reset()
self._presenter.clearTrainingData()
except Exception as err:
logger.exception(err)
ExceptionDialog.showException('Reset', err)
ExceptionDialog.showException('Clear', err)

def _syncScanToView(self) -> None:
self._view.reconstructorView.scanComboBox.blockSignals(True)
Expand Down Expand Up @@ -221,8 +230,9 @@ def _syncModelToView(self) -> None:

isTrainable = self._presenter.isTrainable
self._view.reconstructorView.ingestButton.setVisible(isTrainable)
self._view.reconstructorView.saveButton.setVisible(isTrainable)
self._view.reconstructorView.trainButton.setVisible(isTrainable)
self._view.reconstructorView.resetButton.setVisible(isTrainable)
self._view.reconstructorView.clearButton.setVisible(isTrainable)

self._redrawPlot()

Expand Down
7 changes: 0 additions & 7 deletions ptychodus/controller/tike/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def createInstance(cls, presenter: TikePresenter,

view.numIterSpinBox.valueChanged.connect(presenter.setNumIter)
view.convergenceWindowSpinBox.valueChanged.connect(presenter.setConvergenceWindow)
view.cgIterSpinBox.valueChanged.connect(presenter.setCgIter)

view.alphaSlider.valueChanged.connect(presenter.setAlpha)
view.stepLengthSlider.valueChanged.connect(presenter.setStepLength)
Expand Down Expand Up @@ -81,12 +80,6 @@ def _syncModelToView(self) -> None:
self._view.convergenceWindowSpinBox.setValue(self._presenter.getConvergenceWindow())
self._view.convergenceWindowSpinBox.blockSignals(False)

self._view.cgIterSpinBox.blockSignals(True)
self._view.cgIterSpinBox.setRange(self._presenter.getCgIterLimits().lower,
self._presenter.getCgIterLimits().upper)
self._view.cgIterSpinBox.setValue(self._presenter.getCgIter())
self._view.cgIterSpinBox.blockSignals(False)

self._view.alphaSlider.setValueAndRange(self._presenter.getAlpha(),
self._presenter.getAlphaLimits(),
blockValueChangedSignal=True)
Expand Down
16 changes: 4 additions & 12 deletions ptychodus/controller/tike/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@ def createViewController(self, reconstructorName: str) -> QWidget:
view = None

if reconstructorName == 'rpie':
view = TikeParametersView.createInstance(showCgIter=False,
showAlpha=True,
showStepLength=False)
view = TikeParametersView.createInstance(showAlpha=True, showStepLength=False)
elif reconstructorName == 'lstsq_grad':
view = TikeParametersView.createInstance(showCgIter=False,
showAlpha=False,
showStepLength=False)
view = TikeParametersView.createInstance(showAlpha=False, showStepLength=False)
elif reconstructorName == 'dm':
view = TikeParametersView.createInstance(showCgIter=False,
showAlpha=False,
showStepLength=False)
view = TikeParametersView.createInstance(showAlpha=False, showStepLength=False)
else:
view = TikeParametersView.createInstance(showCgIter=True,
showAlpha=True,
showStepLength=True)
view = TikeParametersView.createInstance(showAlpha=True, showStepLength=True)

controller = TikeParametersController.createInstance(self._model, view)
self._controllerList.append(controller)
Expand Down
2 changes: 1 addition & 1 deletion ptychodus/model/automation/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def __init__(self, registry: StateDataRegistry, reconstructorAPI: ReconstructorA
def execute(self, filePath: Path) -> None:
# TODO watch for ptychodus NPZ files
self._registry.openStateData(filePath)
self._reconstructorAPI.ingest()
self._reconstructorAPI.ingestTrainingData()
self._reconstructorAPI.train()
2 changes: 1 addition & 1 deletion ptychodus/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def batchModeTrain(self, directoryPath: Path) -> float:
for filePath in directoryPath.glob('*.npz'):
# TODO sort by filePath.stat().st_mtime
self._stateDataRegistry.openStateData(filePath)
self._reconstructorCore.reconstructorAPI.ingest()
self._reconstructorCore.reconstructorAPI.ingestTrainingData()

self._reconstructorCore.reconstructorAPI.train()

Expand Down
Loading

0 comments on commit afeee0e

Please sign in to comment.