Skip to content

Commit

Permalink
Expose Euchre tricks through pybind11.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 459827286
Change-Id: Ia21dfec7fe874eec62aa0dbf3fcb1c7f3fd54e7d
  • Loading branch information
dhennes authored and lanctot committed Jul 11, 2022
1 parent 22e08a1 commit eab604e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 13 deletions.
4 changes: 4 additions & 0 deletions open_spiel/games/euchre.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,10 @@ std::vector<double> EuchreState::Returns() const {
return points_;
}

std::vector<Trick> EuchreState::Tricks() const {
return std::vector<Trick>(tricks_.begin(), tricks_.end());
}

Trick::Trick(Player leader, Suit trump_suit, int card)
: winning_card_(card),
led_suit_(CardSuit(card, trump_suit)),
Expand Down
23 changes: 14 additions & 9 deletions open_spiel/games/euchre.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,30 @@ class EuchreState : public State {
absl::optional<bool> DeclarerGoAlone() const { return declarer_go_alone_; }
Player LoneDefender() const { return lone_defender_; }
std::vector<bool> ActivePlayers() const { return active_players_; }
std::vector<double> Points() const { return points_; }
Player Dealer() const { return dealer_; }
int CurrentPhase() const { return static_cast<int>(phase_); }

enum class Phase {
kDealerSelection, kDeal, kBidding, kDiscard, kGoAlone, kPlay, kGameOver };
Phase CurrentPhase() const { return phase_; }

int CurrentTrickIndex() const {
return std::min(num_cards_played_ / num_active_players_,
static_cast<int>(tricks_.size()));
}

std::array<absl::optional<Player>, kNumCards> CardHolder() const {
return holder_;
}
int CardRank(int card) const { return euchre::CardRank(card); }
Suit CardSuit(int card) const { return euchre::CardSuit(card); }
std::string CardString(int card) const { return euchre::CardString(card); }

std::vector<Trick> Tricks() const;

protected:
void DoApplyAction(Action action) override;

private:
enum class Phase {
kDealerSelection, kDeal, kBidding, kDiscard, kGoAlone, kPlay, kGameOver };

std::vector<Action> DealerSelectionLegalActions() const;
std::vector<Action> DealLegalActions() const;
std::vector<Action> BiddingLegalActions() const;
Expand All @@ -176,10 +184,7 @@ class EuchreState : public State {
void ApplyPlayAction(int card);

void ComputeScore();
int CurrentTrickIndex() const {
return std::min(num_cards_played_ / num_active_players_,
static_cast<int>(tricks_.size()));
}

Trick& CurrentTrick() { return tricks_[CurrentTrickIndex()]; }
const Trick& CurrentTrick() const { return tricks_[CurrentTrickIndex()]; }
std::array<std::string, kNumSuits> FormatHand(int player,
Expand Down
32 changes: 30 additions & 2 deletions open_spiel/python/pybind11/games_euchre.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ using euchre::EuchreGame;
using euchre::EuchreState;

void init_pyspiel_games_euchre(py::module& m) {
py::classh<EuchreState, State>(m, "EuchreState")
.def("num_cards_dealt", &EuchreState::NumCardsDealt)
py::classh<EuchreState, State> state_class(m, "EuchreState");
state_class.def("num_cards_dealt", &EuchreState::NumCardsDealt)
.def("num_cards_played", &EuchreState::NumCardsPlayed)
.def("num_passes", &EuchreState::NumPasses)
.def("upcard", &EuchreState::Upcard)
Expand All @@ -53,7 +53,11 @@ void init_pyspiel_games_euchre(py::module& m) {
.def("current_phase", &EuchreState::CurrentPhase)
.def("card_holder", &EuchreState::CardHolder)
.def("card_rank", &EuchreState::CardRank)
.def("card_suit", &EuchreState::CardSuit)
.def("card_string", &EuchreState::CardString)
.def("points", &EuchreState::Points)
.def("tricks", &EuchreState::Tricks)
.def("current_trick", &EuchreState::CurrentTrickIndex)
// Pickle support
.def(py::pickle(
[](const EuchreState& state) { // __getstate__
Expand All @@ -65,6 +69,30 @@ void init_pyspiel_games_euchre(py::module& m) {
return dynamic_cast<EuchreState*>(game_and_state.second.release());
}));

py::enum_<euchre::Suit>(state_class, "Suit")
.value("INVALID_SUIT", euchre::Suit::kInvalidSuit)
.value("CLUBS", euchre::Suit::kClubs)
.value("DIAMONDS", euchre::Suit::kDiamonds)
.value("HEARTS", euchre::Suit::kHearts)
.value("SPADES", euchre::Suit::kSpades)
.export_values();

py::class_<euchre::Trick>(state_class, "Trick")
.def("led_suit", &euchre::Trick::LedSuit)
.def("winner", &euchre::Trick::Winner)
.def("cards", &euchre::Trick::Cards)
.def("leader", &euchre::Trick::Leader);

py::enum_<euchre::EuchreState::Phase>(state_class, "Phase")
.value("DEALER_SELECTION", euchre::EuchreState::Phase::kDealerSelection)
.value("DEAL", euchre::EuchreState::Phase::kDeal)
.value("BIDDING", euchre::EuchreState::Phase::kBidding)
.value("DISCARD", euchre::EuchreState::Phase::kDiscard)
.value("GO_ALONE", euchre::EuchreState::Phase::kGoAlone)
.value("PLAY", euchre::EuchreState::Phase::kPlay)
.value("GAME_OVER", euchre::EuchreState::Phase::kGameOver)
.export_values();

py::classh<EuchreGame, Game>(m, "EuchreGame")
.def("max_bids", &EuchreGame::MaxBids)
.def("num_cards", &EuchreGame::NumCards)
Expand Down
13 changes: 11 additions & 2 deletions open_spiel/python/tests/games_euchre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,25 @@ def test_bindings(self):
self.assertEqual(state.first_defender(), pyspiel.PlayerId.INVALID)
self.assertEqual(state.declarer_partner(), pyspiel.PlayerId.INVALID)
self.assertEqual(state.second_defender(), pyspiel.PlayerId.INVALID)
self.assertIsNone(state.declarer_go_alone(), None)
self.assertIsNone(state.declarer_go_alone())
self.assertEqual(state.lone_defender(), pyspiel.PlayerId.INVALID)
self.assertEqual(state.active_players(), [True, True, True, True])
self.assertEqual(state.dealer(), pyspiel.INVALID_ACTION)
self.assertEqual(state.current_phase(), 0)
self.assertEqual(state.current_phase(), state.Phase.DEALER_SELECTION)
self.assertEqual(state.card_holder(), [None] * 24)
self.assertEqual(state.card_rank(3), 0)
self.assertEqual(state.card_rank(4), 1)
self.assertEqual(state.card_string(0), 'C9')
self.assertEqual(state.card_string(23), 'SA')
self.assertEqual(state.card_suit(0), state.Suit.CLUBS)
self.assertEqual(state.card_suit(23), state.Suit.SPADES)
self.assertEqual(state.current_trick(), 0)

trick = state.tricks()[0]
self.assertEqual(trick.leader(), pyspiel.PlayerId.INVALID)
self.assertEqual(trick.winner(), pyspiel.PlayerId.INVALID)
self.assertEqual(trick.led_suit(), state.Suit.INVALID_SUIT)
self.assertEqual(trick.cards(), [-1])


if __name__ == '__main__':
Expand Down

0 comments on commit eab604e

Please sign in to comment.