From ba93cf219d705be625f649181660d2ebbf130045 Mon Sep 17 00:00:00 2001 From: John Doknjas <32089502+johndoknjas@users.noreply.github.com> Date: Sat, 24 Aug 2024 15:12:18 -0700 Subject: [PATCH] Fix a bug in the FEN validation, and add more tests. (#91) --- stockfish/models.py | 30 +++++++++++++++++++++--------- tests/stockfish/test_models.py | 29 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/stockfish/models.py b/stockfish/models.py index 05c8157..e47ee62 100644 --- a/stockfish/models.py +++ b/stockfish/models.py @@ -35,6 +35,8 @@ class Stockfish: "10.0": "2018-11-29", } + _PIECE_CHARS = ("P", "N", "B", "R", "Q", "K", "p", "n", "b", "r", "q", "k") + # _PARAM_RESTRICTIONS stores the types of each of the params, and any applicable min and max values, based # off the Stockfish source code: https://github.com/official-stockfish/Stockfish/blob/65ece7d985291cc787d6c804a33f1dd82b75736d/src/ucioption.cpp#L58-L82 _PARAM_RESTRICTIONS: Dict[str, Tuple[type, Optional[int], Optional[int]]] = { @@ -629,25 +631,35 @@ def _get_sf_go_command_output(self) -> List[str]: def _is_fen_syntax_valid(fen: str) -> bool: # Code for this function taken from: https://gist.github.com/Dani4kor/e1e8b439115878f8c6dcf127a4ed5d3e # Some small changes have been made to the code. - regexMatch = re.match( + if not re.match( r"\s*^(((?:[rnbqkpRNBQKP1-8]+\/){7})[rnbqkpRNBQKP1-8]+)\s([b|w])\s(-|[K|Q|k|q]{1,4})\s(-|[a-h][1-8])\s(\d+\s\d+)$", fen, - ) - if not regexMatch: + ): return False - regexList = regexMatch.groups() - if len(regexList[0].split("/")) != 8: - return False # 8 rows not present. - for fenPart in regexList[0].split("/"): + + fen_fields = fen.split() + + if any( + ( + len(fen_fields) != 6, + len(fen_fields[0].split("/")) != 8, + any(x not in fen_fields[0] for x in "Kk"), + any(not fen_fields[x].isdigit() for x in (4, 5)), + int(fen_fields[4]) >= int(fen_fields[5]) * 2, + ) + ): + return False + + for fenPart in fen_fields[0].split("/"): field_sum: int = 0 previous_was_digit: bool = False for c in fenPart: - if c in ["1", "2", "3", "4", "5", "6", "7", "8"]: + if "1" <= c <= "8": if previous_was_digit: return False # Two digits next to each other. field_sum += int(c) previous_was_digit = True - elif c.lower() in ["p", "n", "b", "r", "q", "k"]: + elif c in Stockfish._PIECE_CHARS: field_sum += 1 previous_was_digit = False else: diff --git a/tests/stockfish/test_models.py b/tests/stockfish/test_models.py index bc505eb..93a7467 100644 --- a/tests/stockfish/test_models.py +++ b/tests/stockfish/test_models.py @@ -1149,25 +1149,46 @@ def test_is_fen_valid(self, stockfish: Stockfish): old_info = stockfish.info old_depth = stockfish._depth old_fen = stockfish.get_fen_position() - correct_fens = [ + correct_fens: List[Optional[str]] = [ "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 8", "4k3/8/4K3/8/8/8/8/8 w - - 10 50", "r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq - 8 15", + "4k3/8/4K3/8/8/8/8/8 w - - 99 50", ] invalid_syntax_fens = [ "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK b kq - 0 8", "rnbqkb1r/pppp1ppp/4pn2/8/2PP4/8/PP2PPPP/RNBQKBNR w KQkq - 3", "rn1q1rk1/pbppbppp/1p2pn2/8/2PP4/5NP1/PP2PPBP/RNBQ1RK1 w w - 5 7", "4k3/8/4K3/71/8/8/8/8 w - - 10 50", + "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2R2 b kq - 0 8", + "r1bQ1b1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 8", + "4k3/8/4K3/8/8/8/8/8 w - - 100 50", + "4k3/8/4K3/8/8/8/8/8 w - - 101 50", + "4k3/8/4K3/8/8/8/8/8 w - - -1 50", + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 0", + "r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq - - 8 15", + "r1b1kb1r/ppp2ppp/3q4/8/P2Q4/8/1PP2PPP/RNB2RK1 w kq 8 15", + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR W KQkq - 0 1", + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR - KQkq - 0 1", + "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - - 8", + "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - 0 -", + "r1bQkb1r/ppp2ppp/2p5/4Pn2/8/5N2/PPP2PPP/RNB2RK1 b kq - -1 8", + "4k3/8/4K3/8/8/8/8/8 w - - 99 e", + "4k3/8/4K3/8/8/8/8/8 w - - 99 ee", ] + correct_fens.extend([None] * (len(invalid_syntax_fens) - len(correct_fens))) + assert len(correct_fens) == len(invalid_syntax_fens) for correct_fen, invalid_syntax_fen in zip(correct_fens, invalid_syntax_fens): old_del_counter = Stockfish._del_counter - assert stockfish.is_fen_valid(correct_fen) + if correct_fen is not None: + assert stockfish.is_fen_valid(correct_fen) + assert stockfish._is_fen_syntax_valid(correct_fen) assert not stockfish.is_fen_valid(invalid_syntax_fen) - assert stockfish._is_fen_syntax_valid(correct_fen) assert not stockfish._is_fen_syntax_valid(invalid_syntax_fen) - assert Stockfish._del_counter == old_del_counter + 2 + assert Stockfish._del_counter == old_del_counter + ( + 2 if correct_fen is not None else 0 + ) time.sleep(2.0) assert stockfish._stockfish.poll() is None