Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interactive disk detection #16

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 193 additions & 11 deletions src/py4D_browser/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
QMenu,
QAction,
QHBoxLayout,
QVBoxLayout,
QSplitter,
QActionGroup,
QLabel,
QPushButton,
QTabWidget,
QDoubleSpinBox,
QSpinBox,
QComboBox,
QCheckBox,
QGraphicsItemGroup,
)

import pyqtgraph as pg
Expand Down Expand Up @@ -53,6 +60,9 @@ class DataViewer(QMainWindow):
nudge_diffraction_selector,
update_annulus_pos,
update_annulus_radii,
update_probe_template_view,
update_kernel_view,
update_disk_detection,
)

HAS_EMPAD2 = importlib.util.find_spec("empad2") is not None
Expand All @@ -71,6 +81,10 @@ def __init__(self, argv):
self.qtapp = QApplication(argv)

self.setWindowTitle("py4DSTEM")

self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.layout = QHBoxLayout(self.central_widget)

icon = QtGui.QIcon(str(Path(__file__).parent.absolute() / "logo.png"))
self.setWindowIcon(icon)
Expand All @@ -80,6 +94,8 @@ def __init__(self, argv):
self.setAcceptDrops(True)

self.datacube = None
self.separate_window = None
self.disk_group = QGraphicsItemGroup()

self.setup_menus()
self.setup_views()
Expand Down Expand Up @@ -370,7 +386,7 @@ def setup_menus(self):
img_ewpc_action.triggered.connect(
partial(self.update_diffraction_space_view, False)
)

self.help_menu = QMenu("&Help", self)
self.menu_bar.addMenu(self.help_menu)

Expand All @@ -391,7 +407,6 @@ def setup_views(self):
self.virtual_detector_point.sigRegionChanged.connect(
partial(self.update_real_space_view, False)
)

# Scalebar
self.diffraction_scale_bar = ScaleBar(pixel_size=1, units="px", width=10)
self.diffraction_scale_bar.setParentItem(
Expand Down Expand Up @@ -446,21 +461,74 @@ def setup_views(self):
self.fft_widget.dragEnterEvent = self.dragEnterEvent
self.fft_widget.dropEvent = self.dropEvent

layout = QHBoxLayout()
layout.addWidget(self.diffraction_space_widget, 1)
# Create a QTabWidget
self.tab_widget = QTabWidget()

# Add the ImageView and QLabel to the first tab
self.tab1 = QWidget()

self.tab1_layout = QHBoxLayout(self.tab1)
self.tab1_layout.addWidget(self.diffraction_space_widget, 1)

# add a resizeable layout for the vimg and FFT
rightside = QSplitter()
rightside.addWidget(self.real_space_widget)
rightside.addWidget(self.fft_widget)

self.tabs_bottomright = QTabWidget()

self.fft_tab = QWidget()
self.fft_tab_layout = QVBoxLayout(self.fft_tab)
self.fft_tab_layout.addWidget(self.fft_widget)

self.probe_view = pg.ImageView()
self.probe_view.setImage(np.zeros((512, 512)))

self.disk_detect_tab = QWidget()
self.disk_detect_tab_layout = QVBoxLayout(self.disk_detect_tab)

self.cross_correlation_layout = QHBoxLayout()

self.probe_template_layout = QVBoxLayout()
self.generate_probe_template_button = QPushButton("Generate probe template...")
self.generate_probe_template_button.clicked.connect(self.update_probe_template_view)
self.probe_template_layout.addWidget(self.generate_probe_template_button)
self.probe_view = pg.ImageView()
self.probe_view.setImage(np.zeros((512, 512)))
self.probe_template_layout.addWidget(self.probe_view)

self.kernel_layout = QVBoxLayout()
self.generate_kernel_button = QPushButton("Generate kernel...")
self.generate_kernel_button.clicked.connect(self.update_kernel_view)
self.kernel_layout.addWidget(self.generate_kernel_button)
self.kernel_view = pg.ImageView()
self.kernel_view.setImage(np.zeros((512, 512)))
self.kernel_layout.addWidget(self.kernel_view)
self.kernel_radius = QDoubleSpinBox()
self.kernel_radius.setPrefix("Kernel radius multiplier: ")
self.kernel_radius.setMinimum(2.0)
self.kernel_radius.setMaximum(10.0)
self.kernel_radius.valueChanged.connect(self.update_kernel_view)
self.kernel_layout.addWidget(self.kernel_radius)

self.cross_correlation_layout.addLayout(self.probe_template_layout)
self.cross_correlation_layout.addLayout(self.kernel_layout)

self.open_window_button = QPushButton("Setup disk detection parameters")
self.open_window_button.clicked.connect(self.open_separate_window)

self.disk_detect_tab_layout.addLayout(self.cross_correlation_layout)
self.disk_detect_tab_layout.addWidget(self.open_window_button)


self.tabs_bottomright.addTab(self.fft_tab, "FFT")
self.tabs_bottomright.addTab(self.disk_detect_tab, "Disk Detection")

rightside.addWidget(self.tabs_bottomright)
rightside.setOrientation(QtCore.Qt.Vertical)
rightside.setStretchFactor(0, 2)
layout.addWidget(rightside, 1)

widget = QWidget()
widget.setLayout(layout)
self.setCentralWidget(widget)


self.tab1_layout.addWidget(rightside, 1)

self.diffraction_space_widget.getView().setMenuEnabled(False)
self.real_space_widget.getView().setMenuEnabled(False)
self.fft_widget.getView().setMenuEnabled(False)
Expand Down Expand Up @@ -489,6 +557,11 @@ def setup_views(self):
self.real_space_widget.autoLevels
)
self.statusBar().addPermanentWidget(self.realspace_rescale_button)

