From 4094e72c87ccc518e94ef2fbe6c1c29f1a51bfbb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 27 Sep 2024 05:15:55 -0400 Subject: [PATCH] Cancel update tasks on workflow cancellation --- temporalio/worker/_workflow_instance.py | 7 +++++++ tests/worker/test_workflow.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 2cac8c2f..f04502a5 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1851,6 +1851,13 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: err ): self._add_command().cancel_workflow_execution.SetInParent() + # Cancel update tasks, so that the update caller receives an + # update failed error. We do not currently cancel signal tasks + # since (a) doing so would require a workflow flag and (b) the + # presence of the update caller gives a strong reason to cancel + # update tasks. + for update_handler in self._in_progress_updates.values(): + update_handler.task.cancel("The workflow was cancelled.") elif self._is_workflow_failure_exception(err): # All other failure errors fail the workflow self._set_workflow_failure(err) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index a3e31e66..052a28bf 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -5646,7 +5646,16 @@ async def _run_workflow_and_get_warning(self) -> bool: with pytest.WarningsRecorder() as warnings: if self.handler_type == "-update-": assert update_task - if self.handler_waiting == "-wait-all-handlers-finish-": + + if self.workflow_termination_type == "-cancellation-": + with pytest.raises(WorkflowUpdateFailedError) as update_err: + await update_task + assert isinstance(update_err.value.cause, CancelledError) + assert ( + "the workflow was cancelled" + in str(update_err.value.cause).lower() + ) + elif self.handler_waiting == "-wait-all-handlers-finish-": await update_task else: with pytest.raises(RPCError) as update_err: