diff --git a/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py new file mode 100644 index 000000000..f8792ea94 --- /dev/null +++ b/src/snapred/backend/dao/request/CreateArtificialNormalizationRequest.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, root_validator + +from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName + + +class CreateArtificialNormalizationRequest(BaseModel): + runNumber: str + useLiteMode: bool + peakWindowClippingSize: int + smoothingParameter: float + decreaseParameter: bool = True + lss: bool = True + diffractionWorkspace: WorkspaceName + outputWorkspace: WorkspaceName = None + + @root_validator(pre=True) + def set_output_workspace(cls, values): + if values.get("diffractionWorkspace") and not values.get("outputWorkspace"): + values["outputWorkspace"] = WorkspaceName(f"{values['diffractionWorkspace']}_artificialNorm") + return values + + class Config: + arbitrary_types_allowed = True # Allow arbitrary types like WorkspaceName + extra = "forbid" # Forbid extra fields + validate_assignment = True # Enable dynamic validation diff --git a/src/snapred/backend/dao/request/ReductionRequest.py b/src/snapred/backend/dao/request/ReductionRequest.py index dd2c1099d..2b77e8480 100644 --- a/src/snapred/backend/dao/request/ReductionRequest.py +++ b/src/snapred/backend/dao/request/ReductionRequest.py @@ -23,6 +23,7 @@ class ReductionRequest(BaseModel): versions: Versions = Versions(None, None) pixelMasks: List[WorkspaceName] = [] + artificialNormalization: Optional[str] = None # TODO: Move to SNAPRequest continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET diff --git a/src/snapred/backend/recipe/GenericRecipe.py b/src/snapred/backend/recipe/GenericRecipe.py index 437ba3ebf..22284c4ce 100644 --- a/src/snapred/backend/recipe/GenericRecipe.py +++ b/src/snapred/backend/recipe/GenericRecipe.py @@ -7,6 +7,7 @@ from snapred.backend.log.logger import snapredLogger from snapred.backend.recipe.algorithm.BufferMissingColumnsAlgo import BufferMissingColumnsAlgo from snapred.backend.recipe.algorithm.CalibrationMetricExtractionAlgorithm import CalibrationMetricExtractionAlgorithm +from snapred.backend.recipe.algorithm.CreateArtificialNormalizationAlgo import CreateArtificialNormalizationAlgo from snapred.backend.recipe.algorithm.DetectorPeakPredictor import DetectorPeakPredictor from snapred.backend.recipe.algorithm.FitMultiplePeaksAlgorithm import FitMultiplePeaksAlgorithm from snapred.backend.recipe.algorithm.FocusSpectraAlgorithm import FocusSpectraAlgorithm @@ -104,3 +105,7 @@ class ConvertTableToMatrixWorkspaceRecipe(GenericRecipe[ConvertTableToMatrixWork class BufferMissingColumnsRecipe(GenericRecipe[BufferMissingColumnsAlgo]): pass + + +class ArtificialNormalizationRecipe(GenericRecipe[CreateArtificialNormalizationAlgo]): + pass diff --git a/src/snapred/backend/service/ReductionService.py b/src/snapred/backend/service/ReductionService.py index 84d09cdbf..a44280b5f 100644 --- a/src/snapred/backend/service/ReductionService.py +++ b/src/snapred/backend/service/ReductionService.py @@ -3,9 +3,14 @@ from pathlib import Path from typing import Any, Dict, List -from snapred.backend.dao.ingredients import GroceryListItem, ReductionIngredients +from snapred.backend.dao.ingredients import ( + ArtificialNormalizationIngredients, + GroceryListItem, + ReductionIngredients, +) from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, FarmFreshIngredients, ReductionExportRequest, ReductionRequest, @@ -20,6 +25,7 @@ from snapred.backend.error.StateValidationException import StateValidationException from snapred.backend.log.logger import snapredLogger from snapred.backend.recipe.algorithm.MantidSnapper import MantidSnapper +from snapred.backend.recipe.GenericRecipe import ArtificialNormalizationRecipe from snapred.backend.recipe.ReductionRecipe import ReductionRecipe from snapred.backend.service.Service import Service from snapred.backend.service.SousChef import SousChef @@ -72,6 +78,9 @@ def __init__(self): self.registerPath("checkWritePermissions", self.checkWritePermissions) self.registerPath("getSavePath", self.getSavePath) self.registerPath("getStateIds", self.getStateIds) + self.registerPath("validateReduction", self.validateReduction) + self.registerPath("artificialNormalization", self.artificialNormalization) + self.registerPath("grabDiffractionWorkspaceforArtificialNorm", self.grabDiffractionWorkspaceforArtificialNorm) return @staticmethod @@ -80,45 +89,72 @@ def name(): def validateReduction(self, request: ReductionRequest): """ - Validate the reduction request. + Validate the reduction request, providing specific messages if normalization + or calibration data is missing. Notify the user if artificial normalization + will be created when normalization is absent. :param request: a reduction request :type request: ReductionRequest """ continueFlags = ContinueWarning.Type.UNSET - # check if a normalization is present - if not self.dataFactoryService.normalizationExists(request.runNumber, request.useLiteMode): - continueFlags |= ContinueWarning.Type.MISSING_NORMALIZATION - # check if a diffraction calibration is present - if not self.dataFactoryService.calibrationExists(request.runNumber, request.useLiteMode): + message = "" + + # Check if a normalization is present + normalizationExists = self.dataFactoryService.normalizationExists(request.runNumber, request.useLiteMode) + # Check if a diffraction calibration is present + calibrationExists = self.dataFactoryService.calibrationExists(request.runNumber, request.useLiteMode) + + # Determine the action based on missing components + if not calibrationExists and normalizationExists: + # Case: No calibration but normalization exists continueFlags |= ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION + message = ( + "Warning: diffraction calibration is missing." + "If you continue, default instrument geometry will be used." + ) + elif calibrationExists and not normalizationExists: + # Case: Calibration exists but normalization is missing + continueFlags |= ContinueWarning.Type.MISSING_NORMALIZATION + message = ( + "Warning: Reduction is missing normalization data. " + "Artificial normalization will be created in place of actual normalization. " + "Would you like to continue?" + ) + elif not calibrationExists and not normalizationExists: + # Case: No calibration and no normalization + continueFlags |= ( + ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION | ContinueWarning.Type.MISSING_NORMALIZATION + ) + message = ( + "Warning: Reduction is missing both normalization and calibration data. " + "If you continue, default instrument geometry will be used and data will be artificially normalized. " + ) - # remove any continue flags that are present in the request by xor-ing with the flags + # Remove any continue flags that are present in the request by XOR-ing with the flags if request.continueFlags: - continueFlags = continueFlags ^ (request.continueFlags & continueFlags) + continueFlags ^= request.continueFlags & continueFlags - if continueFlags: - raise ContinueWarning( - "Reduction is missing calibration data, continue in uncalibrated mode?", continueFlags - ) + # If there are any continue flags set, raise a ContinueWarning with the appropriate message + if continueFlags and message: + raise ContinueWarning(message, continueFlags) - # ... ensure separate continue warnings ... + # Ensure separate continue warnings for permission check continueFlags = ContinueWarning.Type.UNSET - # check that the user has write permissions to the save directory + # Check that the user has write permissions to the save directory if not self.checkWritePermissions(request.runNumber): continueFlags |= ContinueWarning.Type.NO_WRITE_PERMISSIONS - # remove any continue flags that are present in the request by xor-ing with the flags + # Remove any continue flags that are present in the request by XOR-ing with the flags if request.continueFlags: - continueFlags = continueFlags ^ (request.continueFlags & continueFlags) + continueFlags ^= request.continueFlags & continueFlags if continueFlags: raise ContinueWarning( f"
It looks like you don't have permissions to write to "
f"
{self.getSavePath(request.runNumber)},
"
- + "but you can still save using the workbench tools.
Would you like to continue anyway?
", + "but you can still save using the workbench tools." + "Would you like to continue anyway?
", continueFlags, ) @@ -130,7 +166,6 @@ def reduction(self, request: ReductionRequest): :param request: a ReductionRequest object holding needed information :type request: ReductionRequest """ - self.validateReduction(request) groupingResults = self.fetchReductionGroupings(request) request.focusGroups = groupingResults["focusGroups"] @@ -424,3 +459,42 @@ def _groupByVanadiumVersion(self, requests: List[SNAPRequest]): def getCompatibleMasks(self, request: ReductionRequest) -> List[WorkspaceName]: runNumber, useLiteMode = request.runNumber, request.useLiteMode return self.dataFactoryService.getCompatibleReductionMasks(runNumber, useLiteMode) + + def artificialNormalization(self, request: CreateArtificialNormalizationRequest): + ingredients = ArtificialNormalizationIngredients( + peakWindowClippingSize=request.peakWindowClippingSize, + smoothingParameter=request.smoothingParameter, + decreaseParameter=request.decreaseParameter, + lss=request.lss, + ) + artificialNormWorkspace = ArtificialNormalizationRecipe().executeRecipe( + InputWorkspace=request.diffractionWorkspace, + Ingredients=ingredients, + OutputWorkspace=request.outputWorkspace, + ) + return artificialNormWorkspace + + def grabDiffractionWorkspaceforArtificialNorm(self, request: ReductionRequest): + try: + calVersion = None + calVersion = self.dataFactoryService.getThisOrLatestCalibrationVersion( + request.runNumber, request.useLiteMode + ) + calRecord = self.dataFactoryService.getCalibrationRecord(request.runNumber, request.useLiteMode, calVersion) + filePath = self.dataFactoryService.getCalibrationDataPath( + request.runNumber, request.useLiteMode, calVersion + ) + diffCalOutput = calRecord.workspaces[wngt.DIFFCAL_OUTPUT][0] + diffcalOutputFilePath = str(filePath) + "/" + str(diffCalOutput) + ".nxs.h5" + + groceries = self.groceryService.fetchWorkspace(diffcalOutputFilePath, "diffractionWorkspace") + diffractionWorkspace = groceries.get("workspace") + except: # noqa: E722 + raise RuntimeError( + "This feature is not yet implemented. " + "Artificial normalization cannot currently be made for uncalibrated data as we are missing peak positions. " # noqa: E501 + "We are working on a solution to this problem.\n\n " + f"No calibration record found for run number: {request.runNumber}.\n" + "Please create calibration data for this run number and try again." + ) + return diffractionWorkspace diff --git a/src/snapred/ui/presenter/WorkflowPresenter.py b/src/snapred/ui/presenter/WorkflowPresenter.py index 2a6d64321..2b543331d 100644 --- a/src/snapred/ui/presenter/WorkflowPresenter.py +++ b/src/snapred/ui/presenter/WorkflowPresenter.py @@ -141,12 +141,7 @@ def handleSkipButtonClicked(self): def advanceWorkflow(self): if self.view.currentTab >= self.view.totalNodes - 1: - QMessageBox.information( - self.view, - "‧₊Workflow Complete‧₊", - self.completionMessageLambda(), - ) - self.reset() + self.completeWorkflow() else: self.view.advanceWorkflow() @@ -196,3 +191,12 @@ def continueAnyway(self, continueInfo: ContinueWarning.Model): else: raise NotImplementedError(f"Continue anyway handler not implemented: {self.view.tabModel}") self.handleContinueButtonClicked(self.view.tabModel) + + def completeWorkflow(self): + # Directly show the completion message and reset the workflow + QMessageBox.information( + self.view, + "‧₊Workflow Complete‧₊", + self.completionMessageLambda(), + ) + self.reset() diff --git a/src/snapred/ui/view/BackendRequestView.py b/src/snapred/ui/view/BackendRequestView.py index 6def779e2..1644b17f0 100644 --- a/src/snapred/ui/view/BackendRequestView.py +++ b/src/snapred/ui/view/BackendRequestView.py @@ -6,6 +6,7 @@ from snapred.ui.widget.LabeledField import LabeledField from snapred.ui.widget.MultiSelectDropDown import MultiSelectDropDown from snapred.ui.widget.SampleDropDown import SampleDropDown +from snapred.ui.widget.TrueFalseDropDown import TrueFalseDropDown class BackendRequestView(QWidget): @@ -33,6 +34,9 @@ def _labeledCheckBox(self, label): def _sampleDropDown(self, label, items=[]): return SampleDropDown(label, items, self) + def _trueFalseDropDown(self, label): + return TrueFalseDropDown(label, self) + def _multiSelectDropDown(self, label, items=[]): return MultiSelectDropDown(label, items, self) diff --git a/src/snapred/ui/view/reduction/ArtificialNormalizationView.py b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py new file mode 100644 index 000000000..f08c5a5b2 --- /dev/null +++ b/src/snapred/ui/view/reduction/ArtificialNormalizationView.py @@ -0,0 +1,192 @@ +import matplotlib.pyplot as plt +from mantid.plots.datafunctions import get_spectrum +from mantid.simpleapi import mtd +from qtpy.QtCore import Qt, Signal, Slot +from qtpy.QtWidgets import ( + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, +) +from snapred.meta.Config import Config +from snapred.meta.decorators.Resettable import Resettable +from snapred.ui.view.BackendRequestView import BackendRequestView +from snapred.ui.widget.SmoothingSlider import SmoothingSlider +from workbench.plotting.figuremanager import MantidFigureCanvas +from workbench.plotting.toolbar import WorkbenchNavigationToolbar + + +@Resettable +class ArtificialNormalizationView(BackendRequestView): + signalRunNumberUpdate = Signal(str) + signalValueChanged = Signal(float, bool, bool, int) + signalUpdateRecalculationButton = Signal(bool) + signalUpdateFields = Signal(float, bool, bool) + + def __init__(self, parent=None): + super().__init__(parent=parent) + + # create the run number fields + self.fieldRunNumber = self._labeledField("Run Number", QLineEdit()) + + # create the graph elements + self.figure = plt.figure(constrained_layout=True) + self.canvas = MantidFigureCanvas(self.figure) + self.navigationBar = WorkbenchNavigationToolbar(self.canvas, self) + + # create the other specification elements + self.lssDropdown = self._trueFalseDropDown("LSS") + self.decreaseParameterDropdown = self._trueFalseDropDown("Decrease Parameter") + + # disable run number + for x in [self.fieldRunNumber]: + x.setEnabled(False) + + # create the adjustment controls + self.smoothingSlider = self._labeledField("Smoothing", SmoothingSlider()) + self.peakWindowClippingSize = self._labeledField( + "Peak Window Clipping Size", + QLineEdit(str(Config["constants.ArtificialNormalization.peakWindowClippingSize"])), + ) + + peakControlLayout = QHBoxLayout() + peakControlLayout.addWidget(self.smoothingSlider, 2) + peakControlLayout.addWidget(self.peakWindowClippingSize) + + # a big ol recalculate button + self.recalculationButton = QPushButton("Recalculate") + self.recalculationButton.clicked.connect(self.emitValueChange) + + # add all elements to the grid layout + self.layout.addWidget(self.fieldRunNumber, 0, 0) + self.layout.addWidget(self.navigationBar, 1, 0) + self.layout.addWidget(self.canvas, 2, 0, 1, -1) + self.layout.addLayout(peakControlLayout, 3, 0, 1, 2) + self.layout.addWidget(self.lssDropdown, 4, 0) + self.layout.addWidget(self.decreaseParameterDropdown, 4, 1) + self.layout.addWidget(self.recalculationButton, 5, 0, 1, 2) + + self.layout.setRowStretch(2, 10) + + # store the initial layout without graphs + self.initialLayoutHeight = self.size().height() + + self.signalUpdateRecalculationButton.connect(self.setEnableRecalculateButton) + self.signalUpdateFields.connect(self._updateFields) + self.signalRunNumberUpdate.connect(self._updateRunNumber) + + self.messageLabel = QLabel("") + self.messageLabel.setStyleSheet("font-size: 24px; font-weight: bold; color: black;") + self.messageLabel.setAlignment(Qt.AlignCenter) + self.layout.addWidget(self.messageLabel, 0, 0, 1, 2) + self.messageLabel.hide() + + @Slot(str) + def _updateRunNumber(self, runNumber): + self.fieldRunNumber.setText(runNumber) + + def updateRunNumber(self, runNumber): + self.signalRunNumberUpdate.emit(runNumber) + + @Slot(float, bool, bool) + def _updateFields(self, smoothingParameter, lss, decreaseParameter): + self.smoothingSlider.field.setValue(smoothingParameter) + self.lssDropdown.setCurrentIndex(lss) + self.decreaseParameterDropdown.setCurrentIndex(decreaseParameter) + + def updateFields(self, smoothingParameter, lss, decreaseParameter): + self.signalUpdateFields.emit(smoothingParameter, lss, decreaseParameter) + + @Slot() + def emitValueChange(self): + # verify the fields before recalculation + try: + smoothingValue = self.smoothingSlider.field.value() + lss = self.lssDropdown.currentIndex() == "True" + decreaseParameter = self.decreaseParameterDropdown.currentIndex == "True" + peakWindowClippingSize = int(self.peakWindowClippingSize.field.text()) + except ValueError as e: + QMessageBox.warning( + self, + "Invalid Peak Parameters", + f"Smoothing or peak window clipping size is invalid: {str(e)}", + QMessageBox.Ok, + ) + return + self.signalValueChanged.emit(smoothingValue, lss, decreaseParameter, peakWindowClippingSize) + + def updateWorkspaces(self, diffractionWorkspace, artificialNormWorkspace): + self.diffractionWorkspace = diffractionWorkspace + self.artificialNormWorkspace = artificialNormWorkspace + self._updateGraphs() + + def _updateGraphs(self): + # get the updated workspaces and optimal graph grid + diffractionWorkspace = mtd[self.diffractionWorkspace] + artificialNormWorkspace = mtd[self.artificialNormWorkspace] + numGraphs = diffractionWorkspace.getNumberHistograms() + nrows, ncols = self._optimizeRowsAndCols(numGraphs) + + # now re-draw the figure + self.figure.clear() + for i in range(numGraphs): + ax = self.figure.add_subplot(nrows, ncols, i + 1, projection="mantid") + ax.plot(diffractionWorkspace, wkspIndex=i, label="Diffcal Data", normalize_by_bin_width=True) + ax.plot( + artificialNormWorkspace, + wkspIndex=i, + label="Artificial Normalization Data", + normalize_by_bin_width=True, + linestyle="--", + ) + ax.legend() + ax.tick_params(direction="in") + ax.set_title(f"Group ID: {i + 1}") + # fill in the discovered peaks for easier viewing + x, y, _, _ = get_spectrum(diffractionWorkspace, i, normalize_by_bin_width=True) + # for each detected peak in this group, shade in the peak region + + # resize window and redraw + self.setMinimumHeight(self.initialLayoutHeight + int(self.figure.get_size_inches()[1] * self.figure.dpi)) + self.canvas.draw() + + def _optimizeRowsAndCols(self, numGraphs): + # Get best size for layout + sqrtSize = int(numGraphs**0.5) + if sqrtSize == numGraphs**0.5: + rowSize = sqrtSize + colSize = sqrtSize + elif numGraphs <= ((sqrtSize + 1) * sqrtSize): + rowSize = sqrtSize + colSize = sqrtSize + 1 + else: + rowSize = sqrtSize + 1 + colSize = sqrtSize + 1 + return rowSize, colSize + + @Slot(bool) + def setEnableRecalculateButton(self, enable): + self.recalculationButton.setEnabled(enable) + + def disableRecalculateButton(self): + self.signalUpdateRecalculationButton.emit(False) + + def enableRecalculateButton(self): + self.signalUpdateRecalculationButton.emit(True) + + def verify(self): + # TODO what needs to be verified? + return True + + def showMessage(self, message: str): + self.clearView() + self.messageLabel.setText(message) + self.messageLabel.show() + + def clearView(self): + # Remove all existing widgets except the layout + for i in reversed(range(self.layout.count())): + widget = self.layout.itemAt(i).widget() + if widget is not None and widget != self.messageLabel: + widget.deleteLater() # Delete the widget diff --git a/src/snapred/ui/view/reduction/ReductionRequestView.py b/src/snapred/ui/view/reduction/ReductionRequestView.py index be181e2b4..58c1c66cc 100644 --- a/src/snapred/ui/view/reduction/ReductionRequestView.py +++ b/src/snapred/ui/view/reduction/ReductionRequestView.py @@ -136,6 +136,8 @@ def clearRunNumbers(self): def verify(self): currentText = self.runNumberDisplay.toPlainText() runNumbers = [num.strip() for num in currentText.split("\n") if num.strip()] + if not runNumbers: + raise ValueError("Please enter at least one run number.") for runNumber in runNumbers: if not runNumber.isdigit(): raise ValueError( diff --git a/src/snapred/ui/widget/TrueFalseDropDown.py b/src/snapred/ui/widget/TrueFalseDropDown.py new file mode 100644 index 000000000..bb41951f2 --- /dev/null +++ b/src/snapred/ui/widget/TrueFalseDropDown.py @@ -0,0 +1,31 @@ +from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QWidget + + +class TrueFalseDropDown(QWidget): + def __init__(self, label, parent=None): + super(TrueFalseDropDown, self).__init__(parent) + self.setStyleSheet("background-color: #F5E9E2;") + self._label = QLabel(label + ":", self) + + self.dropDown = QComboBox() + self._initItems() + + layout = QHBoxLayout() + layout.addWidget(self._label) + layout.addWidget(self.dropDown) + layout.setContentsMargins(5, 5, 5, 5) + self.setLayout(layout) + + def _initItems(self): + self.dropDown.clear() + self.dropDown.addItems(["True", "False"]) + self.dropDown.setCurrentIndex(0) + + def currentIndex(self): + return self.dropDown.currentIndex() + + def setCurrentIndex(self, index): + self.dropDown.setCurrentIndex(index) + + def currentText(self): + return self.dropDown.currentText() diff --git a/src/snapred/ui/workflow/ReductionWorkflow.py b/src/snapred/ui/workflow/ReductionWorkflow.py index ebc9edb06..f3d4a3b51 100644 --- a/src/snapred/ui/workflow/ReductionWorkflow.py +++ b/src/snapred/ui/workflow/ReductionWorkflow.py @@ -3,6 +3,7 @@ from qtpy.QtCore import Slot from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, ReductionExportRequest, ReductionRequest, ) @@ -11,8 +12,8 @@ from snapred.backend.log.logger import snapredLogger from snapred.meta.decorators.ExceptionToErrLog import ExceptionToErrLog from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName +from snapred.ui.view.reduction.ArtificialNormalizationView import ArtificialNormalizationView from snapred.ui.view.reduction.ReductionRequestView import ReductionRequestView -from snapred.ui.view.reduction.ReductionSaveView import ReductionSaveView from snapred.ui.workflow.WorkflowBuilder import WorkflowBuilder from snapred.ui.workflow.WorkflowImplementer import WorkflowImplementer @@ -33,9 +34,7 @@ def __init__(self, parent=None): self._reductionRequestView.enterRunNumberButton.clicked.connect(lambda: self._populatePixelMaskDropdown()) self._reductionRequestView.pixelMaskDropdown.dropDown.view().pressed.connect(self._onPixelMaskSelection) - self._reductionSaveView = ReductionSaveView( - parent=parent, - ) + self._artificialNormalizationView = ArtificialNormalizationView(parent=parent) self.workflow = ( WorkflowBuilder( @@ -51,10 +50,16 @@ def __init__(self, parent=None): "Reduction", continueAnywayHandler=self._continueAnywayHandler, ) + .addNode( + self._continueWithNormalization, + self._artificialNormalizationView, + "Artificial Normalization", + ) .build() ) self._reductionRequestView.retainUnfocusedDataCheckbox.checkedChanged.connect(self._enableConvertToUnits) + self._artificialNormalizationView.signalValueChanged.connect(self.onArtificialNormalizationValueChange) def _enableConvertToUnits(self): state = self._reductionRequestView.retainUnfocusedDataCheckbox.isChecked() @@ -130,11 +135,8 @@ def _validateRunNumbers(self, runNumbers: List[str]): stateIds = self.request(path="reduction/getStateIds", payload=runNumbers).data except Exception as e: # noqa: BLE001 raise ValueError(f"Unable to get instrument state for {runNumbers}: {e}") - if len(stateIds) > 1: - stateId = stateIds[0] - for id_ in stateIds[1:]: - if id_ != stateId: - raise ValueError("all run numbers must be from the same state") + if len(stateIds) > 1 and len(set(stateIds)) > 1: + raise ValueError("All run numbers must be from the same state") def _reconstructPixelMaskNames(self, pixelMasks: List[str]) -> List[WorkspaceName]: return [self._compatibleMasks[name] for name in pixelMasks] @@ -169,26 +171,102 @@ def _triggerReduction(self, workflowPresenter): convertUnitsTo=self._reductionRequestView.convertUnitsDropdown.currentText(), ) - response = self.request(path="reduction/", payload=request_) - if response.code == ResponseCode.OK: - record, unfocusedData = response.data.record, response.data.unfocusedData + # Validate reduction; if artificial normalization is needed, handle it + response = self.request(path="reduction/validateReduction", payload=request_) + if ContinueWarning.Type.MISSING_NORMALIZATION in self.continueAnywayFlags: + self._artificialNormalizationView.updateRunNumber(runNumber) + response = self.request(path="reduction/grabDiffractionWorkspaceforArtificialNorm", payload=request_) + self._artificialNormalization(workflowPresenter, response.data, runNumber) + else: + # Proceed with reduction if artificial normalization is not needed + response = self.request(path="reduction/", payload=request_) + if response.code == ResponseCode.OK: + record, unfocusedData = response.data.record, response.data.unfocusedData + self._finalizeReduction(record, unfocusedData) + self._artificialNormalizationView.updateRunNumber(runNumber) + self._artificialNormalizationView.showMessage("Artificial Normalization not Needed") + workflowPresenter.advanceWorkflow() + return self.responses[-1] + + def _artificialNormalization(self, workflowPresenter, responseData, runNumber): + """Handles artificial normalization for the workflow.""" + view = workflowPresenter.widget.tabView # noqa: F841 + request_ = CreateArtificialNormalizationRequest( + runNumber=runNumber, + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + peakWindowClippingSize=int(self._artificialNormalizationView.peakWindowClippingSize.field.text()), + smoothingParameter=self._artificialNormalizationView.smoothingSlider.field.value(), + decreaseParameter=self._artificialNormalizationView.decreaseParameterDropdown.currentIndex() == 1, + lss=self._artificialNormalizationView.lssDropdown.currentIndex() == 1, + diffractionWorkspace=responseData, + ) + response = self.request(path="reduction/artificialNormalization", payload=request_) + # Update artificial normalization view with the response + if response.code == ResponseCode.OK: + self._artificialNormalizationView.updateWorkspaces(responseData, response.data) + else: + raise RuntimeError("Failed to run artificial normalization.") + + return self.responses[-1] + + @Slot(float, bool, bool, int) + def onArtificialNormalizationValueChange(self, smoothingValue, lss, decreaseParameter, peakWindowClippingSize): + """Updates artificial normalization based on user input.""" + self._artificialNormalizationView.disableRecalculateButton() + runNumber = self._artificialNormalizationView.fieldRunNumber.text() + diffractionWorkspace = self._artificialNormalizationView.diffractionWorkspace + + request_ = CreateArtificialNormalizationRequest( + runNumber=runNumber, + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + peakWindowClippingSize=peakWindowClippingSize, + smoothingParameter=smoothingValue, + decreaseParameter=decreaseParameter, + lss=lss, + diffractionWorkspace=diffractionWorkspace, + ) - # .. update "save" panel message: - self.savePath = self.request(path="reduction/getSavePath", payload=record.runNumber).data + response = self.request(path="reduction/artificialNormalization", payload=request_) + self._artificialNormalizationView.updateWorkspaces(diffractionWorkspace, response.data) + self._artificialNormalizationView.enableRecalculateButton() + + def _continueWithNormalization(self, workflowPresenter): # noqa: ARG002 + """Continues the workflow using the artificial normalization workspace.""" + artificialNormWorkspace = self._artificialNormalizationView.artificialNormWorkspace + pixelMasks = self._reconstructPixelMaskNames(self._reductionRequestView.getPixelMasks()) + timestamp = self.request(path="reduction/getUniqueTimestamp").data - # Save the reduced data. (This is automatic: it happens before the "save" panel opens.) - if ContinueWarning.Type.NO_WRITE_PERMISSIONS not in self.continueAnywayFlags: - self.request(path="reduction/save", payload=ReductionExportRequest(record=record)) + request_ = ReductionRequest( + runNumber=str(self._artificialNormalizationView.fieldRunNumber.text()), + useLiteMode=self._reductionRequestView.liteModeToggle.field.getState(), + timestamp=timestamp, + continueFlags=self.continueAnywayFlags, + pixelMasks=pixelMasks, + keepUnfocused=self._reductionRequestView.retainUnfocusedDataCheckbox.isChecked(), + convertUnitsTo=self._reductionRequestView.convertUnitsDropdown.currentText(), + artificialNormalization=artificialNormWorkspace, + ) - # Retain the output workspaces after the workflow is complete. - self.outputs.extend(record.workspaceNames) + response = self.request(path="reduction/", payload=request_) + if response.code == ResponseCode.OK: + record, unfocusedData = response.data.record, response.data.unfocusedData + self._finalizeReduction(record, unfocusedData) - # Also retain the unfocused data after the workflow is complete (if the box was checked), - # but do not actually save it as part of the reduction-data file. - # The unfocused data does not get added to the response.workspaces list. - if unfocusedData is not None: - self.outputs.append(unfocusedData) + return self.responses[-1] + def _finalizeReduction(self, record, unfocusedData): + """Handles post-reduction tasks, including saving and workspace management.""" + self.savePath = self.request(path="reduction/getSavePath", payload=record.runNumber).data + # Save the reduced data. (This is automatic: it happens before the "save" panel opens.) + if ContinueWarning.Type.NO_WRITE_PERMISSIONS not in self.continueAnywayFlags: + self.request(path="reduction/save", payload=ReductionExportRequest(record=record)) + # Retain the output workspaces after the workflow is complete. + self.outputs.extend(record.workspaceNames) + # Also retain the unfocused data after the workflow is complete (if the box was checked), + # but do not actually save it as part of the reduction-data file. + # The unfocused data does not get added to the response.workspaces list. + if unfocusedData: + self.outputs.append(unfocusedData) # Note that the run number is deliberately not deleted from the run numbers list. # Almost certainly it should be moved to a "completed run numbers" list. @@ -197,8 +275,6 @@ def _triggerReduction(self, workflowPresenter): # TODO: make '_clearWorkspaces' a public method (i.e make this combination a special `cleanup` method). self._clearWorkspaces(exclude=self.outputs, clearCachedWorkspaces=True) - return self.responses[-1] - @property def widget(self): return self.workflow.presenter.widget diff --git a/tests/unit/backend/service/test_ReductionService.py b/tests/unit/backend/service/test_ReductionService.py index c8a157f9c..4790097a3 100644 --- a/tests/unit/backend/service/test_ReductionService.py +++ b/tests/unit/backend/service/test_ReductionService.py @@ -14,6 +14,7 @@ from snapred.backend.dao.ingredients.ReductionIngredients import ReductionIngredients from snapred.backend.dao.reduction.ReductionRecord import ReductionRecord from snapred.backend.dao.request import ( + CreateArtificialNormalizationRequest, ReductionExportRequest, ReductionRequest, ) @@ -300,17 +301,16 @@ def test_validateReduction_no_permissions_and_no_calibrations(self): fakeDataService.calibrationExists.return_value = False fakeDataService.normalizationExists.return_value = False self.instance.dataFactoryService = fakeDataService + fakeExportService = mock.Mock() fakeExportService.checkWritePermissions.return_value = False self.instance.dataExportService = fakeExportService + with pytest.raises(ContinueWarning) as excInfo: self.instance.validateReduction(self.request) - # Note: this tests the _first_ continue-anyway check, - # which _only_ deals with the calibrations. - assert ( - excInfo.value.model.flags - == ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION | ContinueWarning.Type.MISSING_NORMALIZATION + assert excInfo.value.model.flags == ( + ContinueWarning.Type.MISSING_DIFFRACTION_CALIBRATION | ContinueWarning.Type.MISSING_NORMALIZATION ) def test_validateReduction_no_permissions_and_no_calibrations_first_reentry(self): @@ -328,8 +328,6 @@ def test_validateReduction_no_permissions_and_no_calibrations_first_reentry(self with pytest.raises(ContinueWarning) as excInfo: self.instance.validateReduction(self.request) - # Note: this tests re-entry for the _first_ continue-anyway check, - # but with no re-entry for the second continue-anyway check. assert excInfo.value.model.flags == ContinueWarning.Type.NO_WRITE_PERMISSIONS def test_validateReduction_no_permissions_and_no_calibrations_second_reentry(self): @@ -350,6 +348,67 @@ def test_validateReduction_no_permissions_and_no_calibrations_second_reentry(sel # and in addition, re-entry for the second continue-anyway check. self.instance.validateReduction(self.request) + @mock.patch(thisService + "ArtificialNormalizationRecipe") + def test_artificialNormalization(self, mockArtificialNormalizationRecipe): + mockArtificialNormalizationRecipe.return_value = mock.Mock() + mockResult = mock.Mock() + mockArtificialNormalizationRecipe.return_value.executeRecipe.return_value = mockResult + + request = CreateArtificialNormalizationRequest( + runNumber="123", + useLiteMode=False, + peakWindowClippingSize=5, + smoothingParameter=0.1, + decreaseParameter=True, + lss=True, + diffractionWorkspace="mock_diffraction_workspace", + outputWorkspace="mock_output_workspace", + ) + + result = self.instance.artificialNormalization(request) + + mockArtificialNormalizationRecipe.return_value.executeRecipe.assert_called_once_with( + InputWorkspace=request.diffractionWorkspace, + Ingredients=mock.ANY, + OutputWorkspace=request.outputWorkspace, + ) + assert result == mockResult + + @mock.patch(thisService + "GroceryService") + @mock.patch(thisService + "DataFactoryService") + def test_grabDiffractionWorkspaceforArtificialNorm(self, mockDataFactoryService, mockGroceryService): + self.instance.groceryService = mockGroceryService + self.instance.dataFactoryService = mockDataFactoryService + + request = ReductionRequest( + runNumber="123", + useLiteMode=False, + timestamp=self.instance.getUniqueTimestamp(), + versions=(1, 2), + pixelMasks=[], + focusGroups=[FocusGroup(name="apple", definition="path/to/grouping")], + ) + + mockCalVersion = 1 + mockDataFactoryService.getThisOrLatestCalibrationVersion = mock.Mock(return_value=mockCalVersion) + + mockCalRecord = mock.Mock() + mockCalRecord.workspaces = {"diffCalOutput": ["mock_diffraction_workspace"]} + + mockDataFactoryService.getCalibrationRecord = mock.Mock(return_value=mockCalRecord) + + mockDataFactoryService.getCalibrationDataPath = mock.Mock(return_value="mock/path/to/calibration") + + mockGroceryService.fetchWorkspace = mock.Mock(return_value={"workspace": "mock_diffraction_workspace"}) + + result = self.instance.grabDiffractionWorkspaceforArtificialNorm(request) + + expected_file_path = "mock/path/to/calibration/mock_diffraction_workspace.nxs.h5" + mockGroceryService.fetchWorkspace.assert_called_once_with(expected_file_path, "diffractionWorkspace") + + # Verify the result + assert result == "mock_diffraction_workspace" + class TestReductionServiceMasks: @pytest.fixture(autouse=True, scope="class")