Skip to content

Commit

Permalink
[Shogi] move shogi_utils.py to experimental (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 1, 2024
1 parent ffe939e commit 26a09bd
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 77 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901
if state._x.turn == 1:
from pgx.shogi import _flip
from pgx._src.games.shogi import _flip

state = ShogiState(_x=_flip(state._x)) # type: ignore
# fmt: off
Expand Down
43 changes: 27 additions & 16 deletions pgx/_src/shogi_utils.py → pgx/experimental/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import jax
import jax.numpy as jnp
import numpy as np
from pgx._src.games.shogi import _flip, Game, GameState
from pgx.shogi import State


def _to_sfen(state):
def to_sfen(state):
"""Convert state into sfen expression.
- Board
Expand All @@ -37,8 +35,9 @@ def _to_sfen(state):
"""
# NOTE: input must be flipped if white turn
state = state if state._x.turn % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore

pb = jnp.rot90(state._x.board.reshape((9, 9)), k=3)
pb = np.rot90(state._x.board.reshape((9, 9)), k=3)
sfen = ""
# fmt: off
board_char_dir = ["", "P", "L", "N", "S", "B", "R", "G", "K", "+P", "+L", "+N", "+S", "+B", "+R", "p", "l", "n", "s", "b", "r", "g", "k", "+p", "+l", "+n", "+s", "+b", "+r"]
Expand Down Expand Up @@ -69,7 +68,7 @@ def _to_sfen(state):
else:
sfen += "w "
# Hand (prisoners)
if jnp.all(state._x.hand == 0):
if np.all(state._x.hand == 0):
sfen += "-"
else:
for i in range(2):
Expand All @@ -85,14 +84,25 @@ def _to_sfen(state):
return sfen


def _from_sfen(sfen):
@jax.jit
def _from_board(turn, piece_board, hand):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore


def from_sfen(sfen):
# fmt: off
board_char_dir = ["P", "L", "N", "S", "B", "R", "G", "K", "", "", "", "", "", "", "p", "l", "n", "s", "b", "r", "g", "k"]
hand_char_dir = ["P", "L", "N", "S", "B", "R", "G", "p", "l", "n", "s", "b", "r", "g"]
# fmt: on
board, turn, hand, step_count = sfen.split()
board_ranks = board.split("/")
piece_board = jnp.zeros(81, dtype=jnp.int32)
piece_board = np.zeros(81, dtype=np.int32)
for i in range(9):
file = board_ranks[i]
rank = []
Expand All @@ -109,17 +119,18 @@ def _from_sfen(sfen):
rank.append(piece)
piece = 0
for j in range(9):
piece_board = piece_board.at[9 * i + j].set(rank[j])
s_hand = jnp.zeros(14, dtype=jnp.int32)
piece_board[9 * i + j] = rank[j]
s_hand = np.zeros(14, dtype=np.int32)
if hand != "-":
num_piece = 1
for char in hand:
if char.isdigit():
num_piece = int(char)
else:
s_hand = s_hand.at[hand_char_dir.index(char)].set(num_piece)
s_hand[hand_char_dir.index(char)] = num_piece
num_piece = 1
piece_board = jnp.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = jnp.reshape(s_hand, (2, 7))
turn = jnp.int32(0) if turn == "b" else jnp.int32(1)
return turn, piece_board, hand, int(step_count) - 1
piece_board = np.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = np.reshape(s_hand, (2, 7))
turn = 0 if turn == "b" else 1
turn, piece_board, hand, step_count = turn, piece_board, hand, int(step_count) - 1
return _from_board(turn, piece_board, hand).replace(_step_count=np.int32(step_count)) # type: ignore
25 changes: 1 addition & 24 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@
import jax.numpy as jnp

import pgx.core as core
from pgx._src.shogi_utils import (
_from_sfen,
_to_sfen,
)
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, _flip
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe


TRUE = jnp.bool_(True)
Expand Down Expand Up @@ -55,25 +51,6 @@ class State(core.State):
def env_id(self) -> core.EnvId:
return "shogi"

@staticmethod
def _from_board(turn, piece_board: Array, hand: Array):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore

@staticmethod
def _from_sfen(sfen):
turn, pb, hand, step_count = _from_sfen(sfen)
return jax.jit(State._from_board)(turn, pb, hand).replace(_step_count=jnp.int32(step_count)) # type: ignore

def _to_sfen(self):
state = self if self._x.turn % 2 == 0 else self.replace(_x=_flip(self._x)) # type: ignore
return _to_sfen(state)


class Shogi(core.Env):

Expand Down
Loading

0 comments on commit 26a09bd

Please sign in to comment.