From 2f58d31523e042a8be0edc9ff12c035b1c48293a Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Fri, 18 Oct 2024 08:59:24 -0500 Subject: [PATCH 1/9] implement --- ptychodus/api/patterns.py | 1 + ptychodus/api/product.py | 14 +++++++++ ptychodus/api/reconstructor.py | 3 +- ptychodus/model/analysis/core.py | 2 +- ptychodus/model/analysis/fluorescence.py | 30 +++++++++----------- ptychodus/model/patterns/active.py | 10 +++++++ ptychodus/model/product/api.py | 8 +++--- ptychodus/model/product/item.py | 18 ++++++++++++ ptychodus/model/product/metadata.py | 10 ++++++- ptychodus/model/product/object/builder.py | 2 ++ ptychodus/model/product/object/item.py | 8 ++++-- ptychodus/model/product/probe/builder.py | 2 ++ ptychodus/model/product/probe/item.py | 13 +++++++-- ptychodus/model/product/probe/itemFactory.py | 4 +-- ptychodus/model/product/probe/multimodal.py | 2 ++ ptychodus/model/product/productGeometry.py | 13 +++++++-- ptychodus/model/product/productRepository.py | 12 +++----- ptychodus/model/product/scan/builder.py | 2 ++ ptychodus/model/product/scan/item.py | 9 ++++-- ptychodus/model/reconstructor/api.py | 6 +++- ptychodus/model/reconstructor/matcher.py | 4 ++- ptychodus/model/tike/settings.py | 1 - 22 files changed, 128 insertions(+), 46 deletions(-) diff --git a/ptychodus/api/patterns.py b/ptychodus/api/patterns.py index 3539ae3b..877f0446 100644 --- a/ptychodus/api/patterns.py +++ b/ptychodus/api/patterns.py @@ -13,6 +13,7 @@ from .observer import Observable from .tree import SimpleTreeNode +BooleanArrayType: TypeAlias = numpy.typing.NDArray[numpy.bool_] DiffractionPatternArrayType: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] DiffractionPatternIndexes: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]] diff --git a/ptychodus/api/product.py b/ptychodus/api/product.py index aeb80cc7..73b660d4 100644 --- a/ptychodus/api/product.py +++ b/ptychodus/api/product.py @@ -4,6 +4,7 @@ from pathlib import Path from sys import getsizeof +from .constants import ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ from .object import Object from .probe import Probe from .scan import Scan @@ -18,6 +19,19 @@ class ProductMetadata: probePhotonsPerSecond: float exposureTimeInSeconds: float + @property + def probeEnergyInJoules(self) -> float: + return self.probeEnergyInElectronVolts * ELECTRON_VOLT_J + + @property + def probeWavelengthInMeters(self) -> float: + hc_Jm = PLANCK_CONSTANT_J_PER_HZ * LIGHT_SPEED_M_PER_S + + try: + return hc_Jm / self.probeEnergyInJoules + except ZeroDivisionError: + return 0.0 + @property def sizeInBytes(self) -> int: sz = getsizeof(self.name) diff --git a/ptychodus/api/reconstructor.py b/ptychodus/api/reconstructor.py index a23acb11..ad5eb67e 100644 --- a/ptychodus/api/reconstructor.py +++ b/ptychodus/api/reconstructor.py @@ -5,12 +5,13 @@ from pathlib import Path from .product import Product -from .patterns import DiffractionPatternArrayType +from .patterns import BooleanArrayType, DiffractionPatternArrayType @dataclass(frozen=True) class ReconstructInput: patterns: DiffractionPatternArrayType + goodPixelMask: BooleanArrayType product: Product diff --git a/ptychodus/model/analysis/core.py b/ptychodus/model/analysis/core.py index 5091b365..7225d657 100644 --- a/ptychodus/model/analysis/core.py +++ b/ptychodus/model/analysis/core.py @@ -49,7 +49,7 @@ def __init__( self._fluorescenceSettings = FluorescenceSettings(settingsRegistry) self.fluorescenceEnhancer = FluorescenceEnhancer( self._fluorescenceSettings, - dataMatcher, + productRepository, upscalingStrategyChooser, deconvolutionStrategyChooser, fluorescenceFileReaderChooser, diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 57966c9b..19a18998 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -22,7 +22,7 @@ from ptychodus.api.product import Product from ptychodus.api.typing import RealArrayType -from ..reconstructor import DiffractionPatternPositionMatcher +from ..product import ProductRepository from .settings import FluorescenceSettings logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def get_axis_weights_and_indexes( class VSPILinearOperator(LinearOperator): - def __init__(self, product: Product, xrf_nchannels: int) -> None: + def __init__(self, product: Product) -> None: """ M: number of XRF positions N: number of ptychography object pixels @@ -71,11 +71,11 @@ def __init__(self, product: Product, xrf_nchannels: int) -> None: A[M,N] * X[N,P] = B[M,P] """ - super().__init__(float, (len(product.scan), xrf_nchannels)) + super().__init__(float, (len(product.scan), len(product.scan))) self._product = product - def matmat(self, X: RealArrayType) -> RealArrayType: - AX = numpy.zeros(self.shape, dtype=self.dtype) + def _matvec(self, X: RealArrayType) -> RealArrayType: + AX = numpy.zeros(X.shape, dtype=self.dtype) probeGeometry = self._product.probe.getGeometry() dx_p_m = probeGeometry.pixelWidthInMeters @@ -103,7 +103,7 @@ def matmat(self, X: RealArrayType) -> RealArrayType: i_nz = numpy.ravel_multi_index(list(zip(IY.flat, IX.flat)), objectShape) X_nz = X.take(i_nz, axis=0) - AX[index, :] = numpy.matmul(numpy.outer(wy, wx).ravel(), X_nz) + AX[index] = numpy.matmul(numpy.outer(wy, wx).ravel(), X_nz) return AX @@ -115,7 +115,7 @@ class FluorescenceEnhancer(Observable, Observer): def __init__( self, settings: FluorescenceSettings, - dataMatcher: DiffractionPatternPositionMatcher, # FIXME match XRF too + productRepository: ProductRepository, upscalingStrategyChooser: PluginChooser[UpscalingStrategy], deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], fileReaderChooser: PluginChooser[FluorescenceFileReader], @@ -124,7 +124,7 @@ def __init__( ) -> None: super().__init__() self._settings = settings - self._dataMatcher = dataMatcher + self._productRepository = productRepository self._upscalingStrategyChooser = upscalingStrategyChooser self._deconvolutionStrategyChooser = deconvolutionStrategyChooser self._fileReaderChooser = fileReaderChooser @@ -152,7 +152,7 @@ def setProduct(self, productIndex: int) -> None: self.notifyObservers() def getProductName(self) -> str: - return self._dataMatcher.getProductName(self._productIndex) + return self._productRepository[self._productIndex].getName() def getOpenFileFilterList(self) -> Sequence[str]: return self._fileReaderChooser.getDisplayNameList() @@ -222,14 +222,12 @@ def enhanceFluorescence(self) -> None: if self._measured is None: raise ValueError('Fluorescence dataset not loaded!') - reconstructInput = self._dataMatcher.matchDiffractionPatternsWithPositions( - self._productIndex - ) + product = self._productRepository[self._productIndex].getProduct() element_maps: list[ElementMap] = list() if self._settings.useVSPI.getValue(): measured_emaps = self._measured.element_maps - A = VSPILinearOperator(reconstructInput.product, len(measured_emaps)) + A = VSPILinearOperator(product) B = numpy.stack([b.counts_per_second.flatten() for b in measured_emaps]).T X, info = gmres(A, B, atol=1e-5) # TODO expose atol @@ -246,8 +244,8 @@ def enhanceFluorescence(self) -> None: for emap in self._measured.element_maps: logger.info(f'Enhancing "{emap.name}"') - emap_upscaled = upscaler(emap, reconstructInput.product) - emap_enhanced = deconvolver(emap_upscaled, reconstructInput.product) + emap_upscaled = upscaler(emap, product) + emap_enhanced = deconvolver(emap_upscaled, product) element_maps.append(emap_enhanced) self._enhanced = FluorescenceDataset( @@ -258,7 +256,7 @@ def enhanceFluorescence(self) -> None: self.notifyObservers() def getPixelGeometry(self) -> PixelGeometry: - return self._dataMatcher.getObjectPlanePixelGeometry(self._productIndex) + return self._productRepository[self._productIndex].getGeometry().getPixelGeometry() def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: if self._enhanced is None: diff --git a/ptychodus/model/patterns/active.py b/ptychodus/model/patterns/active.py index 0a81d618..b43cdaa2 100644 --- a/ptychodus/model/patterns/active.py +++ b/ptychodus/model/patterns/active.py @@ -10,6 +10,7 @@ from ptychodus.api.geometry import ImageExtent from ptychodus.api.patterns import ( + BooleanArrayType, DiffractionDataset, DiffractionMetadata, DiffractionPatternArray, @@ -141,6 +142,15 @@ def insertArray(self, array: DiffractionPatternArray) -> None: self._changedEvent.set() + def getGoodPixelMask(self) -> BooleanArrayType: # FIXME + return numpy.full( + ( + self._diffractionPatternSizer.getHeightInPixels(), + self._diffractionPatternSizer.getWidthInPixels(), + ), + True, + ) + def getAssembledIndexes(self) -> Sequence[int]: indexes: list[int] = list() diff --git a/ptychodus/model/product/api.py b/ptychodus/model/product/api.py index 411c088b..ee7aefe5 100644 --- a/ptychodus/model/product/api.py +++ b/ptychodus/model/product/api.py @@ -112,7 +112,7 @@ def copyScan(self, sourceIndex: int, destinationIndex: int) -> None: logger.warning(f'Failed to access destination scan {destinationIndex} for copying!') return - destinationItem.assign(sourceItem) + destinationItem.assignItem(sourceItem) def getSaveFileFilterList(self) -> Sequence[str]: return self._builderFactory.getSaveFileFilterList() @@ -220,7 +220,7 @@ def copyProbe(self, sourceIndex: int, destinationIndex: int) -> None: logger.warning(f'Failed to access destination probe {destinationIndex} for copying!') return - destinationItem.assign(sourceItem) + destinationItem.assignItem(sourceItem) def getSaveFileFilterList(self) -> Sequence[str]: return self._builderFactory.getSaveFileFilterList() @@ -328,7 +328,7 @@ def copyObject(self, sourceIndex: int, destinationIndex: int) -> None: logger.warning(f'Failed to access destination object {destinationIndex} for copying!') return - destinationItem.assign(sourceItem) + destinationItem.assignItem(sourceItem) def getSaveFileFilterList(self) -> Sequence[str]: return self._builderFactory.getSaveFileFilterList() @@ -370,7 +370,7 @@ def insertNewProduct( likeIndex: int = -1, ) -> int: return self._repository.insertNewProduct( - name, + name=name, comments=comments, detectorDistanceInMeters=detectorDistanceInMeters, probeEnergyInElectronVolts=probeEnergyInElectronVolts, diff --git a/ptychodus/model/product/item.py b/ptychodus/model/product/item.py index 52b3ec4e..096dbbbe 100644 --- a/ptychodus/model/product/item.py +++ b/ptychodus/model/product/item.py @@ -66,6 +66,24 @@ def __init__( self._addGroup('probe', self._probe, observe=True) self._addGroup('object', self._object, observe=True) + def assignItem(self, item: ProductRepositoryItem, *, notify: bool = True) -> None: + self._metadata.assignItem(item.getMetadata()) + self._scan.assignItem(item.getScan()) + self._probe.assignItem(item.getProbe()) + self._object.assignItem(item.getObject()) + self._costs = list(item.getCosts()) + + if notify: + self._parent.handleCostsChanged(self) + + def assign(self, product: Product) -> None: + self._metadata.assign(product.metadata) + self._scan.assign(product.scan) + self._probe.assign(product.probe) + self._object.assign(product.object_) + self._costs = list(product.costs) + self._parent.handleCostsChanged(self) + def syncToSettings(self) -> None: self._metadata.syncToSettings() self._scan.syncToSettings() diff --git a/ptychodus/model/product/metadata.py b/ptychodus/model/product/metadata.py index 74903c6f..7cd517b3 100644 --- a/ptychodus/model/product/metadata.py +++ b/ptychodus/model/product/metadata.py @@ -68,7 +68,7 @@ def __init__( self._index = -1 - def assign(self, item: MetadataRepositoryItem) -> None: + def assignItem(self, item: MetadataRepositoryItem, *, notify: bool = True) -> None: self.setName(item.getName()) self.comments.setValue(item.comments.getValue()) self.detectorDistanceInMeters.setValue(item.detectorDistanceInMeters.getValue()) @@ -76,6 +76,14 @@ def assign(self, item: MetadataRepositoryItem) -> None: self.probePhotonsPerSecond.setValue(item.probePhotonsPerSecond.getValue()) self.exposureTimeInSeconds.setValue(item.exposureTimeInSeconds.getValue()) + def assign(self, metadata: ProductMetadata) -> None: + self.setName(metadata.name) + self.comments.setValue(metadata.comments) + self.detectorDistanceInMeters.setValue(metadata.detectorDistanceInMeters) + self.probeEnergyInElectronVolts.setValue(metadata.probeEnergyInElectronVolts) + self.probePhotonsPerSecond.setValue(metadata.probePhotonsPerSecond) + self.exposureTimeInSeconds.setValue(metadata.exposureTimeInSeconds) + def syncToSettings(self) -> None: for parameter in self.parameters().values(): parameter.syncValueToParent() diff --git a/ptychodus/model/product/object/builder.py b/ptychodus/model/product/object/builder.py index e349395e..7a63de7b 100644 --- a/ptychodus/model/product/object/builder.py +++ b/ptychodus/model/product/object/builder.py @@ -63,8 +63,10 @@ def __init__( super().__init__(settings, 'from_file') self._settings = settings self.filePath = settings.filePath.copy() + self.filePath.setValue(filePath) self._addParameter('file_path', self.filePath) self.fileType = settings.fileType.copy() + self.fileType.setValue(fileType) self._addParameter('file_type', self.fileType) self._fileReader = fileReader diff --git a/ptychodus/model/product/object/item.py b/ptychodus/model/product/object/item.py index f2fd3e44..d4e02bce 100644 --- a/ptychodus/model/product/object/item.py +++ b/ptychodus/model/product/object/item.py @@ -7,7 +7,7 @@ from ptychodus.api.observer import Observable from ptychodus.api.parametric import ParameterGroup -from .builder import ObjectBuilder +from .builder import FromMemoryObjectBuilder, ObjectBuilder from .settings import ObjectSettings logger = logging.getLogger(__name__) @@ -32,10 +32,14 @@ def __init__( self._rebuild() - def assign(self, item: ObjectRepositoryItem) -> None: + def assignItem(self, item: ObjectRepositoryItem) -> None: self.layerDistanceInMeters.setValue(item.layerDistanceInMeters.getValue(), notify=False) self.setBuilder(item.getBuilder().copy()) + def assign(self, object_: Object) -> None: + builder = FromMemoryObjectBuilder(self._settings, object_) + self.setBuilder(builder) + def syncToSettings(self) -> None: for parameter in self.parameters().values(): parameter.syncValueToParent() diff --git a/ptychodus/model/product/probe/builder.py b/ptychodus/model/product/probe/builder.py index 87fc2d61..b8fec604 100644 --- a/ptychodus/model/product/probe/builder.py +++ b/ptychodus/model/product/probe/builder.py @@ -91,8 +91,10 @@ def __init__( super().__init__(settings, 'from_file') self._settings = settings self.filePath = settings.filePath.copy() + self.filePath.setValue(filePath) self._addParameter('file_path', self.filePath) self.fileType = settings.fileType.copy() + self.fileType.setValue(fileType) self._addParameter('file_type', self.fileType) self._fileReader = fileReader diff --git a/ptychodus/model/product/probe/item.py b/ptychodus/model/product/probe/item.py index 214a1699..b3f58615 100644 --- a/ptychodus/model/product/probe/item.py +++ b/ptychodus/model/product/probe/item.py @@ -5,8 +5,9 @@ from ptychodus.api.parametric import ParameterGroup from ptychodus.api.probe import Probe, ProbeGeometryProvider -from .builder import ProbeBuilder +from .builder import FromMemoryProbeBuilder, ProbeBuilder from .multimodal import MultimodalProbeBuilder +from .settings import ProbeSettings logger = logging.getLogger(__name__) @@ -15,11 +16,13 @@ class ProbeRepositoryItem(ParameterGroup): def __init__( self, geometryProvider: ProbeGeometryProvider, + settings: ProbeSettings, builder: ProbeBuilder, additionalModesBuilder: MultimodalProbeBuilder, ) -> None: super().__init__() self._geometryProvider = geometryProvider + self._settings = settings self._builder = builder self._additionalModesBuilder = additionalModesBuilder self._probe = Probe() @@ -29,7 +32,7 @@ def __init__( self._rebuild() - def assign(self, item: ProbeRepositoryItem) -> None: + def assignItem(self, item: ProbeRepositoryItem) -> None: self._removeGroup('additional_modes') self._additionalModesBuilder.removeObserver(self) self._additionalModesBuilder = item.getAdditionalModesBuilder().copy() @@ -37,6 +40,11 @@ def assign(self, item: ProbeRepositoryItem) -> None: self._addGroup('additional_modes', self._additionalModesBuilder, observe=True) self.setBuilder(item.getBuilder().copy()) + self._rebuild() + + def assign(self, probe: Probe) -> None: + builder = FromMemoryProbeBuilder(self._settings, probe) + self.setBuilder(builder) def syncToSettings(self) -> None: for parameter in self.parameters().values(): @@ -57,7 +65,6 @@ def setBuilder(self, builder: ProbeBuilder) -> None: self._builder = builder self._builder.addObserver(self) self._addGroup('builder', self._builder, observe=True) - self._rebuild() def _rebuild(self) -> None: try: diff --git a/ptychodus/model/product/probe/itemFactory.py b/ptychodus/model/product/probe/itemFactory.py index c29ba77e..bf49109a 100644 --- a/ptychodus/model/product/probe/itemFactory.py +++ b/ptychodus/model/product/probe/itemFactory.py @@ -33,7 +33,7 @@ def create( else FromMemoryProbeBuilder(self._settings, probe) ) multimodalBuilder = MultimodalProbeBuilder(self._rng, self._settings) - return ProbeRepositoryItem(geometryProvider, builder, multimodalBuilder) + return ProbeRepositoryItem(geometryProvider, self._settings, builder, multimodalBuilder) def createFromSettings(self, geometryProvider: ProbeGeometryProvider) -> ProbeRepositoryItem: try: @@ -43,4 +43,4 @@ def createFromSettings(self, geometryProvider: ProbeGeometryProvider) -> ProbeRe builder = self._builderFactory.createDefault() multimodalBuilder = MultimodalProbeBuilder(self._rng, self._settings) - return ProbeRepositoryItem(geometryProvider, builder, multimodalBuilder) + return ProbeRepositoryItem(geometryProvider, self._settings, builder, multimodalBuilder) diff --git a/ptychodus/model/product/probe/multimodal.py b/ptychodus/model/product/probe/multimodal.py index 1563ad6e..22c614ed 100644 --- a/ptychodus/model/product/probe/multimodal.py +++ b/ptychodus/model/product/probe/multimodal.py @@ -119,6 +119,8 @@ def _adjustRelativePower(self, probe: WavefieldArrayType) -> WavefieldArrayType: def build(self, probe: Probe) -> Probe: if self.numberOfModes.getValue() <= 1: return probe + elif self.numberOfModes.getValue() == probe.numberOfModes: + return probe array = self._initializeModes(probe.array) diff --git a/ptychodus/model/product/productGeometry.py b/ptychodus/model/product/productGeometry.py index 93c431d1..04a744cb 100644 --- a/ptychodus/model/product/productGeometry.py +++ b/ptychodus/model/product/productGeometry.py @@ -1,13 +1,14 @@ import numpy -from ptychodus.api.object import ObjectGeometry, ObjectGeometryProvider -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.probe import ProbeGeometry, ProbeGeometryProvider from ptychodus.api.constants import ( ELECTRON_VOLT_J, LIGHT_SPEED_M_PER_S, PLANCK_CONSTANT_J_PER_HZ, ) +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import ObjectGeometry, ObjectGeometryProvider +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.probe import ProbeGeometry, ProbeGeometryProvider from ..patterns import PatternSizer from .metadata import MetadataRepositoryItem @@ -63,6 +64,12 @@ def objectPlanePixelWidthInMeters(self) -> float: def objectPlanePixelHeightInMeters(self) -> float: return self._lambdaZInSquareMeters / self._patternSizer.getHeightInMeters() + def getPixelGeometry(self) -> PixelGeometry: + return PixelGeometry( + widthInMeters=self.objectPlanePixelWidthInMeters, + heightInMeters=self.objectPlanePixelHeightInMeters, + ) + @property def fresnelNumber(self) -> float: widthInMeters = self._patternSizer.getWidthInMeters() diff --git a/ptychodus/model/product/productRepository.py b/ptychodus/model/product/productRepository.py index 632b31b4..71e7751c 100644 --- a/ptychodus/model/product/productRepository.py +++ b/ptychodus/model/product/productRepository.py @@ -70,8 +70,8 @@ def _insertProduct(self, item: ProductRepositoryItem) -> int: def insertNewProduct( self, - name: str, *, + name: str = '', comments: str = '', detectorDistanceInMeters: float | None = None, probeEnergyInElectronVolts: float | None = None, @@ -92,13 +92,6 @@ def insertNewProduct( probeItem = self._probeRepositoryItemFactory.create(geometry) objectItem = self._objectRepositoryItemFactory.create(geometry) - if likeIndex >= 0: - sourceItem = self._itemList[likeIndex] - metadataItem.assign(sourceItem.getMetadata()) - scanItem.assign(sourceItem.getScan()) - probeItem.assign(sourceItem.getProbe()) - objectItem.assign(sourceItem.getObject()) - item = ProductRepositoryItem( parent=self, metadata=metadataItem, @@ -110,6 +103,9 @@ def insertNewProduct( costs=list(), ) + if likeIndex >= 0: + item.assignItem(self._itemList[likeIndex], notify=False) + return self._insertProduct(item) def insertProductFromSettings(self) -> int: diff --git a/ptychodus/model/product/scan/builder.py b/ptychodus/model/product/scan/builder.py index f1e8c93c..1be11dfe 100644 --- a/ptychodus/model/product/scan/builder.py +++ b/ptychodus/model/product/scan/builder.py @@ -55,8 +55,10 @@ def __init__( super().__init__(settings, 'from_file') self._settings = settings self.filePath = settings.filePath.copy() + self.filePath.setValue(filePath) self._addParameter('file_path', self.filePath) self.fileType = settings.fileType.copy() + self.fileType.setValue(fileType) self._addParameter('file_type', self.fileType) self._fileReader = fileReader diff --git a/ptychodus/model/product/scan/item.py b/ptychodus/model/product/scan/item.py index be81986b..3585df04 100644 --- a/ptychodus/model/product/scan/item.py +++ b/ptychodus/model/product/scan/item.py @@ -9,7 +9,7 @@ from ptychodus.api.scan import Scan, ScanBoundingBox, ScanPoint from .boundingBox import ScanBoundingBoxBuilder -from .builder import ScanBuilder +from .builder import FromMemoryScanBuilder, ScanBuilder from .settings import ScanSettings from .transform import ScanPointTransform @@ -24,6 +24,7 @@ def __init__( transform: ScanPointTransform, ) -> None: super().__init__() + self._settings = settings self._builder = builder self._transform = transform @@ -60,7 +61,7 @@ def __init__( self._rebuild() - def assign(self, item: ScanRepositoryItem) -> None: + def assignItem(self, item: ScanRepositoryItem) -> None: self._removeGroup('transform') self._transform.removeObserver(self) self._transform = item.getTransform().copy() @@ -69,6 +70,10 @@ def assign(self, item: ScanRepositoryItem) -> None: self.setBuilder(item.getBuilder().copy()) + def assign(self, scan: Scan) -> None: + builder = FromMemoryScanBuilder(self._settings, scan) + self.setBuilder(builder) + def syncToSettings(self) -> None: for parameter in self.parameters().values(): parameter.syncValueToParent() diff --git a/ptychodus/model/reconstructor/api.py b/ptychodus/model/reconstructor/api.py index d657961a..c04fd699 100644 --- a/ptychodus/model/reconstructor/api.py +++ b/ptychodus/model/reconstructor/api.py @@ -37,12 +37,16 @@ def reconstruct( inputProductIndex, indexFilter ) + outputProductIndex = self._productRepository.insertNewProduct(likeIndex=inputProductIndex) + outputProduct = self._productRepository[outputProductIndex] + tic = time.perf_counter() result = reconstructor.reconstruct(parameters) toc = time.perf_counter() logger.info(f'Reconstruction time {toc - tic:.4f} seconds. (code={result.result})') - outputProductIndex = self._productRepository.insertProduct(result.product) + outputProduct.assign(result.product) + return outputProductIndex def reconstructSplit(self, inputProductIndex: int, outputProductName: str) -> tuple[int, int]: diff --git a/ptychodus/model/reconstructor/matcher.py b/ptychodus/model/reconstructor/matcher.py index 5ccd2a5d..a3c47c7f 100644 --- a/ptychodus/model/reconstructor/matcher.py +++ b/ptychodus/model/reconstructor/matcher.py @@ -52,6 +52,8 @@ def getObjectPlanePixelGeometry(self, inputProductIndex: int) -> PixelGeometry: def matchDiffractionPatternsWithPositions( self, inputProductIndex: int, indexFilter: ScanIndexFilter = ScanIndexFilter.ALL ) -> ReconstructInput: + goodPixelMask = self._diffractionDataset.getGoodPixelMask() + inputProductItem = self._productRepository[inputProductIndex] inputProduct = inputProductItem.getProduct() dataIndexes = self._diffractionDataset.getAssembledIndexes() @@ -85,4 +87,4 @@ def matchDiffractionPatternsWithPositions( costs=inputProduct.costs, ) - return ReconstructInput(patterns, product) + return ReconstructInput(patterns, goodPixelMask, product) diff --git a/ptychodus/model/tike/settings.py b/ptychodus/model/tike/settings.py index 6878dd43..0b410dbe 100644 --- a/ptychodus/model/tike/settings.py +++ b/ptychodus/model/tike/settings.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections.abc import Sequence import logging From f3ab831575e1830e399d47e7438c5bd3b4199060 Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Fri, 18 Oct 2024 12:04:26 -0500 Subject: [PATCH 2/9] fix typo --- ptychodus/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ptychodus/__main__.py b/ptychodus/__main__.py index 6f69b989..5fb3c484 100644 --- a/ptychodus/__main__.py +++ b/ptychodus/__main__.py @@ -103,10 +103,10 @@ def main() -> int: fluorescenceInputFilePath: Path | None = None fluorescenceOutputFilePath: Path | None = None - if parsedArgs.flourescence_input is not None: + if parsedArgs.fluorescence_input is not None: fluorescenceInputFilePath = Path(parsedArgs.fluorescence_input.name) - if parsedArgs.flourescence_output is not None: + if parsedArgs.fluorescence_output is not None: fluorescenceOutputFilePath = Path(parsedArgs.fluorescence_output.name) return model.batchModeExecute( From 24b1cc73f5c5a71c9b76aadfa5e2e227f116a5ac Mon Sep 17 00:00:00 2001 From: Steve Henke Date: Fri, 18 Oct 2024 13:26:42 -0500 Subject: [PATCH 3/9] use lsqr --- ptychodus/model/analysis/fluorescence.py | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 19a18998..43a055b8 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -4,7 +4,7 @@ from typing import Final import logging -from scipy.sparse.linalg import gmres, LinearOperator +from scipy.sparse.linalg import lsqr, LinearOperator import math import numpy @@ -71,11 +71,13 @@ def __init__(self, product: Product) -> None: A[M,N] * X[N,P] = B[M,P] """ - super().__init__(float, (len(product.scan), len(product.scan))) + num_positions = len(product.scan) + num_pixels = product.object_.heightInPixels * product.object_.widthInPixels + super().__init__(float, (num_positions, num_pixels)) self._product = product def _matvec(self, X: RealArrayType) -> RealArrayType: - AX = numpy.zeros(X.shape, dtype=self.dtype) + AX = numpy.zeros(len(self._product.scan), dtype=self.dtype) probeGeometry = self._product.probe.getGeometry() dx_p_m = probeGeometry.pixelWidthInMeters @@ -226,18 +228,16 @@ def enhanceFluorescence(self) -> None: element_maps: list[ElementMap] = list() if self._settings.useVSPI.getValue(): - measured_emaps = self._measured.element_maps A = VSPILinearOperator(product) - B = numpy.stack([b.counts_per_second.flatten() for b in measured_emaps]).T - X, info = gmres(A, B, atol=1e-5) # TODO expose atol - - if info != 0: - logger.warning(f'Convergence to tolerance not achieved! {info=}') - - for m_emap, e_cps in zip(measured_emaps, X.T): - e_emap = ElementMap(m_emap.name, e_cps.reshape(m_emap.counts_per_second.shape)) - element_maps.append(e_emap) + for emap in self._measured.element_maps: + logger.info(f'Enhancing "{emap.name}"') + m_cps = emap.counts_per_second + result = lsqr(A, m_cps.flatten()) # TODO expose parameters + logger.debug(result) + e_cps = result[0].reshape(m_cps.shape) + emap_enhanced = ElementMap(emap.name, e_cps) + element_maps.append(emap_enhanced) else: upscaler = self._upscalingStrategyChooser.currentPlugin.strategy deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy From 4e6955aefe86332ad743716b21f59735eb56b705 Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Sun, 20 Oct 2024 22:09:04 -0500 Subject: [PATCH 4/9] repair ptycho+xrf algorithm --- ptychodus/model/analysis/fluorescence.py | 109 ++++++++++++----------- 1 file changed, 59 insertions(+), 50 deletions(-) diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 19a18998..3558b420 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -5,7 +5,6 @@ import logging from scipy.sparse.linalg import gmres, LinearOperator -import math import numpy from ptychodus.api.fluorescence import ( @@ -17,6 +16,7 @@ UpscalingStrategy, ) from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.object import ObjectPoint from ptychodus.api.observer import Observable, Observer from ptychodus.api.plugins import PluginChooser from ptychodus.api.product import Product @@ -28,38 +28,51 @@ logger = logging.getLogger(__name__) -def get_axis_weights_and_indexes( - xmin_o: float, dx_o: float, xmin_p: float, dx_p: float, N_p: int -) -> tuple[Sequence[float], Sequence[int]]: - weight: list[float] = [] - index: list[int] = [] +class ArrayPatchInterpolator: + def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, int]) -> None: + # top left corner of object support + xmin = point.positionXInPixels - shape[-1] / 2 + ymin = point.positionYInPixels - shape[-2] / 2 - x_l = xmin_p - n_o = math.ceil((x_l - xmin_o) / dx_o) + # whole components (pixel indexes) + xmin_wh = int(xmin) + ymin_wh = int(ymin) - for n_p in range(N_p): - x_p = xmin_p + (n_p + 1) * dx_p + # fractional (subpixel) components + xmin_fr = xmin - xmin_wh + ymin_fr = ymin - ymin_wh - while True: - x_o = xmin_o + n_o + dx_o + # bottom right corner of object patch support + xmax_wh = xmin_wh + shape[-1] + 1 + ymax_wh = ymin_wh + shape[-2] + 1 - if x_o >= x_p: - break + # reused quantities + xmin_fr_c = 1.0 - xmin_fr + ymin_fr_c = 1.0 - ymin_fr - weight.append((x_o - x_l) / dx_p) - index.append(n_o) + # barycentric interpolant weights + self._weight00 = ymin_fr_c * xmin_fr_c + self._weight01 = ymin_fr_c * xmin_fr + self._weight10 = ymin_fr * xmin_fr_c + self._weight11 = ymin_fr * xmin_fr - n_o += 1 - x_l = x_o + # extract patch support region from full object + self._support = array[ymin_wh:ymax_wh, xmin_wh:xmax_wh] - weight.append((x_p - x_l) / dx_p) - index.append(n_o) - x_l = x_p + def get_patch(self) -> RealArrayType: + """interpolate array support to extract patch""" + patch = self._weight00 * self._support[:-1, :-1] + patch += self._weight01 * self._support[:-1, 1:] + patch += self._weight10 * self._support[1:, :-1] + patch += self._weight11 * self._support[1:, 1:] + return patch - if x_o == x_p: - n_o += 1 - - return weight, index + def accumulate_patch(self, patch: RealArrayType) -> None: + """add patch update to array support""" + self._support[:-1, :-1] += self._weight00 * patch + self._support[:-1, 1:] += self._weight01 * patch + self._support[1:, :-1] += self._weight10 * patch + self._support[1:, 1:] += self._weight11 * patch class VSPILinearOperator(LinearOperator): @@ -74,38 +87,34 @@ def __init__(self, product: Product) -> None: super().__init__(float, (len(product.scan), len(product.scan))) self._product = product - def _matvec(self, X: RealArrayType) -> RealArrayType: - AX = numpy.zeros(X.shape, dtype=self.dtype) - - probeGeometry = self._product.probe.getGeometry() - dx_p_m = probeGeometry.pixelWidthInMeters - dy_p_m = probeGeometry.pixelHeightInMeters + def _get_psf(self, index: int) -> RealArrayType: + intensity = self._product.probe.getIntensity() + return intensity / numpy.sqrt(intensity.sum()) # FIXME verify + def _matvec(self, X: RealArrayType) -> RealArrayType: objectGeometry = self._product.object_.getGeometry() - objectShape = objectGeometry.heightInPixels, objectGeometry.widthInPixels - xmin_o_m = objectGeometry.minimumXInMeters - ymin_o_m = objectGeometry.minimumYInMeters - dx_o_m = objectGeometry.pixelWidthInMeters - dy_o_m = objectGeometry.pixelHeightInMeters + objectArray = X.reshape((objectGeometry.heightInPixels, objectGeometry.widthInPixels)) + AX = numpy.zeros(len(self._product.scan)) - for index, point in enumerate(self._product.scan): - xmin_p_m = point.positionXInMeters - probeGeometry.widthInMeters / 2 - ymin_p_m = point.positionYInMeters - probeGeometry.heightInMeters / 2 + for index, scanPoint in enumerate(self._product.scan): + objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) + psf = self._get_psf(index) + interpolator = ArrayPatchInterpolator(objectArray, objectPoint, psf.shape) + AX[index] = numpy.sum(psf * interpolator.get_patch()) - wx, ix = get_axis_weights_and_indexes( - xmin_o_m, dx_o_m, xmin_p_m, dx_p_m, probeGeometry.widthInPixels - ) - wy, iy = get_axis_weights_and_indexes( - ymin_o_m, dy_o_m, ymin_p_m, dy_p_m, probeGeometry.heightInPixels - ) + return AX - IY, IX = numpy.meshgrid(iy, ix) - i_nz = numpy.ravel_multi_index(list(zip(IY.flat, IX.flat)), objectShape) - X_nz = X.take(i_nz, axis=0) + def _rmatvec(self, X: RealArrayType) -> RealArrayType: + objectGeometry = self._product.object_.getGeometry() + objectArray = numpy.zeros((objectGeometry.heightInPixels, objectGeometry.widthInPixels)) - AX[index] = numpy.matmul(numpy.outer(wy, wx).ravel(), X_nz) + for index, scanPoint in enumerate(self._product.scan): + objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) + psf = self._get_psf(index) + interpolator = ArrayPatchInterpolator(objectArray, objectPoint, psf.shape) + interpolator.accumulate_patch(X[index] * psf) - return AX + return objectArray.flatten() class FluorescenceEnhancer(Observable, Observer): From 408f4281c34962837e98a8d1b7b1ad8c11bb8023 Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Sun, 20 Oct 2024 22:29:24 -0500 Subject: [PATCH 5/9] fix LinearOperator shape --- ptychodus/model/analysis/fluorescence.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 6b8a807f..1a2799cd 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -84,7 +84,10 @@ def __init__(self, product: Product) -> None: A[M,N] * X[N,P] = B[M,P] """ - super().__init__(float, (len(product.scan), len(product.scan))) + objectGeometry = product.object_.getGeometry() + M = len(product.scan) + N = objectGeometry.heightInPixels * objectGeometry.widthInPixels + super().__init__(float, (M, N)) self._product = product def _get_psf(self, index: int) -> RealArrayType: @@ -104,7 +107,7 @@ def _matvec(self, X: RealArrayType) -> RealArrayType: return AX - def _rmatvec(self, X: RealArrayType) -> RealArrayType: + def _adjoint(self, X: RealArrayType) -> RealArrayType: objectGeometry = self._product.object_.getGeometry() objectArray = numpy.zeros((objectGeometry.heightInPixels, objectGeometry.widthInPixels)) From bc6ffc8dfde0b40fcea931488252addbf80c5f10 Mon Sep 17 00:00:00 2001 From: Steve Henke Date: Mon, 21 Oct 2024 14:21:14 -0500 Subject: [PATCH 6/9] vspi algorithm working --- ptychodus/model/analysis/fluorescence.py | 61 ++++++++++++++---------- ptychodus/plugins/xrfMapsFile.py | 2 +- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 1a2799cd..7425c9f1 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -3,8 +3,9 @@ from pathlib import Path from typing import Final import logging +import time -from scipy.sparse.linalg import lsqr, LinearOperator +from scipy.sparse.linalg import lsmr, LinearOperator import numpy from ptychodus.api.fluorescence import ( @@ -29,7 +30,7 @@ class ArrayPatchInterpolator: - def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, int]) -> None: + def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: # top left corner of object support xmin = point.positionXInPixels - shape[-1] / 2 ymin = point.positionYInPixels - shape[-2] / 2 @@ -84,40 +85,42 @@ def __init__(self, product: Product) -> None: A[M,N] * X[N,P] = B[M,P] """ - objectGeometry = product.object_.getGeometry() + object_geometry = product.object_.getGeometry() M = len(product.scan) - N = objectGeometry.heightInPixels * objectGeometry.widthInPixels + N = object_geometry.heightInPixels * object_geometry.widthInPixels super().__init__(float, (M, N)) self._product = product - def _get_psf(self, index: int) -> RealArrayType: + def _get_psf(self) -> RealArrayType: intensity = self._product.probe.getIntensity() - return intensity / numpy.sqrt(intensity.sum()) # FIXME verify + return intensity / intensity.sum() def _matvec(self, X: RealArrayType) -> RealArrayType: - objectGeometry = self._product.object_.getGeometry() - objectArray = X.reshape((objectGeometry.heightInPixels, objectGeometry.widthInPixels)) + object_geometry = self._product.object_.getGeometry() + object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() AX = numpy.zeros(len(self._product.scan)) - for index, scanPoint in enumerate(self._product.scan): - objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) - psf = self._get_psf(index) - interpolator = ArrayPatchInterpolator(objectArray, objectPoint, psf.shape) + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) AX[index] = numpy.sum(psf * interpolator.get_patch()) return AX - def _adjoint(self, X: RealArrayType) -> RealArrayType: - objectGeometry = self._product.object_.getGeometry() - objectArray = numpy.zeros((objectGeometry.heightInPixels, objectGeometry.widthInPixels)) + def _rmatvec(self, X: RealArrayType) -> RealArrayType: + object_geometry = self._product.object_.getGeometry() + object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() - for index, scanPoint in enumerate(self._product.scan): - objectPoint = objectGeometry.mapScanPointToObjectPoint(scanPoint) - psf = self._get_psf(index) - interpolator = ArrayPatchInterpolator(objectArray, objectPoint, psf.shape) + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) interpolator.accumulate_patch(X[index] * psf) - return objectArray.flatten() + HX = object_array.flatten() + + return HX class FluorescenceEnhancer(Observable, Observer): @@ -235,27 +238,37 @@ def enhanceFluorescence(self) -> None: raise ValueError('Fluorescence dataset not loaded!') product = self._productRepository[self._productIndex].getProduct() + object_geometry = product.object_.getGeometry() + e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels element_maps: list[ElementMap] = list() if self._settings.useVSPI.getValue(): A = VSPILinearOperator(product) for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"') + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() m_cps = emap.counts_per_second - result = lsqr(A, m_cps.flatten()) # TODO expose parameters + result = lsmr(A, m_cps.flatten(), maxiter=100, show=True) # TODO expose parameters logger.debug(result) - e_cps = result[0].reshape(m_cps.shape) + e_cps = result[0].reshape(e_cps_shape) emap_enhanced = ElementMap(emap.name, e_cps) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + element_maps.append(emap_enhanced) else: upscaler = self._upscalingStrategyChooser.currentPlugin.strategy deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"') + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() emap_upscaled = upscaler(emap, product) emap_enhanced = deconvolver(emap_upscaled, product) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + element_maps.append(emap_enhanced) self._enhanced = FluorescenceDataset( diff --git a/ptychodus/plugins/xrfMapsFile.py b/ptychodus/plugins/xrfMapsFile.py index b86678f0..8234cd00 100644 --- a/ptychodus/plugins/xrfMapsFile.py +++ b/ptychodus/plugins/xrfMapsFile.py @@ -16,7 +16,7 @@ class XRFMapsFileIO(FluorescenceFileReader, FluorescenceFileWriter): SIMPLE_NAME: Final[str] = 'XRF-Maps' - DISPLAY_NAME: Final[str] = 'XRF-Maps Fluorescence Dataset (*.h5 *.hdf5)' + DISPLAY_NAME: Final[str] = 'XRF-Maps Fluorescence Dataset (*.h5 *.h5*)' @staticmethod def _split_path(data_path: str) -> tuple[str, str]: From 6583bb00baca0f17353de32e4f81caf8f604a6c6 Mon Sep 17 00:00:00 2001 From: Steve Henke Date: Mon, 21 Oct 2024 15:40:57 -0500 Subject: [PATCH 7/9] expose more vspi parameters --- ptychodus/controller/probe/fluorescence.py | 2 ++ ptychodus/model/analysis/fluorescence.py | 8 +++++++- ptychodus/model/analysis/settings.py | 6 ++++++ ptychodus/model/product/probe/multimodal.py | 1 + ptychodus/view/probe.py | 4 ++++ 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/ptychodus/controller/probe/fluorescence.py b/ptychodus/controller/probe/fluorescence.py index 18f3c01e..236f1e7b 100644 --- a/ptychodus/controller/probe/fluorescence.py +++ b/ptychodus/controller/probe/fluorescence.py @@ -47,6 +47,8 @@ def __init__( self._dialog = FluorescenceDialog() self._enhancementModel = QStringListModel() self._enhancementModel.setStringList(self._enhancer.getEnhancementStrategyList()) + # FIXME add vspiDampingFactor + # FIXME add vspiMaxIterations self._upscalingModel = QStringListModel() self._upscalingModel.setStringList(self._enhancer.getUpscalingStrategyList()) self._deconvolutionModel = QStringListModel() diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py index 7425c9f1..29812299 100644 --- a/ptychodus/model/analysis/fluorescence.py +++ b/ptychodus/model/analysis/fluorescence.py @@ -249,7 +249,13 @@ def enhanceFluorescence(self) -> None: logger.info(f'Enhancing "{emap.name}"...') tic = time.perf_counter() m_cps = emap.counts_per_second - result = lsmr(A, m_cps.flatten(), maxiter=100, show=True) # TODO expose parameters + result = lsmr( + A, + m_cps.flatten(), + damp=self._settings.vspiDampingFactor.getValue(), + maxiter=self._settings.vspiMaximumIterations.getValue(), + show=True, + ) logger.debug(result) e_cps = result[0].reshape(e_cps_shape) emap_enhanced = ElementMap(emap.name, e_cps) diff --git a/ptychodus/model/analysis/settings.py b/ptychodus/model/analysis/settings.py index 670601f9..2edcac95 100644 --- a/ptychodus/model/analysis/settings.py +++ b/ptychodus/model/analysis/settings.py @@ -34,6 +34,12 @@ def __init__(self, registry: SettingsRegistry) -> None: ) self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') self.useVSPI = self._settingsGroup.createBooleanParameter('UseVSPI', True) + self.vspiDampingFactor = self._settingsGroup.createRealParameter( + 'VSPIDampingFactor', 0.0, minimum=0.0 + ) + self.vspiMaximumIterations = self._settingsGroup.createIntegerParameter( + 'VSPIMaximumIterations', 100, minimum=1 + ) self.upscalingStrategy = self._settingsGroup.createStringParameter( 'UpscalingStrategy', 'Linear' ) diff --git a/ptychodus/model/product/probe/multimodal.py b/ptychodus/model/product/probe/multimodal.py index 22c614ed..20fb4358 100644 --- a/ptychodus/model/product/probe/multimodal.py +++ b/ptychodus/model/product/probe/multimodal.py @@ -120,6 +120,7 @@ def build(self, probe: Probe) -> Probe: if self.numberOfModes.getValue() <= 1: return probe elif self.numberOfModes.getValue() == probe.numberOfModes: + # FIXME accomplish this differently return probe array = self._initializeModes(probe.array) diff --git a/ptychodus/view/probe.py b/ptychodus/view/probe.py index e646a3d7..4739359a 100644 --- a/ptychodus/view/probe.py +++ b/ptychodus/view/probe.py @@ -176,6 +176,8 @@ def __init__(self, parent: QWidget | None = None) -> None: super().__init__('Parameters', parent) self.openButton = QPushButton('Open') self.enhancementStrategyComboBox = QComboBox() + self.vspiDampingFactorLineEdit = DecimalLineEdit.createInstance() + self.vspiMaxIterationsSpinBox = QSpinBox() self.upscalingStrategyComboBox = QComboBox() self.deconvolutionStrategyComboBox = QComboBox() self.enhanceButton = QPushButton('Enhance') @@ -184,6 +186,8 @@ def __init__(self, parent: QWidget | None = None) -> None: layout = QFormLayout() layout.addRow('Measured Dataset:', self.openButton) layout.addRow('Enhancement Strategy:', self.enhancementStrategyComboBox) + layout.addRow('VSPI Damping Factor:', self.vspiDampingFactorLineEdit) + layout.addRow('VSPI Max Iterations:', self.vspiMaxIterationsSpinBox) layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) layout.addRow(self.enhanceButton) From 3c522b2e089b2f1c88a73ae30d65e3fce38ecbb0 Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Mon, 21 Oct 2024 20:32:03 -0500 Subject: [PATCH 8/9] start on patterns screen; rename HDF5 file reader; prompt user to choose data format --- ptychodus/controller/core.py | 5 ++++- ptychodus/model/product/probe/multimodal.py | 3 --- ptychodus/plugins/h5DiffractionFile.py | 4 ++-- ptychodus/view/core.py | 2 -- ptychodus/view/patterns.py | 4 +++- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ptychodus/controller/core.py b/ptychodus/controller/core.py index d9902541..e709cbd8 100644 --- a/ptychodus/controller/core.py +++ b/ptychodus/controller/core.py @@ -136,9 +136,12 @@ def __init__(self, model: ModelCore, view: ViewCore) -> None: self._refreshDataTimer.timeout.connect(model.refreshActiveDataset) self._refreshDataTimer.start(1000) # TODO make configurable - view.navigationActionGroup.triggered.connect(lambda action: self.swapCentralWidgets(action)) view.workflowAction.setVisible(model.areWorkflowsSupported) + self.swapCentralWidgets(view.patternsAction) + view.patternsAction.setChecked(True) + view.navigationActionGroup.triggered.connect(lambda action: self.swapCentralWidgets(action)) + def showMainWindow(self, windowTitle: str) -> None: self.view.setWindowTitle(windowTitle) self.view.show() diff --git a/ptychodus/model/product/probe/multimodal.py b/ptychodus/model/product/probe/multimodal.py index 20fb4358..1563ad6e 100644 --- a/ptychodus/model/product/probe/multimodal.py +++ b/ptychodus/model/product/probe/multimodal.py @@ -119,9 +119,6 @@ def _adjustRelativePower(self, probe: WavefieldArrayType) -> WavefieldArrayType: def build(self, probe: Probe) -> Probe: if self.numberOfModes.getValue() <= 1: return probe - elif self.numberOfModes.getValue() == probe.numberOfModes: - # FIXME accomplish this differently - return probe array = self._initializeModes(probe.array) diff --git a/ptychodus/plugins/h5DiffractionFile.py b/ptychodus/plugins/h5DiffractionFile.py index 9e4f19f6..8c9a978b 100644 --- a/ptychodus/plugins/h5DiffractionFile.py +++ b/ptychodus/plugins/h5DiffractionFile.py @@ -191,8 +191,8 @@ def write(self, filePath: Path, dataset: DiffractionDataset) -> None: def registerPlugins(registry: PluginRegistry) -> None: registry.diffractionFileReaders.registerPlugin( H5DiffractionFileReader(dataPath='/entry/data/data'), - simpleName='HDF5', - displayName='Hierarchical Data Format 5 Files (*.h5 *.hdf5)', + simpleName='APS_HXN', + displayName='CNM/APS HXN Diffraction Files (*.h5 *.hdf5)', ) registry.diffractionFileReaders.registerPlugin( H5DiffractionFileReader(dataPath='/entry/measurement/Eiger/data'), diff --git a/ptychodus/view/core.py b/ptychodus/view/core.py index 98b11113..3a57fb33 100644 --- a/ptychodus/view/core.py +++ b/ptychodus/view/core.py @@ -108,8 +108,6 @@ def createInstance( action.setData(index) view.navigationActionGroup.addAction(action) - view.settingsAction.setChecked(True) - # maintain same order as navigationToolBar buttons view.parametersWidget.addWidget(view.settingsView) view.parametersWidget.addWidget(view.patternsView) diff --git a/ptychodus/view/patterns.py b/ptychodus/view/patterns.py index a2037b8c..218b9452 100644 --- a/ptychodus/view/patterns.py +++ b/ptychodus/view/patterns.py @@ -65,16 +65,18 @@ def __init__(self, parent: QWidget | None) -> None: super().__init__(parent) self.directoryComboBox = QComboBox() self.fileSystemTableView = QTableView() + self.fileTypeLabel = QLabel('Choose File Type:') self.fileTypeComboBox = QComboBox() @classmethod def createInstance(cls, parent: QWidget | None = None) -> OpenDatasetWizardFilesPage: view = cls(parent) - view.setTitle('Choose File(s)') + view.setTitle('Choose Dataset File(s)') layout = QVBoxLayout() layout.addWidget(view.directoryComboBox) layout.addWidget(view.fileSystemTableView) + layout.addWidget(view.fileTypeLabel) layout.addWidget(view.fileTypeComboBox) view.setLayout(layout) From 1b0f30d01ae930e755f2515b1e2edc4bbcd92a0d Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Wed, 23 Oct 2024 13:41:17 -0500 Subject: [PATCH 9/9] ptycho+xrf GUI complete --- ptychodus/api/fluorescence.py | 6 + ptychodus/controller/probe/core.py | 2 +- ptychodus/controller/probe/fluorescence.py | 142 ++++++--- ptychodus/model/analysis/__init__.py | 2 - ptychodus/model/analysis/core.py | 42 +-- ptychodus/model/analysis/fluorescence.py | 327 --------------------- ptychodus/model/analysis/settings.py | 31 -- ptychodus/model/core.py | 18 +- ptychodus/model/fluorescence/__init__.py | 10 + ptychodus/model/fluorescence/core.py | 223 ++++++++++++++ ptychodus/model/fluorescence/settings.py | 33 +++ ptychodus/model/fluorescence/two_step.py | 110 +++++++ ptychodus/model/fluorescence/vspi.py | 180 ++++++++++++ ptychodus/view/probe.py | 49 ++- 14 files changed, 718 insertions(+), 457 deletions(-) delete mode 100644 ptychodus/model/analysis/fluorescence.py create mode 100644 ptychodus/model/fluorescence/__init__.py create mode 100644 ptychodus/model/fluorescence/core.py create mode 100644 ptychodus/model/fluorescence/settings.py create mode 100644 ptychodus/model/fluorescence/two_step.py create mode 100644 ptychodus/model/fluorescence/vspi.py diff --git a/ptychodus/api/fluorescence.py b/ptychodus/api/fluorescence.py index 4a27f737..b1352e4c 100644 --- a/ptychodus/api/fluorescence.py +++ b/ptychodus/api/fluorescence.py @@ -24,6 +24,12 @@ class FluorescenceDataset: # scan_indexes: IntegerArray +class FluorescenceEnhancingAlgorithm(ABC): + @abstractmethod + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + pass + + class FluorescenceFileReader(ABC): @abstractmethod def read(self, filePath: Path) -> FluorescenceDataset: diff --git a/ptychodus/controller/probe/core.py b/ptychodus/controller/probe/core.py index b16b76fb..509b7ef9 100644 --- a/ptychodus/controller/probe/core.py +++ b/ptychodus/controller/probe/core.py @@ -8,10 +8,10 @@ from ...model.analysis import ( ExposureAnalyzer, - FluorescenceEnhancer, ProbePropagator, STXMSimulator, ) +from ...model.fluorescence import FluorescenceEnhancer from ...model.product import ProbeAPI, ProbeRepository from ...model.product.probe import ProbeRepositoryItem from ...model.visualization import VisualizationEngine diff --git a/ptychodus/controller/probe/fluorescence.py b/ptychodus/controller/probe/fluorescence.py index 236f1e7b..a2a49c1b 100644 --- a/ptychodus/controller/probe/fluorescence.py +++ b/ptychodus/controller/probe/fluorescence.py @@ -1,13 +1,23 @@ -from typing import Any +from decimal import Decimal +from typing import Any, Final import logging from PyQt5.QtCore import Qt, QAbstractListModel, QModelIndex, QObject, QStringListModel +from PyQt5.QtWidgets import QWidget from ptychodus.api.observer import Observable, Observer -from ...model.analysis import FluorescenceEnhancer +from ...model.fluorescence import ( + FluorescenceEnhancer, + TwoStepFluorescenceEnhancingAlgorithm, + VSPIFluorescenceEnhancingAlgorithm, +) from ...model.visualization import VisualizationEngine -from ...view.probe import FluorescenceDialog +from ...view.probe import ( + FluorescenceDialog, + FluorescenceTwoStepParametersView, + FluorescenceVSPIParametersView, +) from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..visualization import ( @@ -33,6 +43,71 @@ def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: return self._enhancer.getNumberOfChannels() +class FluorescenceTwoStepViewController(Observer): + def __init__(self, algorithm: TwoStepFluorescenceEnhancingAlgorithm) -> None: + super().__init__() + self._algorithm = algorithm + self._view = FluorescenceTwoStepParametersView() + + self._upscalingModel = QStringListModel() + self._upscalingModel.setStringList(self._algorithm.getUpscalingStrategyList()) + self._view.upscalingStrategyComboBox.setModel(self._upscalingModel) + self._view.upscalingStrategyComboBox.textActivated.connect(algorithm.setUpscalingStrategy) + + self._deconvolutionModel = QStringListModel() + self._deconvolutionModel.setStringList(self._algorithm.getDeconvolutionStrategyList()) + self._view.deconvolutionStrategyComboBox.setModel(self._deconvolutionModel) + self._view.deconvolutionStrategyComboBox.textActivated.connect( + algorithm.setDeconvolutionStrategy + ) + + self._syncModelToView() + algorithm.addObserver(self) + + def getWidget(self) -> QWidget: + return self._view + + def _syncModelToView(self) -> None: + self._view.upscalingStrategyComboBox.setCurrentText(self._algorithm.getUpscalingStrategy()) + self._view.deconvolutionStrategyComboBox.setCurrentText( + self._algorithm.getDeconvolutionStrategy() + ) + + def update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._syncModelToView() + + +class FluorescenceVSPIViewController(Observer): + MAX_INT: Final[int] = 0x7FFFFFFF + + def __init__(self, algorithm: VSPIFluorescenceEnhancingAlgorithm) -> None: + super().__init__() + self._algorithm = algorithm + self._view = FluorescenceVSPIParametersView() + + self._view.dampingFactorLineEdit.valueChanged.connect(self._syncDampingFactorToModel) + self._view.maxIterationsSpinBox.setRange(1, self.MAX_INT) + self._view.maxIterationsSpinBox.valueChanged.connect(algorithm.setMaxIterations) + + algorithm.addObserver(self) + self._syncModelToView() + + def getWidget(self) -> QWidget: + return self._view + + def _syncDampingFactorToModel(self, value: Decimal) -> None: + self._algorithm.setDampingFactor(float(value)) + + def _syncModelToView(self) -> None: + self._view.dampingFactorLineEdit.setValue(Decimal(repr(self._algorithm.getDampingFactor()))) + self._view.maxIterationsSpinBox.setValue(self._algorithm.getMaxIterations()) + + def update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._syncModelToView() + + class FluorescenceViewController(Observer): def __init__( self, @@ -46,43 +121,42 @@ def __init__( self._fileDialogFactory = fileDialogFactory self._dialog = FluorescenceDialog() self._enhancementModel = QStringListModel() - self._enhancementModel.setStringList(self._enhancer.getEnhancementStrategyList()) - # FIXME add vspiDampingFactor - # FIXME add vspiMaxIterations - self._upscalingModel = QStringListModel() - self._upscalingModel.setStringList(self._enhancer.getUpscalingStrategyList()) - self._deconvolutionModel = QStringListModel() - self._deconvolutionModel.setStringList(self._enhancer.getDeconvolutionStrategyList()) + self._enhancementModel.setStringList(self._enhancer.getAlgorithmList()) self._channelListModel = FluorescenceChannelListModel(enhancer) self._dialog.fluorescenceParametersView.openButton.clicked.connect( self._openMeasuredDataset ) - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setModel( - self._enhancementModel + twoStepViewController = FluorescenceTwoStepViewController( + enhancer.twoStepEnhancingAlgorithm + ) + self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._dialog.fluorescenceParametersView.algorithmComboBox.count(), ) - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.textActivated.connect( - enhancer.setEnhancementStrategy + self._dialog.fluorescenceParametersView.stackedWidget.addWidget( + twoStepViewController.getWidget() ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setModel( - self._upscalingModel + vspiViewController = FluorescenceVSPIViewController(enhancer.vspiEnhancingAlgorithm) + self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._dialog.fluorescenceParametersView.algorithmComboBox.count(), ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.textActivated.connect( - enhancer.setUpscalingStrategy + self._dialog.fluorescenceParametersView.stackedWidget.addWidget( + vspiViewController.getWidget() ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.setModel( - self._deconvolutionModel + self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( + enhancer.setAlgorithm ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.textActivated.connect( - enhancer.setDeconvolutionStrategy + self._dialog.fluorescenceParametersView.algorithmComboBox.currentIndexChanged.connect( + self._dialog.fluorescenceParametersView.stackedWidget.setCurrentIndex ) - - self._dialog.fluorescenceChannelListView.setModel(self._channelListModel) - self._dialog.fluorescenceChannelListView.selectionModel().currentChanged.connect( - self._updateView + self._dialog.fluorescenceParametersView.algorithmComboBox.setModel(self._enhancementModel) + self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( + enhancer.setAlgorithm ) self._dialog.fluorescenceParametersView.enhanceButton.clicked.connect( @@ -92,6 +166,11 @@ def __init__( self._saveEnhancedDataset ) + self._dialog.fluorescenceChannelListView.setModel(self._channelListModel) + self._dialog.fluorescenceChannelListView.selectionModel().currentChanged.connect( + self._updateView + ) + self._measuredWidgetController = VisualizationWidgetController( engine, self._dialog.measuredWidget, @@ -162,16 +241,9 @@ def _saveEnhancedDataset(self) -> None: ExceptionDialog.showException(title, err) def _syncModelToView(self) -> None: - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setCurrentText( - self._enhancer.getEnhancementStrategy() - ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setCurrentText( - self._enhancer.getUpscalingStrategy() + self._dialog.fluorescenceParametersView.algorithmComboBox.setCurrentText( + self._enhancer.getAlgorithm() ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.setCurrentText( - self._enhancer.getDeconvolutionStrategy() - ) - self._channelListModel.beginResetModel() self._channelListModel.endResetModel() diff --git a/ptychodus/model/analysis/__init__.py b/ptychodus/model/analysis/__init__.py index aeae296f..250c79ca 100644 --- a/ptychodus/model/analysis/__init__.py +++ b/ptychodus/model/analysis/__init__.py @@ -1,6 +1,5 @@ from .core import AnalysisCore from .exposure import ExposureAnalyzer, ExposureMap -from .fluorescence import FluorescenceEnhancer from .frc import FourierRingCorrelator from .objectInterpolator import ObjectLinearInterpolator from .objectStitcher import ObjectStitcher @@ -12,7 +11,6 @@ 'AnalysisCore', 'ExposureAnalyzer', 'ExposureMap', - 'FluorescenceEnhancer', 'FourierRingCorrelator', 'ObjectLinearInterpolator', 'ObjectStitcher', diff --git a/ptychodus/model/analysis/core.py b/ptychodus/model/analysis/core.py index 7225d657..baf8d098 100644 --- a/ptychodus/model/analysis/core.py +++ b/ptychodus/model/analysis/core.py @@ -1,23 +1,14 @@ -from pathlib import Path import logging -from ptychodus.api.fluorescence import ( - DeconvolutionStrategy, - FluorescenceFileReader, - FluorescenceFileWriter, - UpscalingStrategy, -) -from ptychodus.api.plugins import PluginChooser from ptychodus.api.settings import SettingsRegistry from ..product import ObjectRepository, ProductRepository from ..reconstructor import DiffractionPatternPositionMatcher from ..visualization import VisualizationEngine from .exposure import ExposureAnalyzer -from .fluorescence import FluorescenceEnhancer from .frc import FourierRingCorrelator from .propagator import ProbePropagator -from .settings import FluorescenceSettings, ProbePropagationSettings +from .settings import ProbePropagationSettings from .stxm import STXMSimulator from .xmcd import XMCDAnalyzer @@ -31,10 +22,6 @@ def __init__( dataMatcher: DiffractionPatternPositionMatcher, productRepository: ProductRepository, objectRepository: ObjectRepository, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fluorescenceFileReaderChooser: PluginChooser[FluorescenceFileReader], - fluorescenceFileWriterChooser: PluginChooser[FluorescenceFileWriter], ) -> None: self.stxmSimulator = STXMSimulator(dataMatcher) self.stxmVisualizationEngine = VisualizationEngine(isComplex=False) @@ -46,32 +33,5 @@ def __init__( self.exposureVisualizationEngine = VisualizationEngine(isComplex=False) self.fourierRingCorrelator = FourierRingCorrelator(objectRepository) - self._fluorescenceSettings = FluorescenceSettings(settingsRegistry) - self.fluorescenceEnhancer = FluorescenceEnhancer( - self._fluorescenceSettings, - productRepository, - upscalingStrategyChooser, - deconvolutionStrategyChooser, - fluorescenceFileReaderChooser, - fluorescenceFileWriterChooser, - settingsRegistry, - ) - self.fluorescenceVisualizationEngine = VisualizationEngine(isComplex=False) self.xmcdAnalyzer = XMCDAnalyzer(objectRepository) self.xmcdVisualizationEngine = VisualizationEngine(isComplex=False) - - def enhanceFluorescence( - self, productIndex: int, inputFilePath: Path, outputFilePath: Path - ) -> int: - fileType = 'XRF-Maps' - - try: - self.fluorescenceEnhancer.setProduct(productIndex) - self.fluorescenceEnhancer.openMeasuredDataset(inputFilePath, fileType) - self.fluorescenceEnhancer.enhanceFluorescence() - self.fluorescenceEnhancer.saveEnhancedDataset(outputFilePath, fileType) - except Exception as exc: - logger.exception(exc) - return -1 - - return 0 diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py deleted file mode 100644 index 29812299..00000000 --- a/ptychodus/model/analysis/fluorescence.py +++ /dev/null @@ -1,327 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from pathlib import Path -from typing import Final -import logging -import time - -from scipy.sparse.linalg import lsmr, LinearOperator -import numpy - -from ptychodus.api.fluorescence import ( - DeconvolutionStrategy, - ElementMap, - FluorescenceDataset, - FluorescenceFileReader, - FluorescenceFileWriter, - UpscalingStrategy, -) -from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectPoint -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.plugins import PluginChooser -from ptychodus.api.product import Product -from ptychodus.api.typing import RealArrayType - -from ..product import ProductRepository -from .settings import FluorescenceSettings - -logger = logging.getLogger(__name__) - - -class ArrayPatchInterpolator: - def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: - # top left corner of object support - xmin = point.positionXInPixels - shape[-1] / 2 - ymin = point.positionYInPixels - shape[-2] / 2 - - # whole components (pixel indexes) - xmin_wh = int(xmin) - ymin_wh = int(ymin) - - # fractional (subpixel) components - xmin_fr = xmin - xmin_wh - ymin_fr = ymin - ymin_wh - - # bottom right corner of object patch support - xmax_wh = xmin_wh + shape[-1] + 1 - ymax_wh = ymin_wh + shape[-2] + 1 - - # reused quantities - xmin_fr_c = 1.0 - xmin_fr - ymin_fr_c = 1.0 - ymin_fr - - # barycentric interpolant weights - self._weight00 = ymin_fr_c * xmin_fr_c - self._weight01 = ymin_fr_c * xmin_fr - self._weight10 = ymin_fr * xmin_fr_c - self._weight11 = ymin_fr * xmin_fr - - # extract patch support region from full object - self._support = array[ymin_wh:ymax_wh, xmin_wh:xmax_wh] - - def get_patch(self) -> RealArrayType: - """interpolate array support to extract patch""" - patch = self._weight00 * self._support[:-1, :-1] - patch += self._weight01 * self._support[:-1, 1:] - patch += self._weight10 * self._support[1:, :-1] - patch += self._weight11 * self._support[1:, 1:] - return patch - - def accumulate_patch(self, patch: RealArrayType) -> None: - """add patch update to array support""" - self._support[:-1, :-1] += self._weight00 * patch - self._support[:-1, 1:] += self._weight01 * patch - self._support[1:, :-1] += self._weight10 * patch - self._support[1:, 1:] += self._weight11 * patch - - -class VSPILinearOperator(LinearOperator): - def __init__(self, product: Product) -> None: - """ - M: number of XRF positions - N: number of ptychography object pixels - P: number of XRF channels - - A[M,N] * X[N,P] = B[M,P] - """ - object_geometry = product.object_.getGeometry() - M = len(product.scan) - N = object_geometry.heightInPixels * object_geometry.widthInPixels - super().__init__(float, (M, N)) - self._product = product - - def _get_psf(self) -> RealArrayType: - intensity = self._product.probe.getIntensity() - return intensity / intensity.sum() - - def _matvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() - AX = numpy.zeros(len(self._product.scan)) - - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) - interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) - AX[index] = numpy.sum(psf * interpolator.get_patch()) - - return AX - - def _rmatvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() - - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) - interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) - interpolator.accumulate_patch(X[index] * psf) - - HX = object_array.flatten() - - return HX - - -class FluorescenceEnhancer(Observable, Observer): - VSPI: Final[str] = 'Virtual Single Pixel Imaging' - TWO_STEP: Final[str] = 'Upscale and Deconvolve' - - def __init__( - self, - settings: FluorescenceSettings, - productRepository: ProductRepository, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fileReaderChooser: PluginChooser[FluorescenceFileReader], - fileWriterChooser: PluginChooser[FluorescenceFileWriter], - reinitObservable: Observable, - ) -> None: - super().__init__() - self._settings = settings - self._productRepository = productRepository - self._upscalingStrategyChooser = upscalingStrategyChooser - self._deconvolutionStrategyChooser = deconvolutionStrategyChooser - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._reinitObservable = reinitObservable - - self._productIndex = -1 - self._measured: FluorescenceDataset | None = None - self._enhanced: FluorescenceDataset | None = None - - upscalingStrategyChooser.addObserver(self) - upscalingStrategyChooser.setCurrentPluginByName(settings.upscalingStrategy.getValue()) - deconvolutionStrategyChooser.addObserver(self) - deconvolutionStrategyChooser.setCurrentPluginByName( - settings.deconvolutionStrategy.getValue() - ) - fileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) - fileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) - reinitObservable.addObserver(self) - - def setProduct(self, productIndex: int) -> None: - if self._productIndex != productIndex: - self._productIndex = productIndex - self._enhanced = None - self.notifyObservers() - - def getProductName(self) -> str: - return self._productRepository[self._productIndex].getName() - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() - - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName - - def openMeasuredDataset(self, filePath: Path, fileFilter: str) -> None: - if filePath.is_file(): - self._fileReaderChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileReaderChooser.currentPlugin.simpleName - logger.debug(f'Reading "{filePath}" as "{fileType}"') - fileReader = self._fileReaderChooser.currentPlugin.strategy - - try: - measured = fileReader.read(filePath) - except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc - else: - self._measured = measured - self._enhanced = None - - self._settings.filePath.setValue(filePath) - self._settings.fileType.setValue(fileType) - - self.notifyObservers() - else: - logger.warning(f'Refusing to load dataset from invalid file path "{filePath}"') - - def getNumberOfChannels(self) -> int: - return 0 if self._measured is None else len(self._measured.element_maps) - - def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: - if self._measured is None: - raise ValueError('Fluorescence dataset not loaded!') - - return self._measured.element_maps[channelIndex] - - def getEnhancementStrategyList(self) -> Sequence[str]: - return [self.VSPI, self.TWO_STEP] - - def getEnhancementStrategy(self) -> str: - return self.VSPI if self._settings.useVSPI.getValue() else self.TWO_STEP - - def setEnhancementStrategy(self, name: str) -> None: - self._settings.useVSPI.setValue(name.casefold() == self.VSPI.casefold()) - - def getUpscalingStrategyList(self) -> Sequence[str]: - return self._upscalingStrategyChooser.getDisplayNameList() - - def getUpscalingStrategy(self) -> str: - return self._upscalingStrategyChooser.currentPlugin.displayName - - def setUpscalingStrategy(self, name: str) -> None: - self._upscalingStrategyChooser.setCurrentPluginByName(name) - - def getDeconvolutionStrategyList(self) -> Sequence[str]: - return self._deconvolutionStrategyChooser.getDisplayNameList() - - def getDeconvolutionStrategy(self) -> str: - return self._deconvolutionStrategyChooser.currentPlugin.displayName - - def setDeconvolutionStrategy(self, name: str) -> None: - self._deconvolutionStrategyChooser.setCurrentPluginByName(name) - - def enhanceFluorescence(self) -> None: - if self._measured is None: - raise ValueError('Fluorescence dataset not loaded!') - - product = self._productRepository[self._productIndex].getProduct() - object_geometry = product.object_.getGeometry() - e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels - element_maps: list[ElementMap] = list() - - if self._settings.useVSPI.getValue(): - A = VSPILinearOperator(product) - - for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"...') - tic = time.perf_counter() - m_cps = emap.counts_per_second - result = lsmr( - A, - m_cps.flatten(), - damp=self._settings.vspiDampingFactor.getValue(), - maxiter=self._settings.vspiMaximumIterations.getValue(), - show=True, - ) - logger.debug(result) - e_cps = result[0].reshape(e_cps_shape) - emap_enhanced = ElementMap(emap.name, e_cps) - toc = time.perf_counter() - logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') - - element_maps.append(emap_enhanced) - else: - upscaler = self._upscalingStrategyChooser.currentPlugin.strategy - deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy - - for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"...') - tic = time.perf_counter() - emap_upscaled = upscaler(emap, product) - emap_enhanced = deconvolver(emap_upscaled, product) - toc = time.perf_counter() - logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') - - element_maps.append(emap_enhanced) - - self._enhanced = FluorescenceDataset( - element_maps=element_maps, - counts_per_second_path=self._measured.counts_per_second_path, - channel_names_path=self._measured.channel_names_path, - ) - self.notifyObservers() - - def getPixelGeometry(self) -> PixelGeometry: - return self._productRepository[self._productIndex].getGeometry().getPixelGeometry() - - def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: - if self._enhanced is None: - raise ValueError('Fluorescence dataset not enhanced!') - - return self._enhanced.element_maps[channelIndex] - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() - - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName - - def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: - if self._enhanced is None: - raise ValueError('Fluorescence dataset not enhanced!') - - self._fileWriterChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - writer = self._fileWriterChooser.currentPlugin.strategy - writer.write(filePath, self._enhanced) - - def _openFluorescenceFileFromSettings(self) -> None: - self.openMeasuredDataset( - self._settings.filePath.getValue(), self._settings.fileType.getValue() - ) - - def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: - self._openFluorescenceFileFromSettings() - elif observable is self._upscalingStrategyChooser: - strategy = self._upscalingStrategyChooser.currentPlugin.simpleName - self._settings.upscalingStrategy.setValue(strategy) - self.notifyObservers() - elif observable is self._deconvolutionStrategyChooser: - strategy = self._deconvolutionStrategyChooser.currentPlugin.simpleName - self._settings.deconvolutionStrategy.setValue(strategy) - self.notifyObservers() diff --git a/ptychodus/model/analysis/settings.py b/ptychodus/model/analysis/settings.py index 2edcac95..ea0fcd44 100644 --- a/ptychodus/model/analysis/settings.py +++ b/ptychodus/model/analysis/settings.py @@ -1,5 +1,3 @@ -from pathlib import Path - from ptychodus.api.observer import Observable, Observer from ptychodus.api.settings import SettingsRegistry @@ -21,32 +19,3 @@ def __init__(self, registry: SettingsRegistry) -> None: def update(self, observable: Observable) -> None: if observable is self._settingsGroup: self.notifyObservers() - - -class FluorescenceSettings(Observable, Observer): - def __init__(self, registry: SettingsRegistry) -> None: - super().__init__() - self._settingsGroup = registry.createGroup('Fluorescence') - self._settingsGroup.addObserver(self) - - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/dataset.h5') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') - self.useVSPI = self._settingsGroup.createBooleanParameter('UseVSPI', True) - self.vspiDampingFactor = self._settingsGroup.createRealParameter( - 'VSPIDampingFactor', 0.0, minimum=0.0 - ) - self.vspiMaximumIterations = self._settingsGroup.createIntegerParameter( - 'VSPIMaximumIterations', 100, minimum=1 - ) - self.upscalingStrategy = self._settingsGroup.createStringParameter( - 'UpscalingStrategy', 'Linear' - ) - self.deconvolutionStrategy = self._settingsGroup.createStringParameter( - 'DeconvolutionStrategy', 'Richardson-Lucy' - ) - - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() diff --git a/ptychodus/model/core.py b/ptychodus/model/core.py index 2faa7467..e8502692 100644 --- a/ptychodus/model/core.py +++ b/ptychodus/model/core.py @@ -24,7 +24,6 @@ from .analysis import ( AnalysisCore, ExposureAnalyzer, - FluorescenceEnhancer, FourierRingCorrelator, ProbePropagator, STXMSimulator, @@ -35,6 +34,7 @@ AutomationPresenter, AutomationProcessingPresenter, ) +from .fluorescence import FluorescenceCore, FluorescenceEnhancer from .memory import MemoryPresenter from .patterns import ( Detector, @@ -142,16 +142,20 @@ def __init__( self.ptychonnReconstructorLibrary, ], ) - self._analysisCore = AnalysisCore( + self._fluorescenceCore = FluorescenceCore( self.settingsRegistry, - self._reconstructorCore.dataMatcher, self._productCore.productRepository, - self._productCore.objectRepository, self._pluginRegistry.upscalingStrategies, self._pluginRegistry.deconvolutionStrategies, self._pluginRegistry.fluorescenceFileReaders, self._pluginRegistry.fluorescenceFileWriters, ) + self._analysisCore = AnalysisCore( + self.settingsRegistry, + self._reconstructorCore.dataMatcher, + self._productCore.productRepository, + self._productCore.objectRepository, + ) self._workflowCore = WorkflowCore( self.settingsRegistry, self._patternsCore.patternsAPI, @@ -307,7 +311,7 @@ def batchModeExecute( ) if fluorescenceInputFilePath is not None and fluorescenceOutputFilePath is not None: - self._analysisCore.enhanceFluorescence( + self._fluorescenceCore.enhanceFluorescence( outputProductIndex, fluorescenceInputFilePath, fluorescenceOutputFilePath, @@ -361,11 +365,11 @@ def fourierRingCorrelator(self) -> FourierRingCorrelator: @property def fluorescenceEnhancer(self) -> FluorescenceEnhancer: - return self._analysisCore.fluorescenceEnhancer + return self._fluorescenceCore.enhancer @property def fluorescenceVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.fluorescenceVisualizationEngine + return self._fluorescenceCore.visualizationEngine @property def xmcdAnalyzer(self) -> XMCDAnalyzer: diff --git a/ptychodus/model/fluorescence/__init__.py b/ptychodus/model/fluorescence/__init__.py new file mode 100644 index 00000000..dee43986 --- /dev/null +++ b/ptychodus/model/fluorescence/__init__.py @@ -0,0 +1,10 @@ +from .core import FluorescenceCore, FluorescenceEnhancer +from .two_step import TwoStepFluorescenceEnhancingAlgorithm +from .vspi import VSPIFluorescenceEnhancingAlgorithm + +__all__ = [ + 'FluorescenceCore', + 'FluorescenceEnhancer', + 'TwoStepFluorescenceEnhancingAlgorithm', + 'VSPIFluorescenceEnhancingAlgorithm', +] diff --git a/ptychodus/model/fluorescence/core.py b/ptychodus/model/fluorescence/core.py new file mode 100644 index 00000000..02c396b6 --- /dev/null +++ b/ptychodus/model/fluorescence/core.py @@ -0,0 +1,223 @@ +from __future__ import annotations +from collections.abc import Sequence +from pathlib import Path +import logging + + +from ptychodus.api.fluorescence import ( + DeconvolutionStrategy, + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, + FluorescenceFileReader, + FluorescenceFileWriter, + UpscalingStrategy, +) +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.settings import SettingsRegistry + +from ..product import ProductRepository, ProductRepositoryItem +from ..visualization import VisualizationEngine +from .settings import FluorescenceSettings +from .two_step import TwoStepFluorescenceEnhancingAlgorithm +from .vspi import VSPIFluorescenceEnhancingAlgorithm + +logger = logging.getLogger(__name__) + + +class FluorescenceEnhancer(Observable, Observer): + def __init__( + self, + settings: FluorescenceSettings, + productRepository: ProductRepository, + twoStepEnhancingAlgorithm: TwoStepFluorescenceEnhancingAlgorithm, + vspiEnhancingAlgorithm: VSPIFluorescenceEnhancingAlgorithm, + fileReaderChooser: PluginChooser[FluorescenceFileReader], + fileWriterChooser: PluginChooser[FluorescenceFileWriter], + reinitObservable: Observable, + ) -> None: + super().__init__() + self._settings = settings + self._productRepository = productRepository + self.twoStepEnhancingAlgorithm = twoStepEnhancingAlgorithm + self.vspiEnhancingAlgorithm = vspiEnhancingAlgorithm + self._fileReaderChooser = fileReaderChooser + self._fileWriterChooser = fileWriterChooser + self._reinitObservable = reinitObservable + + self._algorithmChooser = PluginChooser[FluorescenceEnhancingAlgorithm]() + self._algorithmChooser.registerPlugin( + twoStepEnhancingAlgorithm, + simpleName=TwoStepFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + displayName=TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + ) + self._algorithmChooser.registerPlugin( + vspiEnhancingAlgorithm, + simpleName=VSPIFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + displayName=VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + ) + self._syncAlgorithmFromSettings() + self._algorithmChooser.addObserver(self) + + self._productIndex = -1 + self._measured: FluorescenceDataset | None = None + self._enhanced: FluorescenceDataset | None = None + + fileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) + fileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) + reinitObservable.addObserver(self) + + @property + def _product(self) -> ProductRepositoryItem: + return self._productRepository[self._productIndex] + + def setProduct(self, productIndex: int) -> None: + if self._productIndex != productIndex: + self._productIndex = productIndex + self._enhanced = None + self.notifyObservers() + + def getProductName(self) -> str: + return self._product.getName() + + def getPixelGeometry(self) -> PixelGeometry: + return self._product.getGeometry().getPixelGeometry() + + def getOpenFileFilterList(self) -> Sequence[str]: + return self._fileReaderChooser.getDisplayNameList() + + def getOpenFileFilter(self) -> str: + return self._fileReaderChooser.currentPlugin.displayName + + def openMeasuredDataset(self, filePath: Path, fileFilter: str) -> None: + if filePath.is_file(): + self._fileReaderChooser.setCurrentPluginByName(fileFilter) + fileType = self._fileReaderChooser.currentPlugin.simpleName + logger.debug(f'Reading "{filePath}" as "{fileType}"') + fileReader = self._fileReaderChooser.currentPlugin.strategy + + try: + measured = fileReader.read(filePath) + except Exception as exc: + raise RuntimeError(f'Failed to read "{filePath}"') from exc + else: + self._measured = measured + self._enhanced = None + + self._settings.filePath.setValue(filePath) + self._settings.fileType.setValue(fileType) + + self.notifyObservers() + else: + logger.warning(f'Refusing to load dataset from invalid file path "{filePath}"') + + def getNumberOfChannels(self) -> int: + return 0 if self._measured is None else len(self._measured.element_maps) + + def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: + if self._measured is None: + raise ValueError('Fluorescence dataset not loaded!') + + return self._measured.element_maps[channelIndex] + + def getAlgorithmList(self) -> Sequence[str]: + return self._algorithmChooser.getDisplayNameList() + + def getAlgorithm(self) -> str: + return self._algorithmChooser.currentPlugin.displayName + + def setAlgorithm(self, name: str) -> None: + self._algorithmChooser.setCurrentPluginByName(name) + self._settings.algorithm.setValue(self._algorithmChooser.currentPlugin.simpleName) + + def _syncAlgorithmFromSettings(self) -> None: + self.setAlgorithm(self._settings.algorithm.getValue()) + + def enhanceFluorescence(self) -> None: + if self._measured is None: + raise ValueError('Fluorescence dataset not loaded!') + else: + algorithm = self._algorithmChooser.currentPlugin.strategy + product = self._product.getProduct() + self._enhanced = algorithm.enhance(self._measured, product) + self.notifyObservers() + + def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: + if self._enhanced is None: + return self.getMeasuredElementMap(channelIndex) + + return self._enhanced.element_maps[channelIndex] + + def getSaveFileFilterList(self) -> Sequence[str]: + return self._fileWriterChooser.getDisplayNameList() + + def getSaveFileFilter(self) -> str: + return self._fileWriterChooser.currentPlugin.displayName + + def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: + if self._enhanced is None: + raise ValueError('Fluorescence dataset not enhanced!') + + self._fileWriterChooser.setCurrentPluginByName(fileFilter) + fileType = self._fileWriterChooser.currentPlugin.simpleName + logger.debug(f'Writing "{filePath}" as "{fileType}"') + writer = self._fileWriterChooser.currentPlugin.strategy + writer.write(filePath, self._enhanced) + + def _openFluorescenceFileFromSettings(self) -> None: + self.openMeasuredDataset( + self._settings.filePath.getValue(), self._settings.fileType.getValue() + ) + + def update(self, observable: Observable) -> None: + if observable is self._algorithmChooser: + self.notifyObservers() + elif observable is self._reinitObservable: + self._syncAlgorithmFromSettings() + self._openFluorescenceFileFromSettings() + + +class FluorescenceCore: + def __init__( + self, + settingsRegistry: SettingsRegistry, + productRepository: ProductRepository, + upscalingStrategyChooser: PluginChooser[UpscalingStrategy], + deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], + fileReaderChooser: PluginChooser[FluorescenceFileReader], + fileWriterChooser: PluginChooser[FluorescenceFileWriter], + ) -> None: + self._settings = FluorescenceSettings(settingsRegistry) + self._twoStepEnhancingAlgorithm = TwoStepFluorescenceEnhancingAlgorithm( + self._settings, upscalingStrategyChooser, deconvolutionStrategyChooser, settingsRegistry + ) + self._vspiEnhancingAlgorithm = VSPIFluorescenceEnhancingAlgorithm(self._settings) + + self.enhancer = FluorescenceEnhancer( + self._settings, + productRepository, + self._twoStepEnhancingAlgorithm, + self._vspiEnhancingAlgorithm, + fileReaderChooser, + fileWriterChooser, + settingsRegistry, + ) + self.visualizationEngine = VisualizationEngine(isComplex=False) + + def enhanceFluorescence( + self, productIndex: int, inputFilePath: Path, outputFilePath: Path + ) -> int: + fileType = 'XRF-Maps' + + try: + self.enhancer.setProduct(productIndex) + self.enhancer.openMeasuredDataset(inputFilePath, fileType) + self.enhancer.enhanceFluorescence() + self.enhancer.saveEnhancedDataset(outputFilePath, fileType) + except Exception as exc: + logger.exception(exc) + return -1 + + return 0 diff --git a/ptychodus/model/fluorescence/settings.py b/ptychodus/model/fluorescence/settings.py new file mode 100644 index 00000000..71540f6b --- /dev/null +++ b/ptychodus/model/fluorescence/settings.py @@ -0,0 +1,33 @@ +from pathlib import Path + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class FluorescenceSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('Fluorescence') + self._settingsGroup.addObserver(self) + + self.filePath = self._settingsGroup.createPathParameter( + 'FilePath', Path('/path/to/dataset.h5') + ) + self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') + self.algorithm = self._settingsGroup.createStringParameter('Algorithm', 'VSPI') + self.vspiDampingFactor = self._settingsGroup.createRealParameter( + 'VSPIDampingFactor', 0.0, minimum=0.0 + ) + self.vspiMaxIterations = self._settingsGroup.createIntegerParameter( + 'VSPIMaxIterations', 100, minimum=1 + ) + self.upscalingStrategy = self._settingsGroup.createStringParameter( + 'UpscalingStrategy', 'Linear' + ) + self.deconvolutionStrategy = self._settingsGroup.createStringParameter( + 'DeconvolutionStrategy', 'Richardson-Lucy' + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() diff --git a/ptychodus/model/fluorescence/two_step.py b/ptychodus/model/fluorescence/two_step.py new file mode 100644 index 00000000..28c08e98 --- /dev/null +++ b/ptychodus/model/fluorescence/two_step.py @@ -0,0 +1,110 @@ +from __future__ import annotations +from collections.abc import Sequence +from typing import Final +import logging +import time + +from ptychodus.api.fluorescence import ( + DeconvolutionStrategy, + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, + UpscalingStrategy, +) +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.product import Product + +from .settings import FluorescenceSettings + +logger = logging.getLogger(__name__) + +__all__ = [ + 'TwoStepFluorescenceEnhancingAlgorithm', +] + + +class TwoStepFluorescenceEnhancingAlgorithm(FluorescenceEnhancingAlgorithm, Observable, Observer): + SIMPLE_NAME: Final[str] = 'TwoStep' + DISPLAY_NAME: Final[str] = 'Upscale and Deconvolve' + + def __init__( + self, + settings: FluorescenceSettings, + upscalingStrategyChooser: PluginChooser[UpscalingStrategy], + deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], + reinitObservable: Observable, + ) -> None: + super().__init__() + self._settings = settings + self._upscalingStrategyChooser = upscalingStrategyChooser + self._deconvolutionStrategyChooser = deconvolutionStrategyChooser + self._reinitObservable = reinitObservable + + self._syncUpscalingStrategyFromSettings() + upscalingStrategyChooser.addObserver(self) + + self._syncDeconvolutionStrategyFromSettings() + deconvolutionStrategyChooser.addObserver(self) + + reinitObservable.addObserver(self) + + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + upscaler = self._upscalingStrategyChooser.currentPlugin.strategy + deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy + element_maps: list[ElementMap] = list() + + for emap in dataset.element_maps: + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() + emap_upscaled = upscaler(emap, product) + emap_enhanced = deconvolver(emap_upscaled, product) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + + element_maps.append(emap_enhanced) + + return FluorescenceDataset( + element_maps=element_maps, + counts_per_second_path=dataset.counts_per_second_path, + channel_names_path=dataset.channel_names_path, + ) + + def getUpscalingStrategyList(self) -> Sequence[str]: + return self._upscalingStrategyChooser.getDisplayNameList() + + def getUpscalingStrategy(self) -> str: + return self._upscalingStrategyChooser.currentPlugin.displayName + + def setUpscalingStrategy(self, name: str) -> None: + self._upscalingStrategyChooser.setCurrentPluginByName(name) + self._settings.upscalingStrategy.setValue( + self._upscalingStrategyChooser.currentPlugin.simpleName + ) + + def _syncUpscalingStrategyFromSettings(self) -> None: + self.setUpscalingStrategy(self._settings.upscalingStrategy.getValue()) + + def getDeconvolutionStrategyList(self) -> Sequence[str]: + return self._deconvolutionStrategyChooser.getDisplayNameList() + + def getDeconvolutionStrategy(self) -> str: + return self._deconvolutionStrategyChooser.currentPlugin.displayName + + def setDeconvolutionStrategy(self, name: str) -> None: + self._deconvolutionStrategyChooser.setCurrentPluginByName(name) + self._settings.deconvolutionStrategy.setValue( + self._deconvolutionStrategyChooser.currentPlugin.simpleName + ) + + def _syncDeconvolutionStrategyFromSettings(self) -> None: + self.setDeconvolutionStrategy(self._settings.deconvolutionStrategy.getValue()) + + def update(self, observable: Observable) -> None: + if observable is self._reinitObservable: + self._syncUpscalingStrategyFromSettings() + self._syncDeconvolutionStrategyFromSettings() + elif observable is self._upscalingStrategyChooser: + self.notifyObservers() + elif observable is self._deconvolutionStrategyChooser: + self.notifyObservers() diff --git a/ptychodus/model/fluorescence/vspi.py b/ptychodus/model/fluorescence/vspi.py new file mode 100644 index 00000000..15247c66 --- /dev/null +++ b/ptychodus/model/fluorescence/vspi.py @@ -0,0 +1,180 @@ +from __future__ import annotations +from typing import Final +import logging +import time + +from scipy.sparse.linalg import lsmr, LinearOperator +import numpy + +from ptychodus.api.fluorescence import ( + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, +) +from ptychodus.api.object import ObjectPoint +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.product import Product +from ptychodus.api.typing import RealArrayType + +from .settings import FluorescenceSettings + +logger = logging.getLogger(__name__) + +__all__ = [ + 'VSPIFluorescenceEnhancingAlgorithm', +] + + +class ArrayPatchInterpolator: + def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: + # top left corner of object support + xmin = point.positionXInPixels - shape[-1] / 2 + ymin = point.positionYInPixels - shape[-2] / 2 + + # whole components (pixel indexes) + xmin_wh = int(xmin) + ymin_wh = int(ymin) + + # fractional (subpixel) components + xmin_fr = xmin - xmin_wh + ymin_fr = ymin - ymin_wh + + # bottom right corner of object patch support + xmax_wh = xmin_wh + shape[-1] + 1 + ymax_wh = ymin_wh + shape[-2] + 1 + + # reused quantities + xmin_fr_c = 1.0 - xmin_fr + ymin_fr_c = 1.0 - ymin_fr + + # barycentric interpolant weights + self._weight00 = ymin_fr_c * xmin_fr_c + self._weight01 = ymin_fr_c * xmin_fr + self._weight10 = ymin_fr * xmin_fr_c + self._weight11 = ymin_fr * xmin_fr + + # extract patch support region from full object + self._support = array[ymin_wh:ymax_wh, xmin_wh:xmax_wh] + + def get_patch(self) -> RealArrayType: + """interpolate array support to extract patch""" + patch = self._weight00 * self._support[:-1, :-1] + patch += self._weight01 * self._support[:-1, 1:] + patch += self._weight10 * self._support[1:, :-1] + patch += self._weight11 * self._support[1:, 1:] + return patch + + def accumulate_patch(self, patch: RealArrayType) -> None: + """add patch update to array support""" + self._support[:-1, :-1] += self._weight00 * patch + self._support[:-1, 1:] += self._weight01 * patch + self._support[1:, :-1] += self._weight10 * patch + self._support[1:, 1:] += self._weight11 * patch + + +class VSPILinearOperator(LinearOperator): + def __init__(self, product: Product) -> None: + """ + M: number of XRF positions + N: number of ptychography object pixels + P: number of XRF channels + + A[M,N] * X[N,P] = B[M,P] + """ + object_geometry = product.object_.getGeometry() + M = len(product.scan) + N = object_geometry.heightInPixels * object_geometry.widthInPixels + super().__init__(float, (M, N)) + self._product = product + + def _get_psf(self) -> RealArrayType: + intensity = self._product.probe.getIntensity() + return intensity / intensity.sum() + + def _matvec(self, X: RealArrayType) -> RealArrayType: + object_geometry = self._product.object_.getGeometry() + object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() + AX = numpy.zeros(len(self._product.scan)) + + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) + AX[index] = numpy.sum(psf * interpolator.get_patch()) + + return AX + + def _rmatvec(self, X: RealArrayType) -> RealArrayType: + object_geometry = self._product.object_.getGeometry() + object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() + + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) + interpolator.accumulate_patch(X[index] * psf) + + HX = object_array.flatten() + + return HX + + +class VSPIFluorescenceEnhancingAlgorithm(FluorescenceEnhancingAlgorithm, Observable, Observer): + SIMPLE_NAME: Final[str] = 'VSPI' + DISPLAY_NAME: Final[str] = 'Virtual Single Pixel Imaging' + + def __init__(self, settings: FluorescenceSettings) -> None: + super().__init__() + self._settings = settings + + settings.vspiDampingFactor.addObserver(self) + settings.vspiMaxIterations.addObserver(self) + + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + object_geometry = product.object_.getGeometry() + e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels + element_maps: list[ElementMap] = list() + A = VSPILinearOperator(product) + + for emap in dataset.element_maps: + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() + m_cps = emap.counts_per_second + result = lsmr( + A, + m_cps.flatten(), + damp=self._settings.vspiDampingFactor.getValue(), + maxiter=self._settings.vspiMaxIterations.getValue(), + show=True, + ) + logger.debug(result) + e_cps = result[0].reshape(e_cps_shape) + emap_enhanced = ElementMap(emap.name, e_cps) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + + element_maps.append(emap_enhanced) + + return FluorescenceDataset( + element_maps=element_maps, + counts_per_second_path=dataset.counts_per_second_path, + channel_names_path=dataset.channel_names_path, + ) + + def getDampingFactor(self) -> float: + return self._settings.vspiDampingFactor.getValue() + + def setDampingFactor(self, factor: float) -> None: + self._settings.vspiDampingFactor.setValue(factor) + + def getMaxIterations(self) -> int: + return self._settings.vspiMaxIterations.getValue() + + def setMaxIterations(self, number: int) -> None: + self._settings.vspiMaxIterations.setValue(number) + + def update(self, observable: Observable) -> None: + if observable is self._settings.vspiDampingFactor: + self.notifyObservers() + elif observable is self._settings.vspiMaxIterations: + self.notifyObservers() diff --git a/ptychodus/view/probe.py b/ptychodus/view/probe.py index 4739359a..ff7cc83d 100644 --- a/ptychodus/view/probe.py +++ b/ptychodus/view/probe.py @@ -13,6 +13,7 @@ QRadioButton, QSlider, QSpinBox, + QStackedWidget, QStatusBar, QVBoxLayout, QWidget, @@ -171,27 +172,49 @@ def __init__(self, parent: QWidget | None = None) -> None: self.setLayout(layout) -class FluorescenceParametersView(QGroupBox): +class FluorescenceVSPIParametersView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: - super().__init__('Parameters', parent) - self.openButton = QPushButton('Open') - self.enhancementStrategyComboBox = QComboBox() - self.vspiDampingFactorLineEdit = DecimalLineEdit.createInstance() - self.vspiMaxIterationsSpinBox = QSpinBox() + super().__init__(parent) + self.dampingFactorLineEdit = DecimalLineEdit.createInstance() + self.maxIterationsSpinBox = QSpinBox() + + layout = QFormLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addRow('Damping Factor:', self.dampingFactorLineEdit) + layout.addRow('Max Iterations:', self.maxIterationsSpinBox) + self.setLayout(layout) + + +class FluorescenceTwoStepParametersView(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) self.upscalingStrategyComboBox = QComboBox() self.deconvolutionStrategyComboBox = QComboBox() - self.enhanceButton = QPushButton('Enhance') - self.saveButton = QPushButton('Save') layout = QFormLayout() - layout.addRow('Measured Dataset:', self.openButton) - layout.addRow('Enhancement Strategy:', self.enhancementStrategyComboBox) - layout.addRow('VSPI Damping Factor:', self.vspiDampingFactorLineEdit) - layout.addRow('VSPI Max Iterations:', self.vspiMaxIterationsSpinBox) + layout.setContentsMargins(0, 0, 0, 0) layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) + self.setLayout(layout) + + +class FluorescenceParametersView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Enhancement Strategy', parent) + self.openButton = QPushButton('Open Measured Dataset') + self.algorithmComboBox = QComboBox() + self.stackedWidget = QStackedWidget() + self.enhanceButton = QPushButton('Enhance') + self.saveButton = QPushButton('Save Enhanced Dataset') + + self.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) + + layout = QFormLayout() + layout.addRow(self.openButton) + layout.addRow('Algorithm:', self.algorithmComboBox) + layout.addRow(self.stackedWidget) layout.addRow(self.enhanceButton) - layout.addRow('Enhanced Dataset:', self.saveButton) + layout.addRow(self.saveButton) self.setLayout(layout)