# Make a virtual imaging tab, in the future more tabs can be added for
# different views
self.tab_widget.addTab(self.tab1, "Virtual Imaging")
self.layout.addWidget(self.tab_widget)

# Handle dragging and dropping a file on the window
def dragEnterEvent(self, event):
Expand Down Expand Up @@ -536,3 +609,112 @@ def keyPressEvent(self, event):
-1 if key == QtCore.Qt.Key_J else 1 if key == QtCore.Qt.Key_L else 0
),
)

def open_separate_window(self):
self.separate_window = SeparateWindow()
self.separate_window.on_fit_current_clicked(self.update_disk_detection)

class SeparateWindow(QWidget):
def __init__(self):
super().__init__()

self.layout = QVBoxLayout()
# Create the fields
label = QLabel("Disk Detection Parameters:")
self.layout.addWidget(label)

label = QLabel("Minimum absolute intensity:")
self.layout.addWidget(label)
self.min_intensity = QSpinBox()
self.min_intensity.setMinimum(0)
self.min_intensity.setMaximum(1000)
self.min_intensity.setValue(0)
self.layout.addWidget(self.min_intensity)

label = QLabel("Minimum relative intensity:")
self.layout.addWidget(label)
self.rel_intensity = QDoubleSpinBox()
self.rel_intensity.setMinimum(0.0)
self.rel_intensity.setMaximum(1.0)
self.rel_intensity.setSingleStep(0.001)
self.rel_intensity.setValue(0.005)
self.layout.addWidget(self.rel_intensity)

label = QLabel("Minimum peak spacing (pixels):")
self.layout.addWidget(label)
self.min_peak_spacing = QSpinBox()
self.min_peak_spacing.setMinimum(0)
self.min_peak_spacing.setValue(60)
self.layout.addWidget(self.min_peak_spacing)

label = QLabel("Edge Boundary (pixels):")
self.layout.addWidget(label)
self.edge_boundary = QSpinBox()
self.edge_boundary.setMinimum(0)
self.edge_boundary.setValue(20)
self.layout.addWidget(self.edge_boundary)

label = QLabel("Sigma:")
self.layout.addWidget(label)
self.sigma = QDoubleSpinBox()
self.sigma.setMinimum(0.0)
self.layout.addWidget(self.sigma)

label = QLabel("Maximum number of peaks:")
self.layout.addWidget(label)
self.max_num_peaks = QSpinBox()
self.max_num_peaks.setMinimum(0)
self.max_num_peaks.setValue(70)
self.layout.addWidget(self.max_num_peaks)

label = QLabel("Correlation power:")
self.layout.addWidget(label)
self.corr_power = QDoubleSpinBox()
self.corr_power.setValue(1.0)
self.corr_power.setMaximum(1.0)
self.corr_power.setMinimum(0.0)
self.layout.addWidget(self.corr_power)

label = QLabel("Subpixel:")
self.layout.addWidget(label)
self.subpixel = QComboBox()
self.subpixel.addItems(["none", "poly", "multicorr"])
self.subpixel.setCurrentText("multicorr")
self.layout.addWidget(self.subpixel)

label = QLabel("Check CUDA:")
self.layout.addWidget(label)
self.check_cuda = QCheckBox("Checklist")
self.layout.addWidget(self.check_cuda)

button_layout = QHBoxLayout()
self.run_current = QPushButton("Fit Current View")
self.run_all = LatchingButton('Enable/Disable Disk Detection')
button_layout.addWidget(self.run_current)
button_layout.addWidget(self.run_all)
self.layout.addLayout(button_layout)

self.setWindowTitle("Disk Detection Parameters:")

# Set the layout for the SeparateWindow
self.setLayout(self.layout)
self.show()

def on_fit_current_clicked(self, slot):
self.run_current.clicked.connect(slot)

def get_params_as_dict(self):

