diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index 8bc22786..a94f5722 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -442,7 +442,7 @@ async def wait_for_completion_or_create_check_status_response( lambda: self._create_http_response(500, status.to_json()), } - result = switch_statement.get(OrchestrationRuntimeStatus(status.runtime_status)) + result = switch_statement.get(status.runtime_status) if result: return result() @@ -546,3 +546,57 @@ def _get_raise_event_url( request_url += "?" + "&".join(query) return request_url + + async def rewind(self, + instance_id: str, + reason: str, + task_hub_name: Optional[str] = None, + connection_name: Optional[str] = None): + """Return / "rewind" a failed orchestration instance to a prior "healthy" state. + + Parameters + ---------- + instance_id: str + The ID of the orchestration instance to rewind. + reason: str + The reason for rewinding the orchestration instance. + task_hub_name: Optional[str] + The TaskHub of the orchestration to rewind + connection_name: Optional[str] + Name of the application setting containing the storage + connection string to use. + + Raises + ------ + Exception: + In case of a failure, it reports the reason for the exception + """ + request_url: str = "" + if self._orchestration_bindings.rpc_base_url: + path = f"instances/{instance_id}/rewind?reason={reason}" + query: List[str] = [] + if not (task_hub_name is None): + query.append(f"taskHub={task_hub_name}") + if not (connection_name is None): + query.append(f"connection={connection_name}") + if len(query) > 0: + path += "&" + "&".join(query) + + request_url = f"{self._orchestration_bindings.rpc_base_url}" + path + else: + raise Exception("The Python SDK only supports RPC endpoints." + + "Please remove the `localRpcEnabled` setting from host.json") + + response = await self._post_async_request(request_url, None) + status: int = response[0] + if status == 200 or status == 202: + return + elif status == 404: + ex_msg = f"No instance with ID {instance_id} found." + raise Exception(ex_msg) + elif status == 410: + ex_msg = "The rewind operation is only supported on failed orchestration instances." + raise Exception(ex_msg) + else: + ex_msg = response[1] + raise Exception(ex_msg) diff --git a/azure/durable_functions/models/DurableOrchestrationStatus.py b/azure/durable_functions/models/DurableOrchestrationStatus.py index 3dd20a51..dba168c4 100644 --- a/azure/durable_functions/models/DurableOrchestrationStatus.py +++ b/azure/durable_functions/models/DurableOrchestrationStatus.py @@ -1,7 +1,7 @@ from datetime import datetime from dateutil.parser import parse as dt_parse from typing import Any, List, Dict, Optional, Union - +from .OrchestrationRuntimeStatus import OrchestrationRuntimeStatus from .utils.json_utils import add_attrib, add_datetime_attrib @@ -15,7 +15,8 @@ class DurableOrchestrationStatus: def __init__(self, name: Optional[str] = None, instanceId: Optional[str] = None, createdTime: Optional[str] = None, lastUpdatedTime: Optional[str] = None, input: Optional[Any] = None, output: Optional[Any] = None, - runtimeStatus: Optional[str] = None, customStatus: Optional[Any] = None, + runtimeStatus: Optional[OrchestrationRuntimeStatus] = None, + customStatus: Optional[Any] = None, history: Optional[List[Any]] = None, **kwargs): self._name: Optional[str] = name @@ -26,7 +27,9 @@ def __init__(self, name: Optional[str] = None, instanceId: Optional[str] = None, if lastUpdatedTime is not None else None self._input: Any = input self._output: Any = output - self._runtime_status: Optional[str] = runtimeStatus # TODO: GH issue 178 + self._runtime_status: Optional[OrchestrationRuntimeStatus] = runtimeStatus + if runtimeStatus is not None: + self._runtime_status = OrchestrationRuntimeStatus(runtimeStatus) self._custom_status: Any = customStatus self._history: Optional[List[Any]] = history if kwargs is not None: @@ -82,7 +85,8 @@ def to_json(self) -> Dict[str, Union[int, str]]: add_datetime_attrib(json, self, 'last_updated_time', 'lastUpdatedTime') add_attrib(json, self, 'output') add_attrib(json, self, 'input_', 'input') - add_attrib(json, self, 'runtime_status', 'runtimeStatus') + if self.runtime_status is not None: + json["runtimeStatus"] = self.runtime_status.name add_attrib(json, self, 'custom_status', 'customStatus') add_attrib(json, self, 'history') return json @@ -129,7 +133,7 @@ def output(self) -> Any: return self._output @property - def runtime_status(self) -> Optional[str]: + def runtime_status(self) -> Optional[OrchestrationRuntimeStatus]: """Get the runtime status of the orchestration instance.""" return self._runtime_status diff --git a/azure/durable_functions/models/Task.py b/azure/durable_functions/models/Task.py index 67ee0466..c9f1b771 100644 --- a/azure/durable_functions/models/Task.py +++ b/azure/durable_functions/models/Task.py @@ -17,7 +17,7 @@ class Task: """ def __init__(self, is_completed, is_faulted, action, - result=None, timestamp=None, id_=None, exc=None): + result=None, timestamp=None, id_=None, exc=None, is_played=False): self._is_completed: bool = is_completed self._is_faulted: bool = is_faulted self._action: Action = action @@ -25,6 +25,7 @@ def __init__(self, is_completed, is_faulted, action, self._timestamp: datetime = timestamp self._id = id_ self._exception = exc + self._is_played = is_played @property def is_completed(self) -> bool: diff --git a/azure/durable_functions/models/TaskSet.py b/azure/durable_functions/models/TaskSet.py index 7f3d3b75..c9d7546d 100644 --- a/azure/durable_functions/models/TaskSet.py +++ b/azure/durable_functions/models/TaskSet.py @@ -17,13 +17,14 @@ class TaskSet: """ def __init__(self, is_completed, actions, result, is_faulted=False, - timestamp=None, exception=None): + timestamp=None, exception=None, is_played=False): self._is_completed: bool = is_completed self._actions: List[Action] = actions self._result = result self._is_faulted: bool = is_faulted self._timestamp: datetime = timestamp self._exception = exception + self._is_played = is_played @property def is_completed(self) -> bool: diff --git a/azure/durable_functions/models/utils/http_utils.py b/azure/durable_functions/models/utils/http_utils.py index 62d28f21..e45cef68 100644 --- a/azure/durable_functions/models/utils/http_utils.py +++ b/azure/durable_functions/models/utils/http_utils.py @@ -65,5 +65,5 @@ async def delete_async_request(url: str) -> List[Union[int, Any]]: """ async with aiohttp.ClientSession() as session: async with session.delete(url) as response: - data = await response.json() + data = await response.json(content_type=None) return [response.status, data] diff --git a/azure/durable_functions/orchestrator.py b/azure/durable_functions/orchestrator.py index 97705789..9bb06fcf 100644 --- a/azure/durable_functions/orchestrator.py +++ b/azure/durable_functions/orchestrator.py @@ -90,6 +90,7 @@ def handle(self, context: DurableOrchestrationContext): continue self._reset_timestamp() + self.durable_context._is_replaying = generation_state._is_played generation_state = self._generate_next(generation_state) except StopIteration as sie: diff --git a/azure/durable_functions/tasks/call_activity.py b/azure/durable_functions/tasks/call_activity.py index 7ff62b22..c5b094b7 100644 --- a/azure/durable_functions/tasks/call_activity.py +++ b/azure/durable_functions/tasks/call_activity.py @@ -40,6 +40,7 @@ def call_activity_task( is_completed=True, is_faulted=False, action=new_action, + is_played=task_completed._is_played, result=parse_history_event(task_completed), timestamp=task_completed.timestamp, id_=task_completed.TaskScheduledId) @@ -49,6 +50,7 @@ def call_activity_task( is_completed=True, is_faulted=True, action=new_action, + is_played=task_failed._is_played, result=task_failed.Reason, timestamp=task_failed.timestamp, id_=task_failed.TaskScheduledId, diff --git a/azure/durable_functions/tasks/call_activity_with_retry.py b/azure/durable_functions/tasks/call_activity_with_retry.py index f7374871..3a4b1273 100644 --- a/azure/durable_functions/tasks/call_activity_with_retry.py +++ b/azure/durable_functions/tasks/call_activity_with_retry.py @@ -1,14 +1,12 @@ from typing import List, Any -from .task_utilities import find_task_scheduled, \ - find_task_retry_timer_created, set_processed, parse_history_event, \ - find_task_completed, find_task_failed, find_task_retry_timer_fired +from .task_utilities import get_retried_task from ..models.RetryOptions import RetryOptions from ..models.Task import ( Task) from ..models.actions.CallActivityWithRetryAction import \ CallActivityWithRetryAction -from ..models.history import HistoryEvent +from ..models.history import HistoryEvent, HistoryEventType def call_activity_with_retry_task( @@ -37,38 +35,12 @@ def call_activity_with_retry_task( """ new_action = CallActivityWithRetryAction( function_name=name, retry_options=retry_options, input_=input_) - for attempt in range(retry_options.max_number_of_attempts): - task_scheduled = find_task_scheduled(state, name) - task_completed = find_task_completed(state, task_scheduled) - task_failed = find_task_failed(state, task_scheduled) - task_retry_timer = find_task_retry_timer_created(state, task_failed) - task_retry_timer_fired = find_task_retry_timer_fired( - state, task_retry_timer) - set_processed([task_scheduled, task_completed, - task_failed, task_retry_timer, task_retry_timer_fired]) - if not task_scheduled: - break - - if task_completed: - return Task( - is_completed=True, - is_faulted=False, - action=new_action, - result=parse_history_event(task_completed), - timestamp=task_completed.timestamp, - id_=task_completed.TaskScheduledId) - - if task_failed and task_retry_timer and attempt + 1 >= \ - retry_options.max_number_of_attempts: - return Task( - is_completed=True, - is_faulted=True, - action=new_action, - timestamp=task_failed.timestamp, - id_=task_failed.TaskScheduledId, - exc=Exception( - f"{task_failed.Reason} \n {task_failed.Details}") - ) - - return Task(is_completed=False, is_faulted=False, action=new_action) + return get_retried_task( + state=state, + max_number_of_attempts=retry_options.max_number_of_attempts, + scheduled_type=HistoryEventType.TASK_SCHEDULED, + completed_type=HistoryEventType.TASK_COMPLETED, + failed_type=HistoryEventType.TASK_FAILED, + action=new_action + ) diff --git a/azure/durable_functions/tasks/call_http.py b/azure/durable_functions/tasks/call_http.py index 85ec228f..c2c34b3f 100644 --- a/azure/durable_functions/tasks/call_http.py +++ b/azure/durable_functions/tasks/call_http.py @@ -57,6 +57,7 @@ def call_http(state: List[HistoryEvent], method: str, uri: str, content: Optiona is_completed=True, is_faulted=False, action=new_action, + is_played=task_completed._is_played, result=parse_history_event(task_completed), timestamp=task_completed.timestamp, id_=task_completed.TaskScheduledId) @@ -66,6 +67,7 @@ def call_http(state: List[HistoryEvent], method: str, uri: str, content: Optiona is_completed=True, is_faulted=True, action=new_action, + is_played=task_failed._is_played, result=task_failed.Reason, timestamp=task_failed.timestamp, id_=task_failed.TaskScheduledId, diff --git a/azure/durable_functions/tasks/call_suborchestrator.py b/azure/durable_functions/tasks/call_suborchestrator.py index 851ea25c..65a6c6d3 100644 --- a/azure/durable_functions/tasks/call_suborchestrator.py +++ b/azure/durable_functions/tasks/call_suborchestrator.py @@ -48,6 +48,7 @@ def call_sub_orchestrator_task( is_completed=True, is_faulted=False, action=new_action, + is_played=task_completed._is_played, result=parse_history_event(task_completed), timestamp=task_completed.timestamp, id_=task_completed.TaskScheduledId) @@ -57,6 +58,7 @@ def call_sub_orchestrator_task( is_completed=True, is_faulted=True, action=new_action, + is_played=task_failed._is_played, result=task_failed.Reason, timestamp=task_failed.timestamp, id_=task_failed.TaskScheduledId, diff --git a/azure/durable_functions/tasks/call_suborchestrator_with_retry.py b/azure/durable_functions/tasks/call_suborchestrator_with_retry.py index 3be2fa65..e27dd354 100644 --- a/azure/durable_functions/tasks/call_suborchestrator_with_retry.py +++ b/azure/durable_functions/tasks/call_suborchestrator_with_retry.py @@ -4,10 +4,8 @@ Task) from ..models.actions.CallSubOrchestratorWithRetryAction import CallSubOrchestratorWithRetryAction from ..models.RetryOptions import RetryOptions -from ..models.history import HistoryEvent -from .task_utilities import set_processed, parse_history_event, \ - find_sub_orchestration_created, find_sub_orchestration_completed, \ - find_sub_orchestration_failed, find_task_retry_timer_fired, find_task_retry_timer_created +from ..models.history import HistoryEvent, HistoryEventType +from .task_utilities import get_retried_task def call_sub_orchestrator_with_retry_task( @@ -40,40 +38,11 @@ def call_sub_orchestrator_with_retry_task( A Durable Task that completes when the called sub-orchestrator completes or fails. """ new_action = CallSubOrchestratorWithRetryAction(name, retry_options, input_, instance_id) - for attempt in range(retry_options.max_number_of_attempts): - task_scheduled = find_sub_orchestration_created( - state, name, context=context, instance_id=instance_id) - task_completed = find_sub_orchestration_completed(state, task_scheduled) - task_failed = find_sub_orchestration_failed(state, task_scheduled) - task_retry_timer = find_task_retry_timer_created(state, task_failed) - task_retry_timer_fired = find_task_retry_timer_fired( - state, task_retry_timer) - set_processed([task_scheduled, task_completed, - task_failed, task_retry_timer, task_retry_timer_fired]) - - if not task_scheduled: - break - - if task_completed is not None: - return Task( - is_completed=True, - is_faulted=False, - action=new_action, - result=parse_history_event(task_completed), - timestamp=task_completed.timestamp, - id_=task_completed.TaskScheduledId) - - if task_failed and task_retry_timer and attempt + 1 >= \ - retry_options.max_number_of_attempts: - return Task( - is_completed=True, - is_faulted=True, - action=new_action, - result=task_failed.Reason, - timestamp=task_failed.timestamp, - id_=task_failed.TaskScheduledId, - exc=Exception( - f"{task_failed.Reason} \n {task_failed.Details}") - ) - - return Task(is_completed=False, is_faulted=False, action=new_action) + return get_retried_task( + state=state, + max_number_of_attempts=retry_options.max_number_of_attempts, + scheduled_type=HistoryEventType.SUB_ORCHESTRATION_INSTANCE_CREATED, + completed_type=HistoryEventType.SUB_ORCHESTRATION_INSTANCE_COMPLETED, + failed_type=HistoryEventType.SUB_ORCHESTRATION_INSTANCE_FAILED, + action=new_action + ) diff --git a/azure/durable_functions/tasks/create_timer.py b/azure/durable_functions/tasks/create_timer.py index 00da0b61..4d67a4d5 100644 --- a/azure/durable_functions/tasks/create_timer.py +++ b/azure/durable_functions/tasks/create_timer.py @@ -33,7 +33,8 @@ def create_timer_task(state: List[HistoryEvent], return TimerTask( is_completed=True, action=new_action, timestamp=timer_fired.timestamp, - id_=timer_fired.event_id) + id_=timer_fired.event_id, + is_played=timer_fired.is_played) else: return TimerTask( is_completed=False, action=new_action, diff --git a/azure/durable_functions/tasks/task_all.py b/azure/durable_functions/tasks/task_all.py index 29e19e14..0758ae61 100644 --- a/azure/durable_functions/tasks/task_all.py +++ b/azure/durable_functions/tasks/task_all.py @@ -1,6 +1,9 @@ -from typing import List +from datetime import datetime +from typing import List, Optional, Any + from ..models.Task import Task from ..models.TaskSet import TaskSet +from ..models.actions import Action def task_all(tasks: List[Task]): @@ -16,31 +19,57 @@ def task_all(tasks: List[Task]): TaskSet A Durable Task Set that reports the state of running all of the tasks within it. """ - all_actions = [] - results = [] + # Args for constructing the output TaskSet + is_played = True + is_faulted = False is_completed = True - complete_time = None - faulted = [] + + actions: List[Action] = [] + results: List[Any] = [] + + exception: Optional[str] = None + end_time: Optional[datetime] = None + for task in tasks: + # Add actions and results if isinstance(task, TaskSet): - for action in task.actions: - all_actions.append(action) + actions.extend(task.actions) else: - all_actions.append(task.action) + # We know it's an atomic Task + actions.append(task.action) results.append(task.result) - if task.is_faulted: - faulted.append(task.exception) + # Record first exception, if it exists + if task.is_faulted and not is_faulted: + is_faulted = True + exception = task.exception + # If any task is not played, TaskSet is not played + if not task._is_played: + is_played = False + + # If any task is incomplete, TaskSet is incomplete + # If the task is complete, we can update the end_time if not task.is_completed: is_completed = False + elif end_time is None: + end_time = task.timestamp else: - complete_time = task.timestamp if complete_time is None \ - else max([task.timestamp, complete_time]) - - if len(faulted) > 0: - return TaskSet(is_completed, all_actions, results, is_faulted=True, exception=faulted[0]) - if is_completed: - return TaskSet(is_completed, all_actions, results, False, complete_time) - else: - return TaskSet(is_completed, all_actions, None) + end_time = max([task.timestamp, end_time]) + + # Incomplete TaskSets do not have results or end-time + if not is_completed: + results = None + end_time = None + + # Construct TaskSet + taskset = TaskSet( + is_completed=is_completed, + actions=actions, + result=results, + is_faulted=is_faulted, + timestamp=end_time, + exception=exception, + is_played=is_played + ) + return taskset diff --git a/azure/durable_functions/tasks/task_utilities.py b/azure/durable_functions/tasks/task_utilities.py index 99a8d4b5..3c54d776 100644 --- a/azure/durable_functions/tasks/task_utilities.py +++ b/azure/durable_functions/tasks/task_utilities.py @@ -4,6 +4,8 @@ from azure.functions._durable_functions import _deserialize_custom_object from datetime import datetime from typing import List, Optional +from ..models.actions.Action import Action +from ..models.Task import Task def should_suspend(partial_result) -> bool: @@ -410,3 +412,95 @@ def should_preserve(event: HistoryEvent) -> bool: # We should try to refactor this logic at some point event = matches[0] return event + + +def get_retried_task( + state: List[HistoryEvent], max_number_of_attempts: int, scheduled_type: HistoryEventType, + completed_type: HistoryEventType, failed_type: HistoryEventType, action: Action) -> Task: + """Determine the state of scheduling some task for execution with retry options. + + Parameters + ---------- + state: List[HistoryEvent] + The list of history events + max_number_of_ints: int + The maximum number of retrying attempts + scheduled_type: HistoryEventType + The event type corresponding to scheduling the searched-for task + completed_type: HistoryEventType + The event type corresponding to a completion of the searched-for task + failed_type: HistoryEventType + The event type coresponding to the failure of the searched-for task + action: Action + The action corresponding to the searched-for task + + Returns + ------- + Task + A Task encompassing the state of the scheduled work item, that is, + either completed, failed, or incomplete. + """ + # tasks to look for in the state array + scheduled_task, completed_task = None, None + failed_task, scheduled_timer_task = None, None + attempt = 1 + + # Note each case below is exclusive, and the order matters + for event in state: + event_type = HistoryEventType(event.event_type) + + # Skip processed events + if event.is_processed: + continue + + # first we find the scheduled_task + elif scheduled_task is None: + if event_type is scheduled_type: + scheduled_task = event + + # if the task has a correponding completion, we process the events + # and return a completed task + elif event_type == completed_type and \ + event.TaskScheduledId == scheduled_task.event_id: + completed_task = event + set_processed([scheduled_task, completed_task]) + return Task( + is_completed=True, + is_faulted=False, + action=action, + result=parse_history_event(completed_task), + timestamp=completed_task.timestamp, + id_=completed_task.TaskScheduledId + ) + + # if its failed, we'll have to wait for an upcoming timer scheduled + elif failed_task is None: + if event_type is failed_type: + if event.TaskScheduledId == scheduled_task.event_id: + failed_task = event + + # if we have a timer scheduled, we'll have to find a timer fired + elif scheduled_timer_task is None: + if event_type is HistoryEventType.TIMER_CREATED: + scheduled_timer_task = event + + # if we have a timer fired, we check if we still have more attempts for retries. + # If so, we retry again and clear our found events so far. + # If not, we process the events and return a completed task + elif event_type is HistoryEventType.TIMER_FIRED: + if event.TimerId == scheduled_timer_task.event_id: + set_processed([scheduled_task, completed_task, failed_task, scheduled_timer_task]) + if attempt >= max_number_of_attempts: + return Task( + is_completed=True, + is_faulted=True, + action=action, + timestamp=failed_task.timestamp, + id_=failed_task.TaskScheduledId, + exc=Exception( + f"{failed_task.Reason} \n {failed_task.Details}") + ) + else: + scheduled_task, failed_task, scheduled_timer_task = None, None, None + attempt += 1 + return Task(is_completed=False, is_faulted=False, action=action) diff --git a/azure/durable_functions/tasks/timer_task.py b/azure/durable_functions/tasks/timer_task.py index fdd719fe..454e6a9c 100644 --- a/azure/durable_functions/tasks/timer_task.py +++ b/azure/durable_functions/tasks/timer_task.py @@ -15,7 +15,7 @@ class TimerTask(Task): ``` """ - def __init__(self, action: CreateTimerAction, is_completed, timestamp, id_): + def __init__(self, action: CreateTimerAction, is_completed, timestamp, id_, is_played=False): self._action: CreateTimerAction = action self._is_completed = is_completed self._timestamp = timestamp @@ -23,6 +23,7 @@ def __init__(self, action: CreateTimerAction, is_completed, timestamp, id_): super().__init__(self._is_completed, False, self._action, None, self._timestamp, self._id, None) + self._is_played = is_played def is_cancelled(self) -> bool: """Check of a timer is cancelled. diff --git a/azure/durable_functions/tasks/wait_for_external_event.py b/azure/durable_functions/tasks/wait_for_external_event.py index bfcb8342..64645232 100644 --- a/azure/durable_functions/tasks/wait_for_external_event.py +++ b/azure/durable_functions/tasks/wait_for_external_event.py @@ -34,6 +34,7 @@ def wait_for_external_event_task( is_completed=True, is_faulted=False, action=new_action, + is_played=event_raised._is_played, result=parse_history_event(event_raised), timestamp=event_raised.timestamp, id_=event_raised.event_id) diff --git a/tests/models/test_DurableOrchestrationClient.py b/tests/models/test_DurableOrchestrationClient.py index 49fc9ff3..6a877568 100644 --- a/tests/models/test_DurableOrchestrationClient.py +++ b/tests/models/test_DurableOrchestrationClient.py @@ -19,6 +19,9 @@ MESSAGE_500 = 'instance failed with unhandled exception' MESSAGE_501 = "well we didn't expect that" +INSTANCE_ID = "2e2568e7-a906-43bd-8364-c81733c5891e" +REASON = "Stuff" + TEST_ORCHESTRATOR = "MyDurableOrchestrator" EXCEPTION_ORCHESTRATOR_NOT_FOUND_EXMESSAGE = "The function doesn't exist,"\ " is disabled, or is not an orchestrator function. Additional info: "\ @@ -147,7 +150,7 @@ async def test_get_202_get_status_success(binding_string): result = await client.get_status(TEST_INSTANCE_ID) assert result is not None - assert result.runtime_status == "Running" + assert result.runtime_status.name == "Running" @pytest.mark.asyncio @@ -161,7 +164,7 @@ async def test_get_200_get_status_success(binding_string): result = await client.get_status(TEST_INSTANCE_ID) assert result is not None - assert result.runtime_status == "Completed" + assert result.runtime_status.name == "Completed" @pytest.mark.asyncio @@ -540,3 +543,52 @@ async def test_start_new_orchestrator_internal_exception(binding_string): with pytest.raises(Exception) as ex: await client.start_new(TEST_ORCHESTRATOR) ex.match(status_str) + +@pytest.mark.asyncio +async def test_rewind_works_under_200_and_200_http_codes(binding_string): + """Tests that the rewind API works as expected under 'successful' http codes: 200, 202""" + client = DurableOrchestrationClient(binding_string) + for code in [200, 202]: + mock_request = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{INSTANCE_ID}/rewind?reason={REASON}", + response=[code, ""]) + client._post_async_request = mock_request.post + result = await client.rewind(INSTANCE_ID, REASON) + assert result is None + +@pytest.mark.asyncio +async def test_rewind_throws_exception_during_404_410_and_500_errors(binding_string): + """Tests the behaviour of rewind under 'exception' http codes: 404, 410, 500""" + client = DurableOrchestrationClient(binding_string) + codes = [404, 410, 500] + exception_strs = [ + f"No instance with ID {INSTANCE_ID} found.", + "The rewind operation is only supported on failed orchestration instances.", + "Something went wrong" + ] + for http_code, expected_exception_str in zip(codes, exception_strs): + mock_request = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{INSTANCE_ID}/rewind?reason={REASON}", + response=[http_code, "Something went wrong"]) + client._post_async_request = mock_request.post + + with pytest.raises(Exception) as ex: + await client.rewind(INSTANCE_ID, REASON) + ex_message = str(ex.value) + assert ex_message == expected_exception_str + +@pytest.mark.asyncio +async def test_rewind_with_no_rpc_endpoint(binding_string): + """Tests the behaviour of rewind without an RPC endpoint / under the legacy HTTP endpoint.""" + client = DurableOrchestrationClient(binding_string) + mock_request = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{INSTANCE_ID}/rewind?reason={REASON}", + response=[-1, ""]) + client._post_async_request = mock_request.post + client._orchestration_bindings._rpc_base_url = None + expected_exception_str = "The Python SDK only supports RPC endpoints."\ + + "Please remove the `localRpcEnabled` setting from host.json" + with pytest.raises(Exception) as ex: + await client.rewind(INSTANCE_ID, REASON) + ex_message = str(ex.value) + assert ex_message == expected_exception_str diff --git a/tests/models/test_DurableOrchestrationStatus.py b/tests/models/test_DurableOrchestrationStatus.py index 012a2041..8e1100ab 100644 --- a/tests/models/test_DurableOrchestrationStatus.py +++ b/tests/models/test_DurableOrchestrationStatus.py @@ -5,6 +5,7 @@ from azure.durable_functions.constants import DATETIME_STRING_FORMAT from azure.durable_functions.models.DurableOrchestrationStatus import DurableOrchestrationStatus +from azure.durable_functions.models.OrchestrationRuntimeStatus import OrchestrationRuntimeStatus from azure.durable_functions.models.history import HistoryEventType TEST_NAME = 'what ever I want it to be' @@ -38,7 +39,7 @@ def test_all_the_args(): result = DurableOrchestrationStatus.from_json(response) - assert result.runtime_status == TEST_RUNTIME_STATUS + assert result.runtime_status.name == TEST_RUNTIME_STATUS assert result.custom_status == TEST_CUSTOM_STATUS assert result.instance_id == TEST_INSTANCE_ID assert result.output == TEST_OUTPUT diff --git a/tests/orchestrator/orchestrator_test_utils.py b/tests/orchestrator/orchestrator_test_utils.py index fc75d0ae..9bdbb1b5 100644 --- a/tests/orchestrator/orchestrator_test_utils.py +++ b/tests/orchestrator/orchestrator_test_utils.py @@ -50,6 +50,17 @@ def get_orchestration_state_result( result = json.loads(result_of_handle) return result +def get_orchestration_property( + context_builder, + activity_func: Callable[[DurableOrchestrationContext], Iterator[Any]], + prop: str): + context_as_string = context_builder.to_json_string() + orchestrator = Orchestrator(activity_func) + result_of_handle = orchestrator.handle( + DurableOrchestrationContext.from_json(context_as_string)) + result = getattr(orchestrator, prop) + return result + def assert_valid_schema(orchestration_state): validation_results = validate(instance=orchestration_state, schema=schema) diff --git a/tests/orchestrator/test_create_timer.py b/tests/orchestrator/test_create_timer.py index 41a8b3aa..b4e083e1 100644 --- a/tests/orchestrator/test_create_timer.py +++ b/tests/orchestrator/test_create_timer.py @@ -28,8 +28,8 @@ def generator_function(context): return "Done!" def add_timer_action(state: OrchestratorState, fire_at: datetime): - action = CreateTimerAction(fire_at= fire_at) - state._actions.append([action]) # Todo: brackets? + action = CreateTimerAction(fire_at=fire_at) + state._actions.append([action]) def test_timers_comparison_with_relaxed_precision(): """Test if that two `datetime` different but equivalent diff --git a/tests/orchestrator/test_is_replaying_flag.py b/tests/orchestrator/test_is_replaying_flag.py new file mode 100644 index 00000000..98f47b4b --- /dev/null +++ b/tests/orchestrator/test_is_replaying_flag.py @@ -0,0 +1,72 @@ +from tests.test_utils.ContextBuilder import ContextBuilder +from .orchestrator_test_utils \ + import get_orchestration_property, assert_orchestration_state_equals, assert_valid_schema +from azure.durable_functions.models.actions.CreateTimerAction import CreateTimerAction +from azure.durable_functions.models.OrchestratorState import OrchestratorState +from azure.durable_functions.constants import DATETIME_STRING_FORMAT +from datetime import datetime, timedelta, timezone + +def generator_function(context): + # Create a timezone aware datetime object, just like a normal + # call to `context.current_utc_datetime` would create + timestamp = "2020-07-23T21:56:54.936700Z" + deadline = datetime.strptime(timestamp, DATETIME_STRING_FORMAT) + deadline = deadline.replace(tzinfo=timezone.utc) + + for _ in range(0, 3): + deadline = deadline + timedelta(seconds=30) + yield context.create_timer(deadline) + +def base_expected_state(output=None) -> OrchestratorState: + return OrchestratorState(is_done=False, actions=[], output=output) + +def add_timer_fired_events(context_builder: ContextBuilder, id_: int, timestamp: str, + is_played: bool = True): + fire_at: str = context_builder.add_timer_created_event(id_, timestamp) + context_builder.add_orchestrator_completed_event() + context_builder.add_orchestrator_started_event() + context_builder.add_timer_fired_event(id_=id_, fire_at=fire_at, is_played=is_played) + +def add_timer_action(state: OrchestratorState, fire_at: datetime): + action = CreateTimerAction(fire_at=fire_at) + state._actions.append([action]) + +def test_is_replaying_initial_value(): + + context_builder = ContextBuilder("") + result = get_orchestration_property( + context_builder, generator_function, "durable_context") + + assert result.is_replaying == False + +def test_is_replaying_one_replayed_event(): + + timestamp = "2020-07-23T21:56:54.9367Z" + fire_at = datetime.strptime(timestamp, DATETIME_STRING_FORMAT) + timedelta(seconds=30) + fire_at_str = fire_at.strftime(DATETIME_STRING_FORMAT) + + context_builder = ContextBuilder("") + add_timer_fired_events(context_builder, 0, fire_at_str, is_played=True) + + result = get_orchestration_property( + context_builder, generator_function, "durable_context") + + assert result.is_replaying == True + +def test_is_replaying_one_replayed_one_not(): + + timestamp = "2020-07-23T21:56:54.9367Z" + fire_at = datetime.strptime(timestamp, DATETIME_STRING_FORMAT) + timedelta(seconds=30) + fire_at_str = fire_at.strftime(DATETIME_STRING_FORMAT) + fire_at2 = datetime.strptime(timestamp, DATETIME_STRING_FORMAT) + timedelta(seconds=60) + fire_at_str2 = fire_at2.strftime(DATETIME_STRING_FORMAT) + + context_builder = ContextBuilder("") + add_timer_fired_events(context_builder, 0, fire_at_str, is_played=True) + add_timer_fired_events(context_builder, 1, fire_at_str2, is_played=False) + + + result = get_orchestration_property( + context_builder, generator_function, "durable_context") + + assert result.is_replaying == False diff --git a/tests/orchestrator/test_retries.py b/tests/orchestrator/test_retries.py new file mode 100644 index 00000000..6e249c50 --- /dev/null +++ b/tests/orchestrator/test_retries.py @@ -0,0 +1,263 @@ +from tests.test_utils.ContextBuilder import ContextBuilder +from azure.durable_functions.models.RetryOptions import RetryOptions +from azure.durable_functions.models.OrchestratorState import OrchestratorState +from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext +from .orchestrator_test_utils import get_orchestration_state_result +from typing import List, Tuple +from datetime import datetime + +RETRY_OPTIONS = RetryOptions(5000, 2) +REASONS = "Stuff" +DETAILS = "Things" +RESULT_PREFIX = "Hello " +CITIES = ["Tokyo", "Seattle", "London"] + +def generator_function(context: DurableOrchestrationContext): + """Orchestrator function for testing retry'ing semantics + + Parameters + ---------- + context: DurableOrchestrationContext + Durable orchestration context, exposes the Durable API + + Returns + ------- + List[str]: + Output of activities, a list of hello'd cities + """ + + outputs = [] + + retry_options = RETRY_OPTIONS + task1 = yield context.call_activity_with_retry( + "Hello", retry_options, "Tokyo") + task2 = yield context.call_activity_with_retry( + "Hello", retry_options, "Seattle") + task3 = yield context.call_activity_with_retry( + "Hello", retry_options, "London") + + outputs.append(task1) + outputs.append(task2) + outputs.append(task3) + + return outputs + +def get_context_with_retries_and_corrupted_completion() -> ContextBuilder: + """Get a ContextBuilder whose history contains a late completion event + for an event that already failed. + + Returns + ------- + ContextBuilder: + The context whose history contains the requested event sequence. + """ + context = get_context_with_retries() + context.add_orchestrator_started_event() + context.add_task_completed_event(id_=0, result="'Do not pick me up'") + context.add_orchestrator_completed_event() + return context + +def get_context_with_retries(will_fail: bool=False) -> ContextBuilder: + """Get a ContextBuilder whose history contains retried events. + + Parameters + ---------- + will_fail: (bool, optional) + If set to true, returns a context with a history where the orchestrator fails. + If false, returns a context with a history where events fail but eventually complete. + Defaults to False. + + Returns + ------- + ContextBuilder: + The context whose history contains the requested event sequence. + """ + context = ContextBuilder() + num_activities = len(CITIES) + + def _schedule_events(context: ContextBuilder, id_counter: int) -> Tuple[ContextBuilder, int, List[int]]: + """Add scheduled events to the context. + + Parameters + ---------- + context: ContextBuilder + Orchestration context mock, to which we'll add the event completion events + id_counter: int + The current event counter + + Returns + ------- + Tuple[ContextBuilder, int, List[int]]: + The updated context, the updated counter, a list of event IDs for each scheduled event + """ + scheduled_ids: List[int] = [] + for id_ in range(num_activities): + scheduled_ids.append(id_) + context.add_task_scheduled_event(name='Hello', id_=id_) + id_counter += 1 + return context, id_counter, scheduled_ids + + def _fail_events(context: ContextBuilder, id_counter: int) -> Tuple[ContextBuilder, int]: + """Add event failed to the context. + + Parameters + ---------- + context: ContextBuilder + Orchestration context mock, to which we'll add the event completion events + id_counter: int + The current event counter + + Returns + ------- + Tuple[ContextBuilder, int]: + The updated context, the updated id_counter + """ + context.add_orchestrator_started_event() + for id_ in scheduled_ids: + context.add_task_failed_event( + id_=id_, reason=REASONS, details=DETAILS) + id_counter += 1 + return context, id_counter + + def _schedule_timers(context: ContextBuilder, id_counter: int) -> Tuple[ContextBuilder, int, List[datetime]]: + """Add timer created events to the context. + + Parameters + ---------- + context: ContextBuilder + Orchestration context mock, to which we'll add the event completion events + id_counter: int + The current event counter + + Returns + ------- + Tuple[ContextBuilder, int, List[datetime]]: + The updated context, the updated counter, a list of timer deadlines + """ + deadlines: List[datetime] = [] + for _ in range(num_activities): + deadlines.append((id_counter, context.add_timer_created_event(id_counter))) + id_counter += 1 + return context, id_counter, deadlines + + def _fire_timer(context: ContextBuilder, id_counter: int, deadlines: List[datetime]) -> Tuple[ContextBuilder, int]: + """Add timer fired events to the context. + + Parameters + ---------- + context: ContextBuilder + Orchestration context mock, to which we'll add the event completion events + id_counter: int + The current event counter + deadlines: List[datetime] + List of dates at which to fire the timers + + Returns + ------- + Tuple[ContextBuilder, int]: + The updated context, the updated id_counter + """ + for id_, fire_at in deadlines: + context.add_timer_fired_event(id_=id_, fire_at=fire_at) + id_counter += 1 + return context, id_counter + + def _complete_event(context: ContextBuilder, id_counter: int) -> Tuple[ContextBuilder, int]: + """Add event / task completions to the context. + + Parameters + ---------- + context: ContextBuilder + Orchestration context mock, to which we'll add the event completion events + id_counter: int + The current event counter + + Returns + ------- + Tuple[ContextBuilder, int] + The updated context, the updated id_counter + """ + for id_, city in zip(scheduled_ids, CITIES): + result = f"\"{RESULT_PREFIX}{city}\"" + context.add_task_completed_event(id_=id_, result=result) + id_counter += 1 + return context, id_counter + + + id_counter = 0 + + # Schedule the events + context, id_counter, scheduled_ids = _schedule_events(context, id_counter) + context.add_orchestrator_completed_event() + + # Record failures, schedule timers + context, id_counter = _fail_events(context, id_counter) + context, id_counter, deadlines = _schedule_timers(context, id_counter) + context.add_orchestrator_completed_event() + + # Fire timers, re-schedule events + context.add_orchestrator_started_event() + context, id_counter = _fire_timer(context, id_counter, deadlines) + context, id_counter, scheduled_ids = _schedule_events(context, id_counter) + context.add_orchestrator_completed_event() + + context.add_orchestrator_started_event() + + # Either complete the event or, if we want all failed events, then + # fail the events, schedule timer, and fire time. + if will_fail: + context, id_counter = _fail_events(context, id_counter) + context, id_counter, deadlines = _schedule_timers(context, id_counter) + context.add_orchestrator_completed_event() + + context.add_orchestrator_started_event() + context, id_counter = _fire_timer(context, id_counter, deadlines) + else: + context, id_counter = _complete_event(context, id_counter) + + context.add_orchestrator_completed_event() + return context + +def test_redundant_completion_doesnt_get_processed(): + """Tests that our implementation processes the state array + sequentially, which previous implementations did not guarantee. In this test, + we add a completion event for a task that was cancelled, meaning that it failed and got + re-scheduled. Older implementations would pick up this completion event and cause + non-determinism. + """ + context_1 = get_context_with_retries() + context_2 = get_context_with_retries_and_corrupted_completion() + + result_1 = get_orchestration_state_result( + context_1, generator_function) + + result_2 = get_orchestration_state_result( + context_2, generator_function) + + assert "output" in result_1 + assert "output" in result_2 + assert result_1["output"] == result_2["output"] + + +def test_failed_tasks_do_not_hang_orchestrator(): + """Tests that our implementation correctly handles up re-scheduled events, + which previous implementations failed to correctly handle. """ + context = get_context_with_retries() + + result = get_orchestration_state_result( + context, generator_function) + + expected_output = list(map(lambda x: RESULT_PREFIX + x, CITIES)) + assert "output" in result + assert result["output"] == expected_output + +def test_retries_can_fail(): + """Tests the code path where a retry'ed Task fails""" + context = get_context_with_retries(will_fail=True) + + result = get_orchestration_state_result( + context, generator_function) + + expected_error = f"{REASONS} \n {DETAILS}" + assert "error" in result + assert result["error"] == expected_error \ No newline at end of file diff --git a/tests/orchestrator/test_sub_orchestrator_with_retry.py b/tests/orchestrator/test_sub_orchestrator_with_retry.py index 2a0c65f1..3052ae6c 100644 --- a/tests/orchestrator/test_sub_orchestrator_with_retry.py +++ b/tests/orchestrator/test_sub_orchestrator_with_retry.py @@ -11,6 +11,7 @@ def generator_function(context): outputs = [] retry_options = RETRY_OPTIONS + task1 = yield context.call_sub_orchestrator_with_retry("HelloSubOrchestrator", retry_options, "Tokyo") task2 = yield context.call_sub_orchestrator_with_retry("HelloSubOrchestrator", retry_options, "Seattle") task3 = yield context.call_sub_orchestrator_with_retry("HelloSubOrchestrator", retry_options, "London") diff --git a/tests/test_utils/ContextBuilder.py b/tests/test_utils/ContextBuilder.py index 753009fa..91d49358 100644 --- a/tests/test_utils/ContextBuilder.py +++ b/tests/test_utils/ContextBuilder.py @@ -13,7 +13,7 @@ class ContextBuilder: - def __init__(self, name: str): + def __init__(self, name: str=""): self.instance_id = uuid.uuid4() self.is_replaying: bool = False self.input_ = None @@ -93,8 +93,8 @@ def add_timer_created_event(self, id_: int, timestamp: str = None): self.history_events.append(event) return fire_at - def add_timer_fired_event(self, id_: int, fire_at: str): - event = self.get_base_event(HistoryEventType.TIMER_FIRED, is_played=True) + def add_timer_fired_event(self, id_: int, fire_at: str, is_played: bool = True): + event = self.get_base_event(HistoryEventType.TIMER_FIRED, is_played=is_played) event.TimerId = id_ event.FireAt = fire_at self.history_events.append(event)