diff --git a/pgx/_src/dwg/shogi.py b/pgx/_src/dwg/shogi.py index dc75cded8..9d30c9f2f 100644 --- a/pgx/_src/dwg/shogi.py +++ b/pgx/_src/dwg/shogi.py @@ -5,7 +5,7 @@ def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901 if state._x.turn == 1: - from pgx.shogi import _flip + from pgx._src.games.shogi import _flip state = ShogiState(_x=_flip(state._x)) # type: ignore # fmt: off diff --git a/pgx/_src/shogi_utils.py b/pgx/experimental/shogi.py similarity index 74% rename from pgx/_src/shogi_utils.py rename to pgx/experimental/shogi.py index fac45bb7d..7b8188d2f 100644 --- a/pgx/_src/shogi_utils.py +++ b/pgx/experimental/shogi.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import numpy as np import jax -import jax.numpy as jnp -import numpy as np +from pgx._src.games.shogi import _flip, Game, GameState +from pgx.shogi import State -def _to_sfen(state): +def to_sfen(state): """Convert state into sfen expression. - Board @@ -37,8 +35,9 @@ def _to_sfen(state): """ # NOTE: input must be flipped if white turn + state = state if state._x.turn % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore - pb = jnp.rot90(state._x.board.reshape((9, 9)), k=3) + pb = np.rot90(state._x.board.reshape((9, 9)), k=3) sfen = "" # fmt: off board_char_dir = ["", "P", "L", "N", "S", "B", "R", "G", "K", "+P", "+L", "+N", "+S", "+B", "+R", "p", "l", "n", "s", "b", "r", "g", "k", "+p", "+l", "+n", "+s", "+b", "+r"] @@ -69,7 +68,7 @@ def _to_sfen(state): else: sfen += "w " # Hand (prisoners) - if jnp.all(state._x.hand == 0): + if np.all(state._x.hand == 0): sfen += "-" else: for i in range(2): @@ -85,14 +84,25 @@ def _to_sfen(state): return sfen -def _from_sfen(sfen): +@jax.jit +def _from_board(turn, piece_board, hand): + """Mainly for debugging purpose. + terminated, reward, and current_player are not changed""" + state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore + # fmt: off + state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore + # fmt: on + return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore + + +def from_sfen(sfen): # fmt: off board_char_dir = ["P", "L", "N", "S", "B", "R", "G", "K", "", "", "", "", "", "", "p", "l", "n", "s", "b", "r", "g", "k"] hand_char_dir = ["P", "L", "N", "S", "B", "R", "G", "p", "l", "n", "s", "b", "r", "g"] # fmt: on board, turn, hand, step_count = sfen.split() board_ranks = board.split("/") - piece_board = jnp.zeros(81, dtype=jnp.int32) + piece_board = np.zeros(81, dtype=np.int32) for i in range(9): file = board_ranks[i] rank = [] @@ -109,17 +119,18 @@ def _from_sfen(sfen): rank.append(piece) piece = 0 for j in range(9): - piece_board = piece_board.at[9 * i + j].set(rank[j]) - s_hand = jnp.zeros(14, dtype=jnp.int32) + piece_board[9 * i + j] = rank[j] + s_hand = np.zeros(14, dtype=np.int32) if hand != "-": num_piece = 1 for char in hand: if char.isdigit(): num_piece = int(char) else: - s_hand = s_hand.at[hand_char_dir.index(char)].set(num_piece) + s_hand[hand_char_dir.index(char)] = num_piece num_piece = 1 - piece_board = jnp.rot90(piece_board.reshape((9, 9)), k=1).flatten() - hand = jnp.reshape(s_hand, (2, 7)) - turn = jnp.int32(0) if turn == "b" else jnp.int32(1) - return turn, piece_board, hand, int(step_count) - 1 + piece_board = np.rot90(piece_board.reshape((9, 9)), k=1).flatten() + hand = np.reshape(s_hand, (2, 7)) + turn = 0 if turn == "b" else 1 + turn, piece_board, hand, step_count = turn, piece_board, hand, int(step_count) - 1 + return _from_board(turn, piece_board, hand).replace(_step_count=np.int32(step_count)) # type: ignore diff --git a/pgx/shogi.py b/pgx/shogi.py index 64cbec074..5d167afd2 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -17,13 +17,9 @@ import jax.numpy as jnp import pgx.core as core -from pgx._src.shogi_utils import ( - _from_sfen, - _to_sfen, -) from pgx._src.struct import dataclass from pgx._src.types import Array, PRNGKey -from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, _flip +from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe TRUE = jnp.bool_(True) @@ -55,25 +51,6 @@ class State(core.State): def env_id(self) -> core.EnvId: return "shogi" - @staticmethod - def _from_board(turn, piece_board: Array, hand: Array): - """Mainly for debugging purpose. - terminated, reward, and current_player are not changed""" - state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore - # fmt: off - state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore - # fmt: on - return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore - - @staticmethod - def _from_sfen(sfen): - turn, pb, hand, step_count = _from_sfen(sfen) - return jax.jit(State._from_board)(turn, pb, hand).replace(_step_count=jnp.int32(step_count)) # type: ignore - - def _to_sfen(self): - state = self if self._x.turn % 2 == 0 else self.replace(_x=_flip(self._x)) # type: ignore - return _to_sfen(state) - class Shogi(core.Env): diff --git a/tests/test_shogi.py b/tests/test_shogi.py index b9c5aa402..352fda5b9 100644 --- a/tests/test_shogi.py +++ b/tests/test_shogi.py @@ -4,6 +4,7 @@ from pgx.shogi import Shogi, State from pgx._src.games.shogi import Action, HORSE, PAWN, DRAGON +from pgx.experimental.shogi import from_sfen, to_sfen env = Shogi() init = jax.jit(env.init) @@ -45,25 +46,25 @@ def test_is_legal_drop(): # 打ち歩詰 # 避けられるし金でも取れる sfen = "lnsgkgsnl/7b1/ppppppppp/9/9/9/PPPP1PPPP/1B5R1/LNSGKGSNL b P 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/legal_drops_001.svg") assert state.legal_action_mask[20 * 81 + xy2i(5, 2)] # 片側に避けられるので打ち歩詰でない sfen = "lns1kpsnl/7b1/ppppGpppp/9/9/9/PPPP1PPPP/1B5R1/LNSGKGSNL b P 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/legal_drops_002.svg") assert state.legal_action_mask[20 * 81 + xy2i(5, 2)] # 両側に避けられないので打ち歩詰 sfen = "lnspkpsnl/7b1/ppppGpppp/9/9/9/PPPP1PPPP/1B5R1/LNSGKGSNL b P 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/legal_drops_003.svg") assert not state.legal_action_mask[20 * 81 + xy2i(5, 2)] # 金で取れるので打ち歩詰でない sfen = "lnsgkpsnl/7b1/ppppGpppp/9/9/9/PPPP1PPPP/1B5R1/LNSGKGSNL b P 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/legal_drops_004.svg") assert state.legal_action_mask[20 * 81 + xy2i(5, 2)] @@ -71,7 +72,7 @@ def test_is_legal_drop(): def test_buggy_samples(): # 歩以外の持ち駒に対しての二歩判定回避 sfen = "9/9/9/9/9/9/PPPPPPPPP/9/9 b NLP 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_001.svg") # 歩は二歩になるので打てない @@ -87,77 +88,77 @@ def test_buggy_samples(): # 成駒のpromotion判定 sfen = "9/2+B1G1+P2/9/9/9/9/9/9/9 b - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_002.svg") # promotionは生成されてたらダメ assert (state.legal_action_mask[10 * 81:]).sum() == 0 # 角は成れないはず sfen = "l+B6l/6k2/3pg2P1/p6p1/1pP1pB2p/2p3n2/P+r1GP3P/4KS1+s1/LNG5L b RGN2sn6p 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_003.svg") assert ~state.legal_action_mask[13 * 81 + 72] # = 1125, promote + left (91角成) # #375 sfen = "lnsgkg1nl/1r5s1/pppppp1pp/6p2/8B/2P6/PP1PPPPPP/7R1/LNSGKGSNL w b 1" - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/buggy_samples_004.svg") assert (s.legal_action_mask.sum() == len([43, 52, 68, 196, 222, 295, 789, 1996, 2004, 2012])).all(), jnp.nonzero(s.legal_action_mask)[0] assert (jnp.nonzero(s.legal_action_mask)[0] == jnp.int32([43, 52, 68, 196, 222, 295, 789, 1996, 2004, 2012])).all() # #602 sfen = "9/4R4/9/9/9/9/9/9/9 b 2r2b4g3s4n4l17p 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_005.svg") dlshogi_action = 846 state = step(state, dlshogi_action) visualize(state, "tests/assets/shogi/buggy_samples_006.svg") sfen = "4+R4/9/9/9/9/9/9/9/9 w 2r2b4g3s4n4l7p 1" - expected_state = State._from_sfen(sfen) + expected_state = from_sfen(sfen) visualize(expected_state, "tests/assets/shogi/buggy_samples_006.svg") assert (state._x.board == expected_state._x.board).all() # #603 - state = State._from_sfen("8k/9/9/5b3/9/3B5/9/9/K8 b 2r4g4s4n4l18p 1") + state = from_sfen("8k/9/9/5b3/9/3B5/9/9/K8 b 2r4g4s4n4l18p 1") visualize(state, "tests/assets/shogi/buggy_samples_007.svg") dlshogi_action = 202 a = Action._from_dlshogi_action(state._x, dlshogi_action) assert a.from_ == xy2i(6, 6) # #610 - state = State._from_sfen("+PsGg1p2+P/+B1+Pgp+N1sp/1+N5l1/P3kP1pL/3P1r3/B2KP3L/4L1SP+s/+r2+p2pgP/2P2+n+p2 b np 1") + state = from_sfen("+PsGg1p2+P/+B1+Pgp+N1sp/1+N5l1/P3kP1pL/3P1r3/B2KP3L/4L1SP+s/+r2+p2pgP/2P2+n+p2 b np 1") visualize(state, "tests/assets/shogi/buggy_samples_008.svg") dlshogi_action = 225 a = Action._from_dlshogi_action(state._x, dlshogi_action) assert a.from_ == xy2i(9, 2) assert a.piece == HORSE state = step(state, dlshogi_action) - expected_state = State._from_sfen("+P+BGg1p2+P/2+Pgp+N1sp/1+N5l1/P3kP1pL/3P1r3/B2KP3L/4L1SP+s/+r2+p2pgP/2P2+n+p2 w Snp 1") + expected_state = from_sfen("+P+BGg1p2+P/2+Pgp+N1sp/1+N5l1/P3kP1pL/3P1r3/B2KP3L/4L1SP+s/+r2+p2pgP/2P2+n+p2 w Snp 1") assert (state._x.board == expected_state._x.board).all() # #613 - state = State._from_sfen("1+N3s1n1/5k2l/l+P2g1bp1/2pP1p2p/p2ppNS2/LB6P/1pS1g2PL/3KPR2S/1R1G1NG2 b P4p 1") + state = from_sfen("1+N3s1n1/5k2l/l+P2g1bp1/2pP1p2p/p2ppNS2/LB6P/1pS1g2PL/3KPR2S/1R1G1NG2 b P4p 1") visualize(state, "tests/assets/shogi/buggy_samples_009.svg") dlshogi_action = 42 a = Action._from_dlshogi_action(state._x, dlshogi_action) assert a.from_ == xy2i(5, 8) assert a.piece == PAWN state = step(state, dlshogi_action) - expected_state = State._from_sfen("1+N3s1n1/5k2l/l+P2g1bp1/2pP1p2p/p2ppNS2/LB6P/1pS1P2PL/3K1R2S/1R1G1NG2 w GP4p 1") + expected_state = from_sfen("1+N3s1n1/5k2l/l+P2g1bp1/2pP1p2p/p2ppNS2/LB6P/1pS1P2PL/3K1R2S/1R1G1NG2 w GP4p 1") assert (state._x.board == expected_state._x.board).all() # #618 - state = State._from_sfen("2+P+P2G1+S/1P2+P+P1+Pn/+S1GK2P2/1b2PP3/1nl4PP/3k2lRL/1pg+s3L1/p2R2p2/P+n+B+p+ng1+s+p w P 1") + state = from_sfen("2+P+P2G1+S/1P2+P+P1+Pn/+S1GK2P2/1b2PP3/1nl4PP/3k2lRL/1pg+s3L1/p2R2p2/P+n+B+p+ng1+s+p w P 1") visualize(state, "tests/assets/shogi/buggy_samples_010.svg") dlshogi_action = 28 a = Action._from_dlshogi_action(state._x, dlshogi_action) assert a.from_ == xy2i(4, 3) state = step(state, dlshogi_action) - expected_state = State._from_sfen("2+P+P2G1+S/1P2+P+P1+Pn/+S1GK2P2/1b2PP3/1nl4PP/3k2lRL/1pg4L1/p2+s2p2/P+n+B+p+ng1+s+p b Pr 1") + expected_state = from_sfen("2+P+P2G1+S/1P2+P+P1+Pn/+S1GK2P2/1b2PP3/1nl4PP/3k2lRL/1pg4L1/p2+s2p2/P+n+B+p+ng1+s+p b Pr 1") assert (state._x.board == expected_state._x.board).all() # 629 - state = State._from_sfen("1ns6/+S1p+Ng1p1l/+P2pg1nNS/4k2G1/2L2R2s/p1G2+BPR1/3Pp2+p1/1+p3B1P1/1LPK2+l1+p b P4p 1") + state = from_sfen("1ns6/+S1p+Ng1p1l/+P2pg1nNS/4k2G1/2L2R2s/p1G2+BPR1/3Pp2+p1/1+p3B1P1/1LPK2+l1+p b P4p 1") visualize(state, "tests/assets/shogi/buggy_samples_011.svg") dlshogi_action = 1660 # 歩打 state = step(state, dlshogi_action) @@ -166,14 +167,14 @@ def test_buggy_samples(): # 打ち歩詰ではないが2歩 # 639 sfen = "+P2G1p2+P/1+N2+Pbk1p/3p3l+L/1gp3s1L/2SPG2p1/NK1SL3N/1p5RP/+p+r+p2+np+s1/b1P+p2+p2 w Gp 35" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_013.svg") assert not state.legal_action_mask[20 * 81 + xy2i(2, 5)] # Hand crafted tests #685 # double check sfen = "8k/9/9/9/9/8r/8s/9/7GK w - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) assert int(state.legal_action_mask.sum()) == 21 dl_action = 226 state = step(state, dl_action) @@ -181,7 +182,7 @@ def test_buggy_samples(): assert int(state.legal_action_mask.sum()) == 1 # discovered check with pin sfen = "8k/9/9/9/9/5b3/6r2/9/7GK w - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) assert int(state.legal_action_mask.sum()) == 49 dl_action = 54 state = step(state, dl_action) @@ -189,7 +190,7 @@ def test_buggy_samples(): assert int(state.legal_action_mask.sum()) == 1 # discovered check sfen = "k8/8g/9/4r1b1K/P8/9/9/9/9 w - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) assert int(state.legal_action_mask.sum()) == 37 dl_action = 156 state = step(state, dl_action) @@ -197,7 +198,7 @@ def test_buggy_samples(): assert int(state.legal_action_mask.sum()) == 1 # catch pieces sfen = "k1b1r4/5B2g/4p4/9/9/9/8K/4L4/9 b - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) assert int(state.legal_action_mask.sum()) == 23 dl_action = 38 state = step(state, dl_action) @@ -217,24 +218,24 @@ def test_buggy_samples(): assert int(state.legal_action_mask.sum()) == 10 # double pin sfen = "8k/9/9/9/9/5b3/6r2/7P1/7GK w - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) dl_action = 54 state = step(state, dl_action) visualize(state, "tests/assets/shogi/buggy_samples_021.svg") assert int(state.legal_action_mask.sum()) == 2 # drop pawn mate sfen = "8k/9/7L1/7N1/9/9/9/9/8K b P 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_022.svg") assert int(state.legal_action_mask.sum()) == 76 # move pawn mate(legal) sfen = "8k/9/7LP/7N1/9/9/9/9/8K b - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_023.svg") assert int(state.legal_action_mask.sum()) == 10 # pinned same line sfen = "8k/9/9/9/4b4/9/6B2/9/8K b - 1" - state = State._from_sfen(sfen) + state = from_sfen(sfen) visualize(state, "tests/assets/shogi/buggy_samples_024.svg") assert int(state.legal_action_mask.sum()) == 6 @@ -247,7 +248,7 @@ def test_step(): except: assert False, line sfen = d["sfen_before"] - state = State._from_sfen(sfen) + state = from_sfen(sfen) expected_legal_actions = d["legal_actions"] legal_actions = jnp.nonzero(state.legal_action_mask)[0] ok = legal_actions.shape[0] == len(expected_legal_actions) @@ -263,7 +264,7 @@ def test_step(): action = int(d["action"]) state = step(state, action) sfen = d["sfen_after"] - assert state._to_sfen() == sfen + assert to_sfen(state) == sfen def test_observe(): @@ -367,7 +368,7 @@ def test_observe(): # 駒打ち sfen = "1ns4nl/1r4k2/2p1gp3/1p1pp3p/l8/2P2PP2/1PNPP3P/2G2S3/2S1KG2L b BGS3Prbnl2p 1" - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/observe_001.svg") obs = observe(s, s.current_player) @@ -380,13 +381,13 @@ def test_observe(): # 王手 sfen = "lnsgkg1nl/1r5s1/pppppp1pp/6p2/8B/2P6/PP1PPPPPP/7R1/LNSGKGSNL b b 1" # 先手番 - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/observe_002.svg") obs = observe(s, s.current_player) assert (~obs[:, :, -1]).all() sfen = "lnsgkg1nl/1r5s1/pppppp1pp/6p2/8B/2P6/PP1PPPPPP/7R1/LNSGKGSNL w b 1" # 後手番 - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/observe_003.svg") obs = observe(s, s.current_player) assert obs[:, :, -1].all() @@ -395,14 +396,14 @@ def test_observe(): def test_sfen(): sfen = "lnsgkg1nl/1r5s1/pppppp1pp/6p2/8B/2P6/PP1PPPPPP/7R1/LNSGKGSNL b b 1" - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/sfen_001.svg") - assert s._to_sfen() == sfen + assert to_sfen(s) == sfen sfen = "lnsgkg1nl/1r5s1/pppppp1pp/6p2/8B/2P6/PP1PPPPPP/7R1/LNSGKGSNL w b 1" - s = State._from_sfen(sfen) + s = from_sfen(sfen) visualize(s, "tests/assets/shogi/sfen_002.svg") - assert s._to_sfen() == sfen + assert to_sfen(s) == sfen def test_api():