Skip to content

Commit

Permalink
Fix a bug in the FEN validation, and add more tests. (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
johndoknjas authored Aug 24, 2024
1 parent f617c03 commit ba93cf2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
30 changes: 21 additions & 9 deletions stockfish/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Stockfish:
"10.0": "2018-11-29",
}

_PIECE_CHARS = ("P", "N", "B", "R", "Q", "K", "p", "n", "b", "r", "q", "k")

# _PARAM_RESTRICTIONS stores the types of each of the params, and any applicable min and max values, based
# off the Stockfish source code: https://github.com/official-stockfish/Stockfish/blob/65ece7d985291cc787d6c804a33f1dd82b75736d/src/ucioption.cpp#L58-L82
_PARAM_RESTRICTIONS: Dict[str, Tuple[type, Optional[int], Optional[int]]] = {
Expand Down Expand Up @@ -629,25 +631,35 @@ def _get_sf_go_command_output(self) -> List[str]:
def _is_fen_syntax_valid(fen: str) -> bool:
# Code for this function taken from: https://gist.github.com/Dani4kor/e1e8b439115878f8c6dcf127a4ed5d3e
# Some small changes have been made to the code.
regexMatch = re.match(
if not re.match(
r"\s*^(((?:[rnbqkpRNBQKP1-8]+\/){7})[rnbqkpRNBQKP1-8]+)\s([b|w])\s(-|[K|Q|k|q]{1,4})\s(-|[a-h][1-8])\s(\d+\s\d+)$",
fen,
)
if not regexMatch:
):
return False
regexList = regexMatch.groups()
if len(regexList[0].split("/")) != 8:
return False # 8 rows not present.
for fenPart in regexList[0].split("/"):

fen_fields = fen.split()

if any(
(
len(fen_fields) != 6,
len(fen_fields[0].split("/")) != 8,
any(x not in fen_fields[0] for x in "Kk"),
any(not fen_fields[x].isdigit() for x in (4, 5)),
int(fen_fields[4]) >= int(fen_fields[5]) * 2,
)
):
return False

for fenPart in fen_fields[0].split("/"):
field_sum: int = 0
previous_was_digit: bool = False
for c in fenPart:
if c in ["1", "2", "3", "4", "5", "6", "7", "8"]:
if "1" <= c <= "8":
if previous_was_digit:
return False # Two digits next to each other.
field_sum += int(c)
previous_was_digit = True
elif c.lower() in ["p", "n", "b", "r", "q", "k"]:
elif c in Stockfish._PIECE_CHARS:
field_sum += 1
previous_was_digit = False
else:
Expand Down
29 changes: 25 additions & 4 deletions tests/stockfish/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,25 +1149,46 @@ def test_is_fen_valid(self, stockfish: Stockfish):
old_info = stockfish.info
old_depth = stockfish._depth
old_fen = stockfish.get_fen_position()
correct_fens = [
correct_fens: List[Optional[str]] = [
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 8",
"4k3/8/4K3/8/8/8/8/8 w - - 10 50",
"r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq - 8 15",
"4k3/8/4K3/8/8/8/8/8 w - - 99 50",
]
invalid_syntax_fens = [
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK b kq - 0 8",
"rnbqkb1r/pppp1ppp/4pn2/8/2PP4/8/PP2PPPP/RNBQKBNR w KQkq - 3",
"rn1q1rk1/pbppbppp/1p2pn2/8/2PP4/5NP1/PP2PPBP/RNBQ1RK1 w w - 5 7",
"4k3/8/4K3/71/8/8/8/8 w - - 10 50",
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2R2 b kq - 0 8",
"r1bQ1b1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 8",
"4k3/8/4K3/8/8/8/8/8 w - - 100 50",
"4k3/8/4K3/8/8/8/8/8 w - - 101 50",
"4k3/8/4K3/8/8/8/8/8 w - - -1 50",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 0",
"r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq - - 8 15",
"r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq 8 15",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR W KQkq - 0 1",
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR - KQkq - 0 1",
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - - 8",
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 -",
"r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - -1 8",
"4k3/8/4K3/8/8/8/8/8 w - - 99 e",
"4k3/8/4K3/8/8/8/8/8 w - - 99 ee",
]
correct_fens.extend([None] * (len(invalid_syntax_fens) - len(correct_fens)))
assert len(correct_fens) == len(invalid_syntax_fens)
for correct_fen, invalid_syntax_fen in zip(correct_fens, invalid_syntax_fens):
old_del_counter = Stockfish._del_counter
assert stockfish.is_fen_valid(correct_fen)
if correct_fen is not None:
assert stockfish.is_fen_valid(correct_fen)
assert stockfish._is_fen_syntax_valid(correct_fen)
assert not stockfish.is_fen_valid(invalid_syntax_fen)
assert stockfish._is_fen_syntax_valid(correct_fen)
assert not stockfish._is_fen_syntax_valid(invalid_syntax_fen)
assert Stockfish._del_counter == old_del_counter + 2
assert Stockfish._del_counter == old_del_counter + (
2 if correct_fen is not None else 0
)

time.sleep(2.0)
assert stockfish._stockfish.poll() is None
Expand Down

0 comments on commit ba93cf2

Please sign in to comment.