Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Shogi] move shogi_utils.py to experimental #1274

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgx/_src/dwg/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def _make_shogi_dwg(dwg, state: ShogiState, config): # noqa: C901
if state._x.turn == 1:
from pgx.shogi import _flip
from pgx._src.games.shogi import _flip

state = ShogiState(_x=_flip(state._x)) # type: ignore
# fmt: off
Expand Down
43 changes: 27 additions & 16 deletions pgx/_src/shogi_utils.py → pgx/experimental/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import jax
import jax.numpy as jnp
import numpy as np
from pgx._src.games.shogi import _flip, Game, GameState
from pgx.shogi import State


def _to_sfen(state):
def to_sfen(state):
"""Convert state into sfen expression.

- Board
Expand All @@ -37,8 +35,9 @@ def _to_sfen(state):

"""
# NOTE: input must be flipped if white turn
state = state if state._x.turn % 2 == 0 else state.replace(_x=_flip(state._x)) # type: ignore

pb = jnp.rot90(state._x.board.reshape((9, 9)), k=3)
pb = np.rot90(state._x.board.reshape((9, 9)), k=3)
sfen = ""
# fmt: off
board_char_dir = ["", "P", "L", "N", "S", "B", "R", "G", "K", "+P", "+L", "+N", "+S", "+B", "+R", "p", "l", "n", "s", "b", "r", "g", "k", "+p", "+l", "+n", "+s", "+b", "+r"]
Expand Down Expand Up @@ -69,7 +68,7 @@ def _to_sfen(state):
else:
sfen += "w "
# Hand (prisoners)
if jnp.all(state._x.hand == 0):
if np.all(state._x.hand == 0):
sfen += "-"
else:
for i in range(2):
Expand All @@ -85,14 +84,25 @@ def _to_sfen(state):
return sfen


def _from_sfen(sfen):
@jax.jit
def _from_board(turn, piece_board, hand):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore


def from_sfen(sfen):
# fmt: off
board_char_dir = ["P", "L", "N", "S", "B", "R", "G", "K", "", "", "", "", "", "", "p", "l", "n", "s", "b", "r", "g", "k"]
hand_char_dir = ["P", "L", "N", "S", "B", "R", "G", "p", "l", "n", "s", "b", "r", "g"]
# fmt: on
board, turn, hand, step_count = sfen.split()
board_ranks = board.split("/")
piece_board = jnp.zeros(81, dtype=jnp.int32)
piece_board = np.zeros(81, dtype=np.int32)
for i in range(9):
file = board_ranks[i]
rank = []
Expand All @@ -109,17 +119,18 @@ def _from_sfen(sfen):
rank.append(piece)
piece = 0
for j in range(9):
piece_board = piece_board.at[9 * i + j].set(rank[j])
s_hand = jnp.zeros(14, dtype=jnp.int32)
piece_board[9 * i + j] = rank[j]
s_hand = np.zeros(14, dtype=np.int32)
if hand != "-":
num_piece = 1
for char in hand:
if char.isdigit():
num_piece = int(char)
else:
s_hand = s_hand.at[hand_char_dir.index(char)].set(num_piece)
s_hand[hand_char_dir.index(char)] = num_piece
num_piece = 1
piece_board = jnp.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = jnp.reshape(s_hand, (2, 7))
turn = jnp.int32(0) if turn == "b" else jnp.int32(1)
return turn, piece_board, hand, int(step_count) - 1
piece_board = np.rot90(piece_board.reshape((9, 9)), k=1).flatten()
hand = np.reshape(s_hand, (2, 7))
turn = 0 if turn == "b" else 1
turn, piece_board, hand, step_count = turn, piece_board, hand, int(step_count) - 1
return _from_board(turn, piece_board, hand).replace(_step_count=np.int32(step_count)) # type: ignore
25 changes: 1 addition & 24 deletions pgx/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@
import jax.numpy as jnp

import pgx.core as core
from pgx._src.shogi_utils import (
_from_sfen,
_to_sfen,
)
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, _flip
from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe


TRUE = jnp.bool_(True)
Expand Down Expand Up @@ -55,25 +51,6 @@ class State(core.State):
def env_id(self) -> core.EnvId:
return "shogi"

@staticmethod
def _from_board(turn, piece_board: Array, hand: Array):
"""Mainly for debugging purpose.
terminated, reward, and current_player are not changed"""
state = State(_x=GameState(turn=turn, board=piece_board, hand=hand)) # type: ignore
# fmt: off
state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore
# fmt: on
return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore

@staticmethod
def _from_sfen(sfen):
turn, pb, hand, step_count = _from_sfen(sfen)
return jax.jit(State._from_board)(turn, pb, hand).replace(_step_count=jnp.int32(step_count)) # type: ignore

def _to_sfen(self):
state = self if self._x.turn % 2 == 0 else self.replace(_x=_flip(self._x)) # type: ignore
return _to_sfen(state)


class Shogi(core.Env):

Expand Down
Loading
Loading