params_dict = {
"minAbsoluteIntensity": self.min_intensity.value(),
"minRelativeIntensity": self.rel_intensity.value(),
"minPeakSpacing": self.min_peak_spacing.value(),
"edgeBoundary": self.edge_boundary.value(),
"sigma": self.sigma.value(),
"maxNumPeaks": self.max_num_peaks.value(),
"corrPower": self.corr_power.value(),
"subpixel": self.subpixel.currentText(),
"CUDA": self.check_cuda.isChecked(),
}

return params_dict
97 changes: 96 additions & 1 deletion src/py4D_browser/update_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from py4D_browser.utils import pg_point_roi, make_detector


def update_real_space_view(self, reset=False):
scaling_mode = self.vimg_scaling_group.checkedAction().text().replace("&", "")
assert scaling_mode in ["Linear", "Log", "Square Root"], scaling_mode
Expand Down Expand Up @@ -267,6 +266,15 @@ def update_diffraction_space_view(self, reset=False):
fft.T, autoLevels=False, levels=levels, autoRange=mode_switch
)

# remove ROIs if button is unlatched to return to the default view
if self.disk_group in self.diffraction_space_widget.view.allChildren():
self.diffraction_space_widget.removeItem(self.disk_group)

# Continuously update disk positions as real space ROI moves when button
# is latched
if self.separate_window is not None:
if self.separate_window.run_all.isChecked():
self.update_disk_detection()

def update_realspace_detector(self):
# change the shape of the detector, then update the view
Expand Down Expand Up @@ -484,3 +492,90 @@ def update_annulus_radii(self):
self.virtual_detector_roi_outer.setPos(
x0 - R_inner - 3, y0 - R_inner - 3, update=False
)

def update_probe_template_view(self, _=None):

if hasattr(self, "real_space_point_selector"):
roi = self.real_space_point_selector
if hasattr(self, "real_space_rect_selector"):
roi = self.real_space_rect_selector

pos = roi.pos()
size = roi.size()

# Create a mask that is the same shape as the data and is 0 everywhere
mask = np.zeros(self.datacube.Rshape, dtype=bool)

# Set the region of the mask under the ROI to 1
mask[int(pos[1]):int(pos[1]+size[1]), int(pos[0]):int(pos[0]+size[0])] = 1
print(np.count_nonzero(mask), "pixels in ROI")
# Use current ROI in real space to generate a probe template
self.probe = self.datacube.get_vacuum_probe(mask)

self.probe_view.setImage(self.probe.probe)

self.alpha_pr, self.qx0_pr, self.qy0_pr = self.datacube.get_probe_size(self.probe.probe)
print(f"Probe size: {self.alpha_pr} px"
f" at ({self.qx0_pr}, {self.qy0_pr}) px")

def update_kernel_view(self):
# Update the kernel view
multiplier = self.kernel_radius.value()
self.probe.get_kernel(
mode='sigmoid',
origin=(self.qx0_pr, self.qy0_pr),
radii=(self.alpha_pr, multiplier*self.alpha_pr) # the inner and outer radii of the 'trench'
)

R = 24
kernel = self.probe.kernel
im_kernel = np.vstack(
[
np.hstack([kernel[-int(R) :, -int(R) :], kernel[-int(R) :, : int(R)]]),
np.hstack([kernel[: int(R), -int(R) :], kernel[: int(R), : int(R)]]),
]
)

self.kernel_view.setImage(im_kernel)

def update_disk_detection(self):
"""
Finds Bragg disks for the currently displayed diffraction pattern.
"""

# Remove existing CircleROIs
for item in self.disk_group.childItems():
self.disk_group.removeFromGroup(item)

# take current real space ROI position and turn it into pixel coordinates
# this part was taken from one of the functions above
roi_state = self.real_space_point_selector.saveState()
y0, x0 = roi_state["pos"]
xc, yc = int(x0 + 1), int(y0 + 1)

# Normalize coordinates
xc = np.clip(xc, 0, self.datacube.R_Nx - 1)
yc = np.clip(yc, 0, self.datacube.R_Ny - 1)

# get parameters in dictionary form from the parameter window
detection_params = self.separate_window.get_params_as_dict()

braggpeaks = self.datacube.find_Bragg_disks(
data=(xc, yc),
template=self.probe.kernel,
**detection_params,
)

for qx, qy in zip(braggpeaks.qx, braggpeaks.qy):

disk_roi = pg.CircleROI(
(qy-self.alpha_pr/2+0.5, qx-self.alpha_pr/2+0.5),
(self.alpha_pr, self.alpha_pr),
movable=False,
resizable=False,
pen=pg.mkPen('r', width=2, cosmetic=True)
)
self.disk_group.addToGroup(disk_roi)

self.diffraction_space_widget.addItem(self.disk_group)

Loading