From 4b304131a12e715b8cbd070ba830fee214185423 Mon Sep 17 00:00:00 2001 From: Matthieu Monsch <1216372+mtth@users.noreply.github.com> Date: Tue, 28 Nov 2023 17:26:38 -0800 Subject: [PATCH] Update outline generation (#122) --- opvious/client/common.py | 36 ++++++++++++++++++++++------------- opvious/client/handlers.py | 32 ++++++++++++++++++++++--------- opvious/data/outcomes.py | 4 ++-- opvious/data/queued_solves.py | 8 ++++---- pyproject.toml | 2 +- tests/test_client.py | 5 +++-- 6 files changed, 56 insertions(+), 31 deletions(-) diff --git a/opvious/client/common.py b/opvious/client/common.py index b1eb58b..e35ad17 100644 --- a/opvious/client/common.py +++ b/opvious/client/common.py @@ -121,6 +121,24 @@ def build(self) -> SolveInputs: ) +async def generate_outline( + executor: Executor, outline_data: Json, transformation_data: Json +) -> ProblemOutline: + if not transformation_data: + return outline_from_json(outline_data) + async with executor.execute( + result_type=JsonExecutorResult, + url="/outlines/transform", + method="POST", + json_data=json_dict( + outline=outline_data, + transformations=transformation_data, + ), + ) as res: + data = res.json_data() + return outline_from_json(data["outline"]) + + class ProblemOutlineGenerator: def __init__(self, executor: Executor, outline_data: Json): self._executor = executor @@ -178,20 +196,12 @@ async def generate(self) -> tuple[ProblemOutline, Json]: class Context(ProblemTransformationContext): async def fetch_outline(self) -> ProblemOutline: - transformations = self.get_json() - if not transformations: + transformation_data = self.get_json() + if not transformation_data: return pristine_outline - async with executor.execute( - result_type=JsonExecutorResult, - url="/outlines/transform", - method="POST", - json_data=json_dict( - outline=pristine_outline_data, - transformations=transformations, - ), - ) as res: - data = res.json_data() - return outline_from_json(data["outline"]) + return await generate_outline( + executor, pristine_outline_data, transformation_data + ) context = Context() for tf in self._transformations: diff --git a/opvious/client/handlers.py b/opvious/client/handlers.py index 26b74f9..95d405c 100644 --- a/opvious/client/handlers.py +++ b/opvious/client/handlers.py @@ -20,6 +20,7 @@ ) from ..data.outcomes import ( AbortedOutcome, + FailedOutcome, FeasibleOutcome, InfeasibleOutcome, SolveOutcome, @@ -29,7 +30,7 @@ feasible_outcome_from_graphql, solve_outcome_status, ) -from ..data.outlines import ProblemOutline, outline_from_json +from ..data.outlines import ProblemOutline from ..data.solves import ( ProblemSummary, SolveInputs, @@ -60,6 +61,7 @@ ProblemOutlineGenerator, SolveInputsBuilder, feasible_outcome_details, + generate_outline, log_progress, ) @@ -530,8 +532,8 @@ async def queue_solve(self, problem: Problem) -> QueuedSolve: uuid = res.json_data()["uuid"] return QueuedSolve( uuid=uuid, - started_at=datetime.now(timezone.utc), outline=outline, + started_at=datetime.now(timezone.utc), ) async def fetch_solve(self, uuid: str) -> Optional[QueuedSolve]: @@ -547,10 +549,12 @@ async def fetch_solve(self, uuid: str) -> Optional[QueuedSolve]: solve = data["queuedSolve"] if not solve: return None - return queued_solve_from_graphql( - data=solve, - outline=outline_from_json(solve["outline"]), + outline = await generate_outline( + self._executor, + solve["specification"]["outline"], + solve["transformations"], ) + return queued_solve_from_graphql(solve, outline) async def cancel_solve(self, uuid: str) -> bool: """Cancels a running solve @@ -580,6 +584,18 @@ async def poll_solve( variables=json_dict(uuid=solve.uuid), ) solve_data = data["queuedSolve"] + + error_status = solve_data["attempt"]["errorStatus"] + if error_status: + failure_data = solve_data["failure"] + if failure_data: + return failed_outcome_from_graphql(failure_data) + else: + return FailedOutcome( + error_status, + "The problem's inputs did not match its specification", + ) + outcome_data = solve_data["outcome"] if not outcome_data: edges = solve_data["notifications"]["edges"] @@ -587,6 +603,7 @@ async def poll_solve( dequeued=bool(solve_data["dequeuedAt"]), data=edges[0]["node"] if edges else None, ) + status = outcome_data["status"] if status == "ABORTED": return cast(SolveOutcome, AbortedOutcome()) @@ -596,10 +613,7 @@ async def poll_solve( return UnboundedOutcome() if status == "FEASIBLE" or status == "OPTIMAL": return feasible_outcome_from_graphql(outcome_data) - failure_data = solve_data["failure"] - if not failure_data: - raise Exception(f"Unexpected status {status} without failure") - return failed_outcome_from_graphql(failure_data) + raise Exception(f"Unexpected status {status} without failure") @backoff.on_predicate( backoff.fibo, diff --git a/opvious/data/outcomes.py b/opvious/data/outcomes.py index 6b88799..8f8c475 100644 --- a/opvious/data/outcomes.py +++ b/opvious/data/outcomes.py @@ -31,10 +31,10 @@ class FailedOutcome: message: str """The underlying error's message""" - code: Optional[str] + code: Optional[str] = None """The underlying error's error code""" - tags: Any + tags: Any = None """Structured data associated with the failure""" diff --git a/opvious/data/queued_solves.py b/opvious/data/queued_solves.py index d5993d1..3b33fc4 100644 --- a/opvious/data/queued_solves.py +++ b/opvious/data/queued_solves.py @@ -19,20 +19,20 @@ class QueuedSolve: uuid: str """The solve's unique identifier""" - started_at: datetime - """The time the solve was created""" - outline: ProblemOutline = dataclasses.field(repr=False) """The specification outline corresponding to this solve""" + started_at: datetime + """The time the solve was created""" + def queued_solve_from_graphql( data: Any, outline: ProblemOutline ) -> QueuedSolve: return QueuedSolve( uuid=data["uuid"], - started_at=datetime.fromisoformat(data["startedAt"]), outline=outline, + started_at=datetime.fromisoformat(data["attempt"]["startedAt"]), ) diff --git a/pyproject.toml b/pyproject.toml index e500c4d..2be60b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "opvious" -version = "0.18.5rc2" +version = "0.18.6rc1" description = "Opvious Python SDK" authors = ["Opvious Engineering "] readme = "README.md" diff --git a/tests/test_client.py b/tests/test_client.py index 0195b39..ef3f971 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -99,7 +99,7 @@ async def test_queue_diet_solve(self): @pytest.mark.asyncio async def test_queue_relaxed_solve(self): - solve = await client.queue_solve( + queued = await client.queue_solve( opvious.Problem( specification=opvious.FormulationSpecification("bounded"), transformations=[ @@ -113,7 +113,8 @@ async def test_queue_relaxed_solve(self): parameters={"bound": 3}, ), ) - outcome = await client.wait_for_solve_outcome(solve) + fetched = await client.fetch_solve(queued.uuid) + outcome = await client.wait_for_solve_outcome(fetched) assert isinstance(outcome, opvious.FeasibleOutcome) assert outcome.objective_value == 2