Source code for darfix.gui.shiftCorrectionWidget
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "05/07/2021"
# import os
import numpy
from silx.gui import qt
from silx.gui.colors import Colormap
from silx.gui.plot.StackView import StackViewMainWindow
import darfix
from darfix.core.dataset import Operation
from .operationThread import OperationThread
from .chooseDimensions import ChooseDimensionDock
from darfix import dtypes
from darfix.gui.utils.message import missing_dataset_msg
[docs]
class ShiftCorrectionDialog(qt.QDialog):
"""
Dialog with `ShiftCorrectionWidget` as main window and standard buttons.
"""
okSignal = qt.Signal()
def __init__(self, parent=None):
qt.QDialog.__init__(self, parent)
self.setWindowFlags(qt.Qt.Widget)
types = qt.QDialogButtonBox.Ok
self._buttons = qt.QDialogButtonBox(parent=self)
self._buttons.setStandardButtons(types)
self._buttons.setEnabled(False)
resetB = self._buttons.addButton(self._buttons.Reset)
self.mainWindow = ShiftCorrectionWidget(parent=self)
self.mainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
self.setLayout(qt.QVBoxLayout())
self.layout().addWidget(self.mainWindow)
self.layout().addWidget(self._buttons)
self._buttons.accepted.connect(self.okSignal.emit)
resetB.clicked.connect(self.mainWindow.resetStack)
self.mainWindow.computingSignal.connect(self._toggleButton)
[docs]
def setDataset(self, owdataset: dtypes.OWDataset) -> None:
if owdataset.dataset is not None:
self._buttons.setEnabled(True)
self.mainWindow.setDataset(owdataset)
def _toggleButton(self, state):
self._buttons.button(qt.QDialogButtonBox.Ok).setEnabled(not state)
[docs]
class ShiftCorrectionWidget(qt.QMainWindow):
"""
A widget to apply shift correction to a stack of images
"""
computingSignal = qt.Signal(bool)
def __init__(self, parent=None):
qt.QMainWindow.__init__(self, parent)
self.setWindowFlags(qt.Qt.Widget)
self._shift = numpy.array([0, 0])
self._filtered_shift = None
self._dimension = None
self._update_dataset = None
self.dataset = None
self.indices = None
self.bg_indices = None
self.bg_dataset = None
self._inputDock = _InputDock()
self._inputDock.widget.correctionB.setEnabled(False)
self._sv = StackViewMainWindow()
self._sv.setColormap(
Colormap(name=darfix.config.DEFAULT_COLORMAP_NAME, normalization="linear")
)
self.setCentralWidget(self._sv)
self._chooseDimensionDock = ChooseDimensionDock(self)
spacer1 = qt.QWidget(parent=self)
spacer1.setLayout(qt.QVBoxLayout())
spacer1.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Expanding)
spacer2 = qt.QWidget(parent=self)
spacer2.setLayout(qt.QVBoxLayout())
spacer2.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Expanding)
self._chooseDimensionDock.widget.layout().addWidget(spacer1)
self._inputDock.widget.layout().addWidget(spacer2)
self._chooseDimensionDock.hide()
self.addDockWidget(qt.Qt.RightDockWidgetArea, self._chooseDimensionDock)
self.addDockWidget(qt.Qt.RightDockWidgetArea, self._inputDock)
self._inputDock.widget.correctionB.clicked.connect(self.correct)
self._inputDock.widget.abortB.clicked.connect(self.abort)
self._inputDock.widget.dxLE.editingFinished.connect(self._updateShiftValue)
self._inputDock.widget.dyLE.editingFinished.connect(self._updateShiftValue)
self._inputDock.widget._findShiftB.clicked.connect(self._findShift)
self._chooseDimensionDock.widget.filterChanged.connect(self._filterStack)
self._chooseDimensionDock.widget.stateDisabled.connect(self._wholeStack)
[docs]
def setDataset(self, owdataset: dtypes.OWDataset):
"""Saves the dataset and updates the stack with the dataset data."""
self.parent = owdataset.parent
self.dataset = owdataset.dataset
self._update_dataset = owdataset.dataset
self.indices = owdataset.indices
self.bg_indices = owdataset.bg_indices
self.bg_dataset = owdataset.bg_dataset
self._inputDock.widget.correctionB.setEnabled(True)
if self.dataset.title != "":
self._sv.setTitleCallback(lambda idx: self.dataset.title)
if len(self.dataset.data.shape) > 3:
self._chooseDimensionDock.show()
self._chooseDimensionDock.widget.setDimensions(self._update_dataset.dims)
if not self._chooseDimensionDock.widget._checkbox.isChecked():
self._wholeStack()
[docs]
def getDataset(self, parent) -> dtypes.OWDataset:
return dtypes.OWDataset(
parent=parent,
dataset=self._update_dataset,
indices=self.indices,
bg_indices=self.bg_indices,
bg_dataset=self.bg_dataset,
)
def _updateDataset(self, widget, dataset):
self.parent._updateDataset(widget, dataset)
self.dataset = dataset
[docs]
def correct(self):
"""
Function that starts the thread to compute the shift given
at the input widget
"""
dx = self._inputDock.widget.getDx()
dy = self._inputDock.widget.getDy()
self.shift = numpy.array([dy, dx])
if self._filtered_shift is None or self._inputDock.widget.checkbox.isChecked():
frames = numpy.arange(
self._update_dataset.get_data(
indices=self.indices, dimension=self._dimension
).shape[0]
)
self.thread_correction = OperationThread(
self, self._update_dataset.apply_shift
)
self.thread_correction.setArgs(
numpy.outer(self.shift, frames), self._dimension, indices=self.indices
)
else:
self.thread_correction = OperationThread(
self, self._update_dataset.apply_shift_along_dimension
)
self.thread_correction.setArgs(
self._filtered_shift, self._dimension[0], indices=self.indices
)
self.thread_correction.finished.connect(self._updateData)
self._inputDock.widget.correctionB.setEnabled(False)
self._inputDock.widget.abortB.show()
self.thread_correction.start()
self.computingSignal.emit(True)
[docs]
def abort(self):
self._inputDock.widget.abortB.setEnabled(False)
self._update_dataset.stop_operation(Operation.SHIFT)
def _findShift(self):
if self.dataset is None:
missing_dataset_msg()
return
if self._filtered_shift is not None:
self.thread_detection = OperationThread(
self, self._update_dataset.find_shift_along_dimension
)
self.thread_detection.setArgs(self._dimension[0], indices=self.indices)
else:
self.thread_detection = OperationThread(
self, self._update_dataset.find_shift
)
self.thread_detection.setArgs(self._dimension, indices=self.indices)
self._inputDock.widget._findShiftB.setEnabled(False)
self.thread_detection.finished.connect(self._updateShift)
self.thread_detection.start()
self.computingSignal.emit(True)
def _updateShiftValue(self):
if self._filtered_shift is not None:
self._filtered_shift[self._dimension[1]] = [
self._inputDock.widget.getDy(),
self._inputDock.widget.getDx(),
]
def _updateShift(self):
self._inputDock.widget._findShiftB.setEnabled(True)
self.thread_detection.finished.disconnect(self._updateShift)
if self._filtered_shift is None:
self.shift = numpy.round(self.thread_detection.data[:, 1], 5)
else:
self._filtered_shift = []
for s in self.thread_detection.data:
try:
self._filtered_shift.append(numpy.round(s[:, 1], 5))
except (IndexError, TypeError):
self._filtered_shift.append([0, 0])
self._filtered_shift = numpy.array(self._filtered_shift)
self.shift = self._filtered_shift[self._dimension[1][0]]
self.computingSignal.emit(False)
def _updateData(self):
"""
Updates the stack with the data computed in the thread
"""
self.thread_correction.finished.disconnect(self._updateData)
self._inputDock.widget.abortB.hide()
self._inputDock.widget.abortB.setEnabled(True)
self._inputDock.widget.correctionB.setEnabled(True)
self.computingSignal.emit(False)
if self.thread_correction.data:
del self._update_dataset
self._update_dataset = self.thread_correction.data
self.thread_correction.data = None
self.thread_correction.func = None
assert self._update_dataset is not None
self.setStack(self._update_dataset)
else:
print("\nCorrection aborted")
[docs]
def setStack(self, dataset=None):
"""
Sets new data to the stack.
Mantains the current frame showed in the view.
:param Dataset dataset: if not None, data set to the stack will be from the given dataset.
"""
if dataset is None:
dataset = self.dataset
nframe = self._sv.getFrameNumber()
data = (
dataset.get_data(self.indices, self._dimension)
if dataset is not None
else None
)
self._sv.setStack(data if len(data) else None)
self._sv.setFrameNumber(nframe)
[docs]
def clearStack(self):
self._sv.setStack(None)
self._inputDock.widget.correctionB.setEnabled(False)
def _filterStack(self, dim=0, val=0):
self._dimension = [dim, val]
data = self._update_dataset.get_data(self.indices, self._dimension)
if self.dataset.dims.ndim == 2:
stack_size = self.dataset.dims.get(dim[0]).size
reset_shift = (
self._filtered_shift is None
or self._filtered_shift.shape[0] != stack_size
)
self._inputDock.widget.checkbox.show()
self._filtered_shift = (
numpy.zeros((stack_size, 2)) if reset_shift else self._filtered_shift
)
self.shift = self._filtered_shift[val[0]]
self._sv.setStack(data if len(data) else None)
def _wholeStack(self):
self._dimension = None
self._filtered_shift = None
self.shift = numpy.array([0, 0])
self._inputDock.widget.checkbox.hide()
self.setStack(self._update_dataset)
[docs]
def getStackViewColormap(self):
"""
Returns the colormap from the stackView
:rtype: silx.gui.colors.Colormap
"""
return self._sv.getColormap()
[docs]
def setStackViewColormap(self, colormap):
"""
Sets the stackView colormap
:param colormap: Colormap to set
:type colormap: silx.gui.colors.Colormap
"""
self._sv.setColormap(colormap)
@property
def shift(self):
return self._shift
@shift.setter
def shift(self, shift):
self._shift = numpy.array(shift)
self._inputDock.widget.setDx(shift[1])
self._inputDock.widget.setDy(shift[0])
class _InputDock(qt.QDockWidget):
def __init__(self, parent=None):
qt.QDockWidget.__init__(self, parent)
self.widget = _InputWidget()
self.setWidget(self.widget)
class _InputWidget(qt.QWidget):
"""
Widget used to obtain the double parameters for the shift correction.
"""
def __init__(self, parent=None):
super(_InputWidget, self).__init__(parent)
self._findShiftB = qt.QPushButton("Find shift")
labelx = qt.QLabel("Horizontal shift:")
labely = qt.QLabel("Vertical shift:")
self.dxLE = qt.QLineEdit("0.0")
self.dyLE = qt.QLineEdit("0.0")
self.correctionB = qt.QPushButton("Correct")
self.abortB = qt.QPushButton("Abort")
self.abortB.hide()
self.checkbox = qt.QCheckBox("Apply only to selected value")
self.checkbox.setChecked(False)
self.checkbox.hide()
self.dxLE.setValidator(qt.QDoubleValidator())
self.dyLE.setValidator(qt.QDoubleValidator())
layout = qt.QGridLayout()
layout.addWidget(self._findShiftB, 0, 0, 1, 2)
layout.addWidget(labelx, 1, 0)
layout.addWidget(labely, 2, 0)
layout.addWidget(self.dxLE, 1, 1)
layout.addWidget(self.dyLE, 2, 1)
layout.addWidget(self.correctionB, 4, 0, 1, 2)
layout.addWidget(self.abortB, 4, 0, 1, 2)
layout.addWidget(self.checkbox, 3, 1)
self.setLayout(layout)
def setDx(self, dx):
"""
Set the shift in the x axis
"""
self.dxLE.setText(str(dx))
self.dxLE.setCursorPosition(0)
def getDx(self):
"""
Get the shift in the x axis
:return float:
"""
return float(self.dxLE.text())
def setDy(self, dy):
"""
Set the shift in the x axis
"""
self.dyLE.setText(str(dy))
self.dyLE.setCursorPosition(0)
def getDy(self):
"""
Get the shift in the y axis
:return float:
"""
return float(self.dyLE.text())