Skip to content

Commit

Permalink
Immediately dispatch line/termination/finish (fixes #1049, fixes #1071)
Browse files Browse the repository at this point in the history
Avoids races between queued up lines and command finish callbacks.
  • Loading branch information
niklasf committed Jul 31, 2024
1 parent 71e7c31 commit 7299216
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ New features:

Bugfixes:

* Fix unsolicited engine output may cause assertion errors with regard to
command states.
* Fix handling of whitespace in UCI engine communication.
* For ``chess.Board.epd()`` and ``chess.Board.set_epd()``, require that EPD
opcodes start with a letter.
Expand Down
40 changes: 27 additions & 13 deletions chess/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def write(self, data: bytes) -> None:
expectation, responses = self.expectations.popleft()
assert expectation == line, f"expected {expectation}, got: {line}"
if responses:
self.protocol.pipe_data_received(1, "\n".join(responses + [""]).encode("utf-8"))
self.protocol.loop.call_soon(self.protocol.pipe_data_received, 1, "\n".join(responses + [""]).encode("utf-8"))

def get_pid(self) -> int:
return id(self)
Expand Down Expand Up @@ -934,12 +934,12 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
LOGGER.debug("%s: Connection lost (exit code: %d, error: %s)", self, code, exc)

# Terminate commands.
if self.command is not None:
self.command._engine_terminated(code)
self.command = None
if self.next_command is not None:
self.next_command._engine_terminated(code)
self.next_command = None
command, self.command = self.command, None
next_command, self.next_command = self.next_command, None
if command:
command._engine_terminated(code)
if next_command:
next_command._engine_terminated(code)

self.returncode.set_result(code)

Expand All @@ -965,9 +965,9 @@ def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None:
LOGGER.warning("%s: >> %r (%s)", self, bytes(line_bytes), err)
else:
if fd == 1:
self.loop.call_soon(self._line_received, line)
self._line_received(line)
else:
self.loop.call_soon(self.error_line_received, line)
self.error_line_received(line)

def error_line_received(self, line: str) -> None:
LOGGER.warning("%s: stderr >> %s", self, line)
Expand Down Expand Up @@ -998,7 +998,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) -

self.next_command = command

def previous_command_finished(_: Optional[asyncio.Future[None]]) -> None:
def previous_command_finished() -> None:
self.command, self.next_command = self.next_command, None
if self.command is not None:
cmd = self.command
Expand All @@ -1008,11 +1008,11 @@ def cancel_if_cancelled(result: asyncio.Future[T]) -> None:
cmd._cancel()

cmd.result.add_done_callback(cancel_if_cancelled)
cmd.finished.add_done_callback(previous_command_finished)
cmd._start()
cmd.add_finished_callback(previous_command_finished)

if self.command is None:
previous_command_finished(None)
previous_command_finished()
elif not self.command.result.done():
self.command.result.cancel()
elif not self.command.result.cancelled():
Expand Down Expand Up @@ -1228,13 +1228,25 @@ def __init__(self, engine: Protocol) -> None:
self.result: asyncio.Future[T] = asyncio.Future()
self.finished: asyncio.Future[None] = asyncio.Future()

self._finished_callbacks: List[Callable[[], None]] = []

def add_finished_callback(self, callback: Callable[[], None]) -> None:
self._finished_callbacks.append(callback)
self._dispatch_finished()

def _dispatch_finished(self) -> None:
if self.finished.done():
while self._finished_callbacks:
self._finished_callbacks.pop()()

def _engine_terminated(self, code: int) -> None:
hint = ", binary not compatible with cpu?" if code in [-4, 0xc000001d] else ""
exc = EngineTerminatedError(f"engine process died unexpectedly (exit code: {code}{hint})")
if self.state == CommandState.ACTIVE:
self.engine_terminated(exc)
elif self.state == CommandState.CANCELLING:
self.finished.set_result(None)
self._dispatch_finished()
elif self.state == CommandState.NEW:
self._handle_exception(exc)

Expand All @@ -1251,13 +1263,15 @@ def _handle_exception(self, exc: Exception) -> None:

if not self.finished.done():
self.finished.set_result(None)
self._dispatch_finished()

def set_finished(self) -> None:
assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state
if not self.result.done():
self.result.set_exception(EngineError(f"engine command finished before returning result: {self!r}"))
self.finished.set_result(None)
self.state = CommandState.DONE
self.finished.set_result(None)
self._dispatch_finished()

def _cancel(self) -> None:
if self.state != CommandState.CANCELLING and self.state != CommandState.DONE:
Expand Down
18 changes: 18 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3527,6 +3527,24 @@ async def main():

asyncio.run(main())

def test_uci_output_after_command(self):
async def main():
protocol = chess.engine.UciProtocol()
mock = chess.engine.MockTransport(protocol)

mock.expect("uci", [
"Arasan v24.0.0-10-g367aa9f Copyright 1994-2023 by Jon Dart.",
"All rights reserved.",
"id name Arasan v24.0.0-10-g367aa9f",
"uciok",
"info string out of do_all_pending, list size=0"
])
await protocol.initialize()

mock.assert_done()

asyncio.run(main())

def test_hiarcs_bestmove(self):
async def main():
protocol = chess.engine.UciProtocol()
Expand Down

0 comments on commit 7299216

Please sign in to comment.