Skip to content

Commit

Permalink
Merge pull request #1100 from niklasf/stricter-subclass-typing
Browse files Browse the repository at this point in the history
Stricter subclass typing
  • Loading branch information
niklasf authored Jul 31, 2024
2 parents 32253d6 + 33eea7c commit ec399d1
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 324 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ Changes:
some 8 piece positions with decisive captures can be probed successfully.
* The string wrapper returned by ``chess.svg`` functions now also implements
``_repr_html_``.
* Significant changes to ``chess.engine`` internals:
``chess.engine.BaseCommand`` methods other than the constructor no longer
receive ``engine: Protocol``.
* Significant changes to board state internals: Subclasses of ``chess.Board``
can no longer hook into board state recording/restoration and need to
override relevant methods instead (``clear_stack``, ``copy``, ``root``,
``push``, ``pop``).

New features:

Expand Down
45 changes: 21 additions & 24 deletions chess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import ClassVar, Callable, Counter, Dict, Generic, Hashable, Iterable, Iterator, List, Literal, Mapping, Optional, SupportsInt, Tuple, Type, TypeVar, Union

if typing.TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias


EnPassantSpec = Literal["legal", "fen", "xfen"]
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def apply_transform(self, f: Callable[[Bitboard], Bitboard]) -> None:
self.occupied = f(self.occupied)
self.promoted = f(self.promoted)

def transform(self: BaseBoardT, f: Callable[[Bitboard], Bitboard]) -> BaseBoardT:
def transform(self, f: Callable[[Bitboard], Bitboard]) -> Self:
"""
Returns a transformed copy of the board (without move stack)
by applying a bitboard transformation function.
Expand All @@ -1473,11 +1473,11 @@ def transform(self: BaseBoardT, f: Callable[[Bitboard], Bitboard]) -> BaseBoardT
board.apply_transform(f)
return board

def apply_mirror(self: BaseBoardT) -> None:
def apply_mirror(self) -> None:
self.apply_transform(flip_vertical)
self.occupied_co[WHITE], self.occupied_co[BLACK] = self.occupied_co[BLACK], self.occupied_co[WHITE]

def mirror(self: BaseBoardT) -> BaseBoardT:
def mirror(self) -> Self:
"""
Returns a mirrored copy of the board (without move stack).
Expand All @@ -1491,7 +1491,7 @@ def mirror(self: BaseBoardT) -> BaseBoardT:
board.apply_mirror()
return board

def copy(self: BaseBoardT) -> BaseBoardT:
def copy(self) -> Self:
"""Creates a copy of the board."""
board = type(self)(None)

Expand All @@ -1509,10 +1509,10 @@ def copy(self: BaseBoardT) -> BaseBoardT:

return board

def __copy__(self: BaseBoardT) -> BaseBoardT:
def __copy__(self) -> Self:
return self.copy()

def __deepcopy__(self: BaseBoardT, memo: Dict[int, object]) -> BaseBoardT:
def __deepcopy__(self, memo: Dict[int, object]) -> Self:
board = self.copy()
memo[id(self)] = board
return board
Expand Down Expand Up @@ -1542,9 +1542,9 @@ def from_chess960_pos(cls: Type[BaseBoardT], scharnagl: int) -> BaseBoardT:

BoardT = TypeVar("BoardT", bound="Board")

class _BoardState(Generic[BoardT]):
class _BoardState:

def __init__(self, board: BoardT) -> None:
def __init__(self, board: Board) -> None:
self.pawns = board.pawns
self.knights = board.knights
self.bishops = board.bishops
Expand All @@ -1564,7 +1564,7 @@ def __init__(self, board: BoardT) -> None:
self.halfmove_clock = board.halfmove_clock
self.fullmove_number = board.fullmove_number

def restore(self, board: BoardT) -> None:
def restore(self, board: Board) -> None:
board.pawns = self.pawns
board.knights = self.knights
board.bishops = self.bishops
Expand Down Expand Up @@ -1694,14 +1694,14 @@ class Board(BaseBoard):
manipulation.
"""

def __init__(self: BoardT, fen: Optional[str] = STARTING_FEN, *, chess960: bool = False) -> None:
def __init__(self, fen: Optional[str] = STARTING_FEN, *, chess960: bool = False) -> None:
BaseBoard.__init__(self, None)

self.chess960 = chess960

self.ep_square = None
self.move_stack = []
self._stack: List[_BoardState[BoardT]] = []
self._stack: List[_BoardState] = []

if fen is None:
self.clear()
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def clear_stack(self) -> None:
self.move_stack.clear()
self._stack.clear()

def root(self: BoardT) -> BoardT:
def root(self) -> Self:
"""Returns a copy of the root position."""
if self._stack:
board = type(self)(None, chess960=self.chess960)
Expand Down Expand Up @@ -2304,13 +2304,10 @@ def is_repetition(self, count: int = 3) -> bool:

return False

def _board_state(self: BoardT) -> _BoardState[BoardT]:
return _BoardState(self)

def _push_capture(self, move: Move, capture_square: Square, piece_type: PieceType, was_promoted: bool) -> None:
pass

def push(self: BoardT, move: Move) -> None:
def push(self, move: Move) -> None:
"""
Updates the position with the given *move* and puts it onto the
move stack.
Expand All @@ -2335,7 +2332,7 @@ def push(self: BoardT, move: Move) -> None:
"""
# Push move and remember board state.
move = self._to_chess960(move)
board_state = self._board_state()
board_state = _BoardState(self)
self.castling_rights = self.clean_castling_rights() # Before pushing stack
self.move_stack.append(self._from_chess960(self.chess960, move.from_square, move.to_square, move.promotion, move.drop))
self._stack.append(board_state)
Expand Down Expand Up @@ -2431,7 +2428,7 @@ def push(self: BoardT, move: Move) -> None:
# Swap turn.
self.turn = not self.turn

def pop(self: BoardT) -> Move:
def pop(self) -> Move:
"""
Restores the previous position and returns the last move from the stack.
Expand Down Expand Up @@ -2841,7 +2838,7 @@ def _validate_epd_opcode(self, opcode: str) -> None:
if blacklisted in opcode:
raise ValueError(f"invalid character {blacklisted!r} in epd opcode: {opcode!r}")

def _parse_epd_ops(self: BoardT, operation_part: str, make_board: Callable[[], BoardT]) -> Dict[str, Union[None, str, int, float, Move, List[Move]]]:
def _parse_epd_ops(self, operation_part: str, make_board: Callable[[], Self]) -> Dict[str, Union[None, str, int, float, Move, List[Move]]]:
operations: Dict[str, Union[None, str, int, float, Move, List[Move]]] = {}
state = "opcode"
opcode = ""
Expand Down Expand Up @@ -3834,16 +3831,16 @@ def apply_transform(self, f: Callable[[Bitboard], Bitboard]) -> None:
self.ep_square = None if self.ep_square is None else msb(f(BB_SQUARES[self.ep_square]))
self.castling_rights = f(self.castling_rights)

def transform(self: BoardT, f: Callable[[Bitboard], Bitboard]) -> BoardT:
def transform(self, f: Callable[[Bitboard], Bitboard]) -> Self:
board = self.copy(stack=False)
board.apply_transform(f)
return board

def apply_mirror(self: BoardT) -> None:
def apply_mirror(self) -> None:
super().apply_mirror()
self.turn = not self.turn

def mirror(self: BoardT) -> BoardT:
def mirror(self) -> Self:
"""
Returns a mirrored copy of the board.
Expand All @@ -3858,7 +3855,7 @@ def mirror(self: BoardT) -> BoardT:
board.apply_mirror()
return board

def copy(self: BoardT, *, stack: Union[bool, int] = True) -> BoardT:
def copy(self, *, stack: Union[bool, int] = True) -> Self:
"""
Creates a copy of the board.
Expand Down
Loading

0 comments on commit ec399d1

Please sign in to comment.