diff --git a/.gitignore b/.gitignore index 73b1d2d..11a1fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ __pycache__ # ignore any CSV files written to the data directory: +data/*.csv data/*/*.csv diff --git a/app/ai/README.md b/app/ai/README.md deleted file mode 100644 index ab5ba79..0000000 --- a/app/ai/README.md +++ /dev/null @@ -1,56 +0,0 @@ - -# Tic Tac Toe (AI) - -> NOTE: this is currently a work in progress! - -The [computer players](/app/player.py) in the Tic Tac Toe app use a predefined algorithm to select the best move. Let's see if we can train machine learning models to play the game instead. - -## Generating Moves Data - -Generates moves dataset for model training and evaluation. The test/train splits are out of scope for this step, and will be done later, during model training. - -The inputs for each move are the board state (e.g. "-X-O-X-OX" and the active player (i.e. "X" or "O"). The output is the player's selected move, represented by the selected square's index position in the board notation string (0-8). - -We simulate alternating moves until we reach an outcome (win, lose, tie). After reaching an outcome, we can assign the eventual outcome value to all moves that player made leading up until the outcome: - + all winning player's moves get assigned a positive value (+1) - + all losing player's moves get assigned a negative value (-1) - + all moves resulting in a tie get a neutral score (0) - -Generating the datasets: - -```sh -GAME_COUNT=100000 X_STRATEGY="RANDOM" O_STRATEGY="RANDOM" python -m app.jobs.play_moves -``` - -Exports a CSV file in the "data" directory (e.g "/data/moves/x_random_vs_o_random_10000.csv"). Example results: - -|game_id|move_id|board_state|player|square_idx|reward| -|-------|-------|-----------|------|---------|------| -|1 |1 | `---------` |`X` |5 | 0 | -|1 |2 | `-----X---` |`O` |7 | 0 | -|1 |3 | `-----X-O-` |`X` |8 | 0 | -|1 |4 | `-----X-OX` |`O` |3 | 0 | -|1 |5 | `---O-X-OX` |`X` |1 | 0 | -|1 |6 | `-X-O-X-OX` |`O` |2 | 0 | -|1 |7 | `-XOO-X-OX` |`X` |6 | 0 | -|1 |8 | `-XOO-XXOX` |`O` |4 | 0 | -|1 |9 | `-XOOOXXOX` |`X` |0 | 0 | -|2 |1 |`---------` |`X` |1 |1 | -|2 |2 |`-X-------` |`O` |8 |-1 | -|2 |3 |`-X------O` |`X` |0 |1 | -|2 |4 |`XX------O` |`O` |4 |-1 | -|2 |5 |`XX--O---O` |`X` |2 |1 | -|3 |1 |`---------` |`X` |4 |-1 | -|3 |2 |`----X----` |`O` |7 |1 | -|3 |3 |`----X--O-` |`X` |2 |-1 | -|3 |4 |`--X-X--O-` |`O` |8 |1 | -|3 |5 |`--X-X--OO` |`X` |5 |-1 | -|3 |6 |`--X-XX-OO` |`O` |6 |1 | - -## Model Selection and Evaluation - -TBD - -## Model Training - -TBD diff --git a/app/board.py b/app/board.py index 7fd3f75..697d7fb 100644 --- a/app/board.py +++ b/app/board.py @@ -39,14 +39,14 @@ def __repr__(self): """ - @property - def notation(self) -> str: - """ - Represents the board's current state in simple string format like "-X-O-X-OX". - - Position corresponds with square names ['A1','B1','C1','A2','B2','C2','A3','B3','C3'] and indices [0,1,2,3,4,5,6,7,8]. - """ - return "".join([square.notation for square in self.squares]) + #@property + #def notation(self) -> str: + # """ + # Represents the board's current state in simple string format like "-X-O-X-OX". +# + # Position corresponds with square names ['A1','B1','C1','A2','B2','C2','A3','B3','C3'] and indices [0,1,2,3,4,5,6,7,8]. + # """ + # return "".join([square.notation for square in self.squares]) def get_square(self, square_name): @@ -57,7 +57,13 @@ def get_square(self, square_name): def get_squares(self, square_names): return [square for square in self.squares if square.name in square_names] - def set_square(self, square_name, player_letter): + def set_square(self, square_name: str, player_letter: str): + """ + Params: + square_name + + player_letter + """ square = self.get_square(square_name) if not square.letter: square.letter = player_letter diff --git a/app/game.py b/app/game.py index b5b5c82..ccd6546 100644 --- a/app/game.py +++ b/app/game.py @@ -1,5 +1,6 @@ from itertools import cycle +from copy import deepcopy from app.board import Board from app.player import select_player @@ -46,9 +47,8 @@ def take_turn(self, turn: tuple): Pass the turn param as a tuple in the form of (player_letter, square_name). """ player_letter, square_name = turn - initial_board_state = self.board.notation # important to note this before changing the board - - move = Move(board_state=initial_board_state, active_player=player_letter, selected_square=square_name) + initial_board = deepcopy(self.board) # important to note this before changing the board + move = Move(board=initial_board, active_player=player_letter, selected_square=square_name) # make the move / change the board state: self.board.set_square(square_name, player_letter) diff --git a/app/jobs/generate_moves.py b/app/jobs/generate_moves.py new file mode 100644 index 0000000..13671cc --- /dev/null +++ b/app/jobs/generate_moves.py @@ -0,0 +1,113 @@ + + + +import os + +from pandas import DataFrame + +from app import OPPOSITE_LETTERS +#from app.board import SQUARE_NAMES +from app.game import Game +from app.player import select_player +from app.jobs.timer import Timer + + +# for the strategies, use "RANDOM" for random moves, or "MINIMAX-AB" for expert moves +X_STRATEGY = os.getenv("X_STRATEGY", default="RANDOM") +O_STRATEGY = os.getenv("O_STRATEGY", default="RANDOM") + +GAME_COUNT = int(os.getenv("GAME_COUNT", default="1_000")) + + + +class EvaluatedGame(Game): + @property + def player_rewards(self): + if self.winner: + # reward the winner and punish the loser: + winning_letter = self.winner["letter"] + losing_letter = OPPOSITE_LETTERS[winning_letter] + return {winning_letter: 1, losing_letter: 0} + else: + # give neutral scores to both players: + return {"X": 0.5, "O": 0.5} + + +class MoveEvaluator: + def __init__(self): + self.timer = Timer() + self.players = [ + select_player(letter="X", strategy=X_STRATEGY), + select_player(letter="O", strategy=O_STRATEGY), + ] + self.GAME_COUNT = GAME_COUNT + self.moves_df = None + + def perform(self, export=True): + self.timer.start() + records = [] + for game_counter in range(0, self.GAME_COUNT): + game = EvaluatedGame(players=self.players) + game.play() + + player_rewards = game.player_rewards + for move_counter, move in enumerate(game.move_history): + active_player = move.active_player + + # if the active player takes this move they will get the outcome + move.board.set_square(square_name=move.selected_square, player_letter=active_player) + + records.append({ + "game_id": game_counter + 1, # start ids at 1 instead of 0 + "move_id": move_counter + 1, # start ids at 1 instead of 0 + #"board_state": move.board.notation, + "a1": move.board.get_square("A1").notation, + "b1": move.board.get_square("B1").notation, + "c1": move.board.get_square("C1").notation, + "a2": move.board.get_square("A2").notation, + "b2": move.board.get_square("B2").notation, + "c2": move.board.get_square("C2").notation, + "a3": move.board.get_square("A3").notation, + "b3": move.board.get_square("B3").notation, + "c3": move.board.get_square("C3").notation, + "player": active_player, + #"square_name": move.selected_square, + #"square_idx": SQUARE_NAMES.index(move.selected_square), # translate squares to index 0-8 to match board notation (maybe) + "outcome": player_rewards[active_player], + }) + + self.timer.end() + print("------------------------") + print("PLAYED", self.GAME_COUNT, "GAMES", f"IN {self.timer.duration_seconds} SECONDS") + print("TOTAL MOVES:", len(records)) + self.moves_df = DataFrame(records) + print(self.moves_df.head()) + + if export: + print("------------------------") + print("SAVING DATA TO FILE...") + #csv_filename = f"{self.players[0].letter}_{self.players[0].player_type.replace('-','')}" + #csv_filename += "_vs_" + #csv_filename += f"_{self.players[1].letter}_{self.players[1].player_type.replace('-','')}" + + csv_filename = self.players[0].player_type.replace('-','') + "_" + csv_filename += self.players[1].player_type.replace('-','') + csv_filename += f"_{self.GAME_COUNT}.csv" + csv_filename = csv_filename.lower() + csv_filepath = os.path.join(os.path.dirname(__file__), "..", "..", "data", "moves", csv_filename) + + self.moves_df.to_csv(csv_filepath, index=False) + print(os.path.abspath(csv_filepath)) + + return self.moves_df + + + + +if __name__ == "__main__": + + + + job = MoveEvaluator() + + job.perform() diff --git a/app/jobs/play_moves.py b/app/jobs/play_moves.py deleted file mode 100644 index 2e6e4ed..0000000 --- a/app/jobs/play_moves.py +++ /dev/null @@ -1,92 +0,0 @@ - - - -import os - -from pandas import DataFrame - -from app import OPPOSITE_LETTERS -from app.board import SQUARE_NAMES -from app.game import Game -from app.player import select_player -from app.jobs.timer import Timer - - -# for the strategies, use "RANDOM" for random moves, or "MINIMAX-AB" for expert moves -X_STRATEGY = os.getenv("X_STRATEGY", default="RANDOM") -O_STRATEGY = os.getenv("O_STRATEGY", default="RANDOM") - -GAME_COUNT = int(os.getenv("GAME_COUNT", default="100_000")) - -if __name__ == "__main__": - - timer = Timer() - timer.start() - - records = [] - for game_counter in range(0, GAME_COUNT): - game = Game(players=[ - select_player(letter="X", strategy=X_STRATEGY), - select_player(letter="O", strategy=O_STRATEGY), - ]) - - # - # PLAY - # - - game.play() - - # - # OUTCOME EVAL - # - - # determine reward values for each player - if game.winner: - winning_letter = game.winner["letter"] - losing_letter = OPPOSITE_LETTERS[winning_letter] - # reward the winner and punish the loser: - rewards = {winning_letter: 1, losing_letter: -1} - else: - # give neutral scores to both players: - rewards = {"X": 0, "O": 0} - print("------------------------") - print("REWARDS:", rewards) - - # - # PLAYBACK - # - - for move_counter, move in enumerate(game.move_history): - active_player = move.active_player - records.append({ - "game_id": game_counter + 1, # start ids at 1 instead of 0 - "move_id": move_counter + 1, # start ids at 1 instead of 0 - "board_state": move.board_state, - "player": active_player, - "square_name": move.selected_square, - "square_idx": SQUARE_NAMES.index(move.selected_square), # translate squares to index 0-8 to match board notation (maybe) - "reward": rewards[active_player], - }) - - timer.end() - print("------------------------") - print("PLAYED", GAME_COUNT, "GAMES", f"IN {timer.duration_seconds} SECONDS") - print("TOTAL MOVES:", len(records)) - - df = DataFrame(records) - print(df.head()) - - print("------------------------") - print("SAVING DATA TO FILE...") - - csv_filename = f"{game.players[0].letter}_{game.players[0].player_type}" - csv_filename += "_vs_" - csv_filename += f"{game.players[1].letter}_{game.players[1].player_type}" - csv_filename += f"_{GAME_COUNT}.csv" - csv_filename = csv_filename.lower() - csv_filepath = os.path.join(os.path.dirname(__file__), "..", "..", "data", "moves", csv_filename) - - df.to_csv(csv_filepath, index=False) - print(os.path.abspath(csv_filepath)) - - #breakpoint() diff --git a/app/move.py b/app/move.py index e065407..6b544de 100644 --- a/app/move.py +++ b/app/move.py @@ -3,19 +3,19 @@ class Move: - def __init__(self, board_state, active_player, selected_square): + def __init__(self, board, active_player, selected_square): """ Params - board_state (str) the initial board state before the player made the move + board (Board) the initial board state before the player made the move active_player (str) the letter of the player who made the move ("X" or "O") selected_square (str) the name of the square the player selected (e.g "A1") """ - self.board_state = board_state #> "XX-OO----" - self.active_player = active_player #> "X" - self.selected_square = selected_square #> "C1" + self.board = board + self.active_player = active_player + self.selected_square = selected_square def __repr__(self): return f"" diff --git a/ml/README.md b/ml/README.md new file mode 100644 index 0000000..1fac79a --- /dev/null +++ b/ml/README.md @@ -0,0 +1,51 @@ + +# Tic Tac Toe (AI/ML) + +> NOTE: this is currently a work in progress! + +The [computer players](/app/player.py) in the Tic Tac Toe app use a predefined algorithm to select the best move. Let's see if we can train machine learning models to play the game instead. + +## Existing Datasets + +There is a [dataset of tic-tac-toe endgames](https://archive.ics.uci.edu/ml/datasets/Tic-Tac-Toe+Endgame) from UC Irvine. The inputs are the terminal board states, and the output is the game outcome ("positive" or "negative") for the "X" player. It is possible to [train a nice classifier](/endgames/Endgame_Model_Training.ipynb) on their data. But it represents terminal game states only, whereas an AI agent would need to assess non-terminal game states as well. And it assumes the perspective of the "X" player only, whereas an AI agent would need to be able to play as "O" as well. + +## Dataset Generation + +So let's generate a dataset of terminal and non-terminal game states, for both players (in case there are strategy differences between "X" going first and "O" going second). + +In this step, we generate moves datasets for model training and evaluation. The test/train splits are out of scope for this step, and will be done later, during model training. + +The inputs are the board state after the active player makes a move (e.g. "-X-O-X-OX"). The output is the eventual outcome of that move for the given player (win, loss, or tie). + +We simulate alternating moves until we reach an outcome (win, lose, tie). After reaching an outcome, we assign the eventual outcome value to all moves that player made leading up until the outcome: + + all winning player's moves get assigned a positive value (1.0) + + all losing player's moves get assigned a zero (0.0) + + all moves resulting in a tie get a neutral score (0.5) + +Generating the datasets: + +```sh +X_STRATEGY="RANDOM" O_STRATEGY="RANDOM" GAME_COUNT=1000 python -m app.jobs.generate_moves + +X_STRATEGY="RANDOM" O_STRATEGY="MINIMAX-AB" GAME_COUNT=1000 python -m app.jobs.generate_moves + +X_STRATEGY="MINIMAX-AB" O_STRATEGY="RANDOM" GAME_COUNT=1000 python -m app.jobs.generate_moves + +X_STRATEGY="MINIMAX-AB" O_STRATEGY="MINIMAX-AB" GAME_COUNT=1000 python -m app.jobs.generate_moves +``` + +After generating the datasets, we uploaded the CSV files to GitHub, and used a [notebook](/ml/data_prep/Training_Data_Prep.ipynb) to combine the datasets into a single CSV file ("move_values.csv"), which we also upload to GitHub as well: + +Datasets: + + + [random_random_1000.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921041/random_random_1000.csv) + + [random_minimaxab_1000.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921043/random_minimaxab_1000.csv) + + [minimaxab_random_1000.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921050/minimaxab_random_1000.csv) + + [minimaxab_minimaxab_1000.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921045/minimaxab_minimaxab_1000.csv) + +Combined Datasets: + + + [move_values.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921159/move_values.csv) + + [move_values_x.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921160/move_values_x.csv) + + [move_values_o.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921161/move_values_o.csv) + + [move_values_normalized.csv](https://github.com/s2t2/tic-tac-toe-py/files/7921162/move_values_normalized.csv) diff --git a/ml/data_prep/Training_Data_Prep.ipynb b/ml/data_prep/Training_Data_Prep.ipynb new file mode 100644 index 0000000..6e6d1f6 --- /dev/null +++ b/ml/data_prep/Training_Data_Prep.ipynb @@ -0,0 +1,697 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Tic Tac Toe - Training Data Prep", + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Import datasets" + ], + "metadata": { + "id": "0pG_Kg24wWz6" + } + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Hj6PchBDokUF", + "outputId": "8353ecc3-0174-4805-a565-e2d6372af1ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--------------------\n", + "X RANDOM vs O RANDOM:\n", + "1.0 0.460388\n", + "0.0 0.386149\n", + "0.5 0.153463\n", + "Name: outcome, dtype: float64\n", + "--------------------\n", + "X RANDOM vs O EXPERT:\n", + "1.0 0.381239\n", + "0.0 0.381239\n", + "0.5 0.237522\n", + "Name: outcome, dtype: float64\n", + "--------------------\n", + "X EXPERT vs O RANDOM:\n", + "1.0 0.568935\n", + "0.0 0.398080\n", + "0.5 0.032984\n", + "Name: outcome, dtype: float64\n", + "--------------------\n", + "X EXPERT vs O EXPERT:\n", + "0.5 1.0\n", + "Name: outcome, dtype: float64\n" + ] + } + ], + "source": [ + "from pandas import read_csv\n", + "\n", + "print(\"--------------------\")\n", + "print(\"X RANDOM vs O RANDOM:\")\n", + "rr = read_csv(\"https://github.com/s2t2/tic-tac-toe-py/files/7921041/random_random_1000.csv\")\n", + "rr[\"dataset_id\"] = \"random_vs_random\"\n", + "print(rr[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"--------------------\")\n", + "print(\"X RANDOM vs O EXPERT:\")\n", + "rm = read_csv(\"https://github.com/s2t2/tic-tac-toe-py/files/7921043/random_minimaxab_1000.csv\")\n", + "rm[\"dataset_id\"] = \"random_vs_expert\"\n", + "print(rm[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"--------------------\")\n", + "print(\"X EXPERT vs O RANDOM:\")\n", + "mr = read_csv(\"https://github.com/s2t2/tic-tac-toe-py/files/7921050/minimaxab_random_1000.csv\")\n", + "mr[\"dataset_id\"] = \"expert_vs_random\"\n", + "print(mr[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"--------------------\")\n", + "print(\"X EXPERT vs O EXPERT:\")\n", + "mm = read_csv(\"https://github.com/s2t2/tic-tac-toe-py/files/7921045/minimaxab_minimaxab_1000.csv\")\n", + "print(mm[\"outcome\"].value_counts(normalize=True))\n", + "mm[\"dataset_id\"] = \"expert_vs_expert\"\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Combine Datasets" + ], + "metadata": { + "id": "FVHOoqIYwf39" + } + }, + { + "cell_type": "code", + "source": [ + "from pandas import concat\n", + "\n", + "combined_df = concat([rr, rm, mr, mm])\n", + "\n", + "print(combined_df)\n", + "\n", + "print(\"-------------\")\n", + "print(\"OUTCOMES:\")\n", + "print(combined_df[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"-------------\")\n", + "print(\"DATASETS:\")\n", + "print(combined_df[\"dataset_id\"].value_counts(normalize=True))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SZb0fWyQpshF", + "outputId": "fdbc27a6-cfb2-4ad0-ef58-33c1181bdd2c" + }, + "execution_count": 78, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " game_id move_id a1 b1 c1 a2 ... a3 b3 c3 player outcome dataset_id\n", + "0 1 1 X - - - ... - - - X 1.0 random_vs_random\n", + "1 1 2 X - - - ... - O - O 0.0 random_vs_random\n", + "2 1 3 X X - - ... - O - X 1.0 random_vs_random\n", + "3 1 4 X X - - ... - O - O 0.0 random_vs_random\n", + "4 1 5 X X - X ... - O - X 1.0 random_vs_random\n", + "... ... ... .. .. .. .. ... .. .. .. ... ... ...\n", + "8995 1000 5 X X O - ... X - - X 0.5 expert_vs_expert\n", + "8996 1000 6 X X O O ... X - - O 0.5 expert_vs_expert\n", + "8997 1000 7 X X O O ... X - - X 0.5 expert_vs_expert\n", + "8998 1000 8 X X O O ... X O - O 0.5 expert_vs_expert\n", + "8999 1000 9 X X O O ... X O X X 0.5 expert_vs_expert\n", + "\n", + "[29326 rows x 14 columns]\n", + "-------------\n", + "OUTCOMES:\n", + "0.5 0.409705\n", + "1.0 0.321489\n", + "0.0 0.268806\n", + "Name: outcome, dtype: float64\n", + "-------------\n", + "DATASETS:\n", + "expert_vs_expert 0.306895\n", + "random_vs_random 0.259974\n", + "random_vs_expert 0.237741\n", + "expert_vs_random 0.195390\n", + "Name: dataset_id, dtype: float64\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "combined_df.to_csv(\"move_values.csv\", index=False)\n" + ], + "metadata": { + "id": "S9oYL8hFqALA" + }, + "execution_count": 79, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Separate Player Perspectives" + ], + "metadata": { + "id": "oiKMMt0BMyKk" + } + }, + { + "cell_type": "code", + "source": [ + "df = combined_df.copy()\n", + "\n", + "x_rows = df[df[\"player\"] == \"X\"]\n", + "o_rows = df[df[\"player\"] == \"O\"]\n", + "print(\"X ROWS:\", len(x_rows))\n", + "print(\"O ROWS:\", len(o_rows))\n", + "print(\"TOTAL :\", len(df))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Yb5xL0mcM0T8", + "outputId": "0b270a88-ffa2-40a5-96fc-2fd3ef84770f" + }, + "execution_count": 80, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "X ROWS: 16103\n", + "O ROWS: 13223\n", + "TOTAL : 29326\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "o_rows.to_csv(\"move_values_o.csv\")\n", + "x_rows.to_csv(\"move_values_x.csv\")" + ], + "metadata": { + "id": "rz6yYomRM9aJ" + }, + "execution_count": 81, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Normalize O Player Perspective" + ], + "metadata": { + "id": "XG1xWzn4Gvwk" + } + }, + { + "cell_type": "code", + "source": [ + "OPPOSITE_PERSPECTIVES = {\"X\":\"O\", \"O\":\"X\", \"-\":\"-\"}\n", + "OPPOSITE_OUTCOMES = {1:0, 0:1, 0.5:0.5}\n", + "\n", + "def normalize_square(val):\n", + " return OPPOSITE_PERSPECTIVES[val]\n", + "\n", + "#def normalize_board(board_notation):\n", + "# return \"\".join([normalize(char) for char in board_notation])\n", + "\n", + "def normalize_outcome(val):\n", + " return OPPOSITE_OUTCOMES[val]\n", + "\n", + "assert normalize(\"X\") == \"O\"\n", + "assert normalize(\"O\") == \"X\"\n", + "assert normalize(\"-\") == \"-\"\n", + "#assert normalize_board(\"XO-XO----\") == \"OX-OX----\"\n", + "\n", + "assert normalize_outcome(1) == 0\n", + "assert normalize_outcome(0.5) == 0.5\n", + "assert normalize_outcome(0) == 1\n" + ], + "metadata": { + "id": "530ZTDGEGz2u" + }, + "execution_count": 82, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"\\n--------------------\")\n", + "print(\"ORIGINAL PERSPECTIVES:\")\n", + "print(o_rows[[\"a1\",\"b1\", \"c1\", \"a2\",\"b2\", \"c2\", \"a3\",\"b3\", \"c3\", \"player\", \"outcome\"]].head())\n", + "print(o_rows[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "# NORMALIZING...\n", + "new_o_rows = o_rows.copy()\n", + "\n", + "#new_o_rows[\"board_state\"] = o_rows[\"board_state\"].apply(normalize_board)\n", + "for col in [\"a1\",\"b1\",\"c1\",\"a2\",\"b2\",\"c2\",\"a3\",\"b3\",\"c3\"]:\n", + " new_o_rows[col] = o_rows[col].apply(normalize)\n", + "new_o_rows[\"outcome\"] = o_rows[\"outcome\"].apply(normalize_outcome)\n", + "new_o_rows[\"player\"] = \"X\" # normalize player (or drop it)\n", + "\n", + "print(\"\\n--------------------\")\n", + "print(\"NORMALIZED PERSPECTIVES:\")\n", + "print(new_o_rows[[\"a1\",\"b1\", \"c1\", \"a2\",\"b2\", \"c2\", \"a3\",\"b3\", \"c3\", \"player\", \"outcome\"]].head())\n", + "print(new_o_rows[\"outcome\"].value_counts(normalize=True))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5KusQOKcHXq3", + "outputId": "bde4dea7-6a7b-4289-c472-007756bc231a" + }, + "execution_count": 83, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "--------------------\n", + "ORIGINAL PERSPECTIVES:\n", + " a1 b1 c1 a2 b2 c2 a3 b3 c3 player outcome\n", + "1 X - - - - - - O - O 0.0\n", + "3 X X - - O - - O - O 0.0\n", + "5 X X - X O - - O O O 0.0\n", + "8 - - - O - - - X - O 0.0\n", + "10 X - - O - O - X - O 0.0\n", + "0.5 0.403842\n", + "0.0 0.310141\n", + "1.0 0.286017\n", + "Name: outcome, dtype: float64\n", + "\n", + "--------------------\n", + "NORMALIZED PERSPECTIVES:\n", + " a1 b1 c1 a2 b2 c2 a3 b3 c3 player outcome\n", + "1 O - - - - - - X - X 1.0\n", + "3 O O - - X - - X - X 1.0\n", + "5 O O - O X - - X X X 1.0\n", + "8 - - - X - - - O - X 1.0\n", + "10 O - - X - X - O - X 1.0\n", + "0.5 0.403842\n", + "1.0 0.310141\n", + "0.0 0.286017\n", + "Name: outcome, dtype: float64\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# https://pandas.pydata.org/docs/reference/api/pandas.concat.html\n", + "from pandas import concat\n", + "\n", + "# single-player df\n", + "print(\"SINGLE PLAYER'S PERSPECTIVE:\")\n", + "df_normalized = concat([x_rows, new_o_rows])\n", + "df_normalized.drop(columns=[\"player\"], inplace=True)\n", + "df_normalized.sort_values(by=[\"dataset_id\", \"game_id\", \"move_id\"], inplace=True)\n", + "print(len(df_normalized))\n", + "\n", + "df_normalized" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 494 + }, + "id": "LlHj10zTHhcZ", + "outputId": "d6e555b4-4a48-4bd9-b00c-17fa85e6753e" + }, + "execution_count": 84, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "SINGLE PLAYER'S PERSPECTIVE:\n", + "29326\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
game_idmove_ida1b1c1a2b2c2a3b3c3outcomedataset_id
011-----X---0.5expert_vs_expert
112--X--O---0.5expert_vs_expert
213X-O--X---0.5expert_vs_expert
314O-XX-O---0.5expert_vs_expert
415XXOO-X---0.5expert_vs_expert
..........................................
761910001-X-------1.0random_vs_random
762010002-O---X---1.0random_vs_random
762110003-XX--O---1.0random_vs_random
762210004-OO--X--X1.0random_vs_random
762310005XXX--O--O1.0random_vs_random
\n", + "

29326 rows × 13 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " game_id move_id a1 b1 c1 a2 b2 c2 a3 b3 c3 outcome dataset_id\n", + "0 1 1 - - - - - X - - - 0.5 expert_vs_expert\n", + "1 1 2 - - X - - O - - - 0.5 expert_vs_expert\n", + "2 1 3 X - O - - X - - - 0.5 expert_vs_expert\n", + "3 1 4 O - X X - O - - - 0.5 expert_vs_expert\n", + "4 1 5 X X O O - X - - - 0.5 expert_vs_expert\n", + "... ... ... .. .. .. .. .. .. .. .. .. ... ...\n", + "7619 1000 1 - X - - - - - - - 1.0 random_vs_random\n", + "7620 1000 2 - O - - - X - - - 1.0 random_vs_random\n", + "7621 1000 3 - X X - - O - - - 1.0 random_vs_random\n", + "7622 1000 4 - O O - - X - - X 1.0 random_vs_random\n", + "7623 1000 5 X X X - - O - - O 1.0 random_vs_random\n", + "\n", + "[29326 rows x 13 columns]" + ] + }, + "metadata": {}, + "execution_count": 84 + } + ] + }, + { + "cell_type": "code", + "source": [ + "df_normalized.to_csv(\"move_values_normalized.csv\")" + ], + "metadata": { + "id": "L2hwCS19P3F3" + }, + "execution_count": 85, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/ml/endgames/Endgame_Model_Training.ipynb b/ml/endgames/Endgame_Model_Training.ipynb new file mode 100644 index 0000000..1cbf1c1 --- /dev/null +++ b/ml/endgames/Endgame_Model_Training.ipynb @@ -0,0 +1,1831 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Tic Tac Toe Model Training (END STATES) - 20220102 ", + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Dataset Info\n" + ], + "metadata": { + "id": "u6mV0e_fHOCu" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "http://archive.ics.uci.edu/ml/datasets/Tic-Tac-Toe+Endgame\n", + "\n", + "\n", + "This database encodes the **complete set of possible board configurations at the end of tic-tac-toe games**, where \"x\" is assumed to have played first. \n", + "\n", + "The target concept is: \"win for x\" (i.e. positive when \"x\" has a way to win).\n", + "\n", + "\n", + "Attributes / Columns:\n", + "\n", + "1. top-left-square: {x,o,b}\n", + "2. top-middle-square: {x,o,b}\n", + "3. top-right-square: {x,o,b}\n", + "4. middle-left-square: {x,o,b}\n", + "5. middle-middle-square: {x,o,b}\n", + "6. middle-right-square: {x,o,b}\n", + "7. bottom-left-square: {x,o,b}\n", + "8. bottom-middle-square: {x,o,b}\n", + "9. bottom-right-square: {x,o,b}\n", + "10. Class: {positive,negative}\n", + "\n" + ], + "metadata": { + "id": "2a0xpH6f4kM7" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Formatter Functions" + ], + "metadata": { + "id": "q3-WGcWYHGAP" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "\n", + "# HELPER FUNCTIONS\n", + "\n", + "def fmt_n(large_number):\n", + " \"\"\"\n", + " Formats a large number with thousands separator, for printing and logging.\n", + " \n", + " Param large_number (int) like 1_000_000_000\n", + " \n", + " Returns (str) like '1,000,000,000'\n", + " \"\"\"\n", + " return f\"{large_number:,.0f}\"\n", + "\n", + "\n", + "def fmt_pct(decimal_number):\n", + " \"\"\"\n", + " Formats a large number with thousands separator, for printing and logging.\n", + " \n", + " Param decimal_number (float) like 0.95555555555\n", + " \n", + " Returns (str) like '95.5%'\n", + " \"\"\"\n", + " return f\"{decimal_number:.2f}%\"\n", + "\n", + "\n", + "assert fmt_n(1_000_000_000) == '1,000,000,000'\n", + "assert fmt_pct(9.67890987) == '9.68%'" + ], + "metadata": { + "id": "b7416I9bELB5" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Importing the Data\n" + ], + "metadata": { + "id": "SY8YP2-p53Mh" + } + }, + { + "cell_type": "code", + "source": [ + "#\n", + "# IMPORT DATA\n", + "# https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html\n", + "#\n", + "\n", + "from pandas import read_csv\n", + "\n", + "url = \"http://archive.ics.uci.edu/ml/machine-learning-databases/tic-tac-toe/tic-tac-toe.data\"\n", + "\n", + "columns = [\"a1\", \"a2\", \"a3\", \"b1\", \"b2\", \"b3\", \"c1\", \"c2\", \"c3\", \"outcome\"]\n", + "\n", + "raw_df = read_csv(url, names=columns)\n", + "raw_df.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "id": "IYxmq5QT564m", + "outputId": "54096f45-7388-4864-dad2-9188a4ae575d" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
a1a2a3b1b2b3c1c2c3outcome
0xxxxooxoopositive
1xxxxoooxopositive
2xxxxooooxpositive
3xxxxooobbpositive
4xxxxoobobpositive
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " a1 a2 a3 b1 b2 b3 c1 c2 c3 outcome\n", + "0 x x x x o o x o o positive\n", + "1 x x x x o o o x o positive\n", + "2 x x x x o o o o x positive\n", + "3 x x x x o o o b b positive\n", + "4 x x x x o o b o b positive" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The positive outcome rows are first and the negative outcome rows are later, so we'll have to shuffle the row order / randomize when splitting the data." + ], + "metadata": { + "id": "7kEqVWLm56T1" + } + }, + { + "cell_type": "code", + "source": [ + "print(raw_df[\"outcome\"].value_counts(normalize=True))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FQ_XwcoX-5vW", + "outputId": "2a698c41-e09a-4ed7-86db-eb76d77220b7" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "positive 0.653445\n", + "negative 0.346555\n", + "Name: outcome, dtype: float64\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "there is a 2:1 ratio between positive and negative outcomes. this is generally imbalanced. \n", + "\n", + "https://machinelearningmastery.com/types-of-classification-in-machine-learning/" + ], + "metadata": { + "id": "-jSzc1cm-8rq" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Splitting the Data" + ], + "metadata": { + "id": "ABHaHgU9_h2o" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "\n", + "# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "\n", + "train_df, test_df = train_test_split(raw_df, train_size=0.8, shuffle=True, random_state=99, stratify= raw_df[\"outcome\"])\n", + "\n", + "print(\"--------------------\")\n", + "print(\"TRAIN:\", len(train_df))\n", + "print(train_df[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"--------------------\")\n", + "print(\"TEST:\", len(test_df))\n", + "print(test_df[\"outcome\"].value_counts(normalize=True))\n", + "\n", + "print(\"--------------------\")\n", + "train_df" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 606 + }, + "id": "oFWpPzJi7I1M", + "outputId": "9a31a5cd-cda5-48ec-de87-b6b737926f23" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--------------------\n", + "TRAIN: 766\n", + "positive 0.654047\n", + "negative 0.345953\n", + "Name: outcome, dtype: float64\n", + "--------------------\n", + "TEST: 192\n", + "positive 0.651042\n", + "negative 0.348958\n", + "Name: outcome, dtype: float64\n", + "--------------------\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
a1a2a3b1b2b3c1c2c3outcome
783oxoobxoxxnegative
468oboobxxxxpositive
405ooxbbxoxxpositive
820oooxoxxxbnegative
874obboxxoxbnegative
.................................
232xbxoxbooxpositive
449obxbxoxbbpositive
381ooxoxxxxopositive
838ooobxxoxxnegative
403ooxbobxxxpositive
\n", + "

766 rows × 10 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " a1 a2 a3 b1 b2 b3 c1 c2 c3 outcome\n", + "783 o x o o b x o x x negative\n", + "468 o b o o b x x x x positive\n", + "405 o o x b b x o x x positive\n", + "820 o o o x o x x x b negative\n", + "874 o b b o x x o x b negative\n", + ".. .. .. .. .. .. .. .. .. .. ...\n", + "232 x b x o x b o o x positive\n", + "449 o b x b x o x b b positive\n", + "381 o o x o x x x x o positive\n", + "838 o o o b x x o x x negative\n", + "403 o o x b o b x x x positive\n", + "\n", + "[766 rows x 10 columns]" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Feature Selection" + ], + "metadata": { + "id": "z84Eh1FOFUA0" + } + }, + { + "cell_type": "code", + "source": [ + "target_col = \"outcome\"\n", + "\n", + "feature_cols = columns\n", + "if target_col in feature_cols:\n", + " feature_cols.remove(target_col)\n", + "print(\"FEATURE COLS:\", feature_cols)\n", + "\n", + "train_x = train_df[feature_cols]\n", + "train_y = train_df[[target_col]]\n", + "\n", + "test_x = test_df[feature_cols]\n", + "test_y = test_df[[target_col]]" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4Z6IYvoS_7rr", + "outputId": "28e1b4db-2003-4d74-b57f-fa9aa89c4be1" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "FEATURE COLS: ['a1', 'a2', 'a3', 'b1', 'b2', 'b3', 'c1', 'c2', 'c3']\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "train_x" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 419 + }, + "id": "5Ty17vtsCl89", + "outputId": "df1ed55f-45ca-4253-9a50-22710c4abd48" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
a1a2a3b1b2b3c1c2c3
783oxoobxoxx
468oboobxxxx
405ooxbbxoxx
820oooxoxxxb
874obboxxoxb
..............................
232xbxoxboox
449obxbxoxbb
381ooxoxxxxo
838ooobxxoxx
403ooxbobxxx
\n", + "

766 rows × 9 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " a1 a2 a3 b1 b2 b3 c1 c2 c3\n", + "783 o x o o b x o x x\n", + "468 o b o o b x x x x\n", + "405 o o x b b x o x x\n", + "820 o o o x o x x x b\n", + "874 o b b o x x o x b\n", + ".. .. .. .. .. .. .. .. .. ..\n", + "232 x b x o x b o o x\n", + "449 o b x b x o x b b\n", + "381 o o x o x x x x o\n", + "838 o o o b x x o x x\n", + "403 o o x b o b x x x\n", + "\n", + "[766 rows x 9 columns]" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "code", + "source": [ + "train_y" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 419 + }, + "id": "mXbpOPMpCnZb", + "outputId": "69f32a6d-6521-435f-b6ce-9aece60eece9" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
outcome
783negative
468positive
405positive
820negative
874negative
......
232positive
449positive
381positive
838negative
403positive
\n", + "

766 rows × 1 columns

\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + " outcome\n", + "783 negative\n", + "468 positive\n", + "405 positive\n", + "820 negative\n", + "874 negative\n", + ".. ...\n", + "232 positive\n", + "449 positive\n", + "381 positive\n", + "838 negative\n", + "403 positive\n", + "\n", + "[766 rows x 1 columns]" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Feature Matrix Encoding" + ], + "metadata": { + "id": "9aridLI6Hf6b" + } + }, + { + "cell_type": "code", + "source": [ + "# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html\n", + "\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "\n", + "encoder = OneHotEncoder()\n", + "\n", + "encoder.fit(train_x) \n", + "\n", + "features = encoder.get_feature_names_out()\n", + "print(\"FEATURES :\", fmt_n(len(features)))\n", + "print(features)\n", + "\n", + "train_m = encoder.transform(train_x)\n", + "test_m = encoder.transform(test_x)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AYKvZl3GHeWm", + "outputId": "d0f9475f-9924-45fb-fb9d-02177bff3f45" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "FEATURES : 27\n", + "['a1_b' 'a1_o' 'a1_x' 'a2_b' 'a2_o' 'a2_x' 'a3_b' 'a3_o' 'a3_x' 'b1_b'\n", + " 'b1_o' 'b1_x' 'b2_b' 'b2_o' 'b2_x' 'b3_b' 'b3_o' 'b3_x' 'c1_b' 'c1_o'\n", + " 'c1_x' 'c2_b' 'c2_o' 'c2_x' 'c3_b' 'c3_o' 'c3_x']\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "print(train_m.shape)\n", + "print(test_m.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4mInxKSoKCse", + "outputId": "3f7c71ac-814c-43ca-df84-a8154608eab5" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(766, 27)\n", + "(192, 27)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "print(train_m[0])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "J9_I06q3KXEq", + "outputId": "b1e18e15-9d6c-4ba7-cc20-9eadbff6fb72" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " (0, 1)\t1.0\n", + " (0, 5)\t1.0\n", + " (0, 7)\t1.0\n", + " (0, 10)\t1.0\n", + " (0, 12)\t1.0\n", + " (0, 17)\t1.0\n", + " (0, 19)\t1.0\n", + " (0, 23)\t1.0\n", + " (0, 26)\t1.0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Model Selection and Training" + ], + "metadata": { + "id": "r4gEro9jBL6r" + } + }, + { + "cell_type": "code", + "source": [ + "from pprint import pprint\n", + "from sklearn.metrics import classification_report #, accuracy_score\n", + "\n", + "#def train_and_score(model, train_m, train_y, test_m, test_y):\n", + "def train_and_score(model):\n", + "\n", + " print(\"-------------------\")\n", + " print(\"TRAINING...\")\n", + " model.fit(train_m, train_y)\n", + "\n", + " train_y_pred = model.predict(train_m)\n", + " train_scores = classification_report(train_y, train_y_pred, output_dict=True)\n", + " #print(\"ACCY:\", fmt_pct(train_scores[\"accuracy\"]), \"GOAL:\", fmt_pct(train_scores[recall_class][\"recall\"]))\n", + " pprint(train_scores)\n", + "\n", + " print(\"-------------------\")\n", + " print(\"TESTING...\")\n", + " test_y_pred = model.predict(test_m)\n", + " test_scores = classification_report(test_y, test_y_pred, output_dict=True)\n", + " #print(\"ACCY:\", fmt_pct(test_scores[\"accuracy\"]), \"GOAL:\", fmt_pct(test_scores[recall_class][\"recall\"]))\n", + " pprint(test_scores)\n", + "\n" + ], + "metadata": { + "id": "KdlN86_CBN6r" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n", + "\n", + "from sklearn.linear_model import LogisticRegression\n", + "\n", + "lr = LogisticRegression(random_state=99, max_iter=2000)\n", + "\n", + "train_and_score(model=lr)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jDg96coCDA2k", + "outputId": "5f29fd82-f1d9-46f6-c71e-d4bdb063a356" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "-------------------\n", + "TRAINING...\n", + "{'accuracy': 0.9817232375979112,\n", + " 'macro avg': {'f1-score': 0.9795443447476042,\n", + " 'precision': 0.9864077669902913,\n", + " 'recall': 0.9735849056603774,\n", + " 'support': 766},\n", + " 'negative': {'f1-score': 0.9728682170542635,\n", + " 'precision': 1.0,\n", + " 'recall': 0.9471698113207547,\n", + " 'support': 265},\n", + " 'positive': {'f1-score': 0.9862204724409448,\n", + " 'precision': 0.9728155339805825,\n", + " 'recall': 1.0,\n", + " 'support': 501},\n", + " 'weighted avg': {'f1-score': 0.9816012195982939,\n", + " 'precision': 0.9822200816243758,\n", + " 'recall': 0.9817232375979112,\n", + " 'support': 766}}\n", + "-------------------\n", + "TESTING...\n", + "{'accuracy': 0.9895833333333334,\n", + " 'macro avg': {'f1-score': 0.9884559884559885,\n", + " 'precision': 0.9921259842519685,\n", + " 'recall': 0.9850746268656716,\n", + " 'support': 192},\n", + " 'negative': {'f1-score': 0.9848484848484849,\n", + " 'precision': 1.0,\n", + " 'recall': 0.9701492537313433,\n", + " 'support': 67},\n", + " 'positive': {'f1-score': 0.9920634920634921,\n", + " 'precision': 0.984251968503937,\n", + " 'recall': 1.0,\n", + " 'support': 125},\n", + " 'weighted avg': {'f1-score': 0.9895457551707553,\n", + " 'precision': 0.989747375328084,\n", + " 'recall': 0.9895833333333334,\n", + " 'support': 192}}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/utils/validation.py:993: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegressionCV.html\n", + "# https://scikit-learn.org/stable/glossary.html#term-cross-validation-estimator\n", + "\n", + "from sklearn.linear_model import LogisticRegressionCV\n", + "\n", + "lr_cv = LogisticRegressionCV(random_state=99)\n", + "\n", + "train_and_score(lr_cv)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XuBbSRgQDXEt", + "outputId": "dae601dc-8cb0-46b0-c5db-d351f5ee4bfd" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "-------------------\n", + "TRAINING...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/sklearn/utils/validation.py:993: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'accuracy': 0.9817232375979112,\n", + " 'macro avg': {'f1-score': 0.9795443447476042,\n", + " 'precision': 0.9864077669902913,\n", + " 'recall': 0.9735849056603774,\n", + " 'support': 766},\n", + " 'negative': {'f1-score': 0.9728682170542635,\n", + " 'precision': 1.0,\n", + " 'recall': 0.9471698113207547,\n", + " 'support': 265},\n", + " 'positive': {'f1-score': 0.9862204724409448,\n", + " 'precision': 0.9728155339805825,\n", + " 'recall': 1.0,\n", + " 'support': 501},\n", + " 'weighted avg': {'f1-score': 0.9816012195982939,\n", + " 'precision': 0.9822200816243758,\n", + " 'recall': 0.9817232375979112,\n", + " 'support': 766}}\n", + "-------------------\n", + "TESTING...\n", + "{'accuracy': 0.9895833333333334,\n", + " 'macro avg': {'f1-score': 0.9884559884559885,\n", + " 'precision': 0.9921259842519685,\n", + " 'recall': 0.9850746268656716,\n", + " 'support': 192},\n", + " 'negative': {'f1-score': 0.9848484848484849,\n", + " 'precision': 1.0,\n", + " 'recall': 0.9701492537313433,\n", + " 'support': 67},\n", + " 'positive': {'f1-score': 0.9920634920634921,\n", + " 'precision': 0.984251968503937,\n", + " 'recall': 1.0,\n", + " 'support': 125},\n", + " 'weighted avg': {'f1-score': 0.9895457551707553,\n", + " 'precision': 0.989747375328084,\n", + " 'recall': 0.9895833333333334,\n", + " 'support': 192}}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html\n", + "\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "\n", + "dt = DecisionTreeClassifier(random_state=99)\n", + "\n", + "train_and_score(dt)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CVnedxS0JFdl", + "outputId": "63f67bda-21c4-4fbf-b808-05243c29aae3" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "-------------------\n", + "TRAINING...\n", + "{'accuracy': 1.0,\n", + " 'macro avg': {'f1-score': 1.0,\n", + " 'precision': 1.0,\n", + " 'recall': 1.0,\n", + " 'support': 766},\n", + " 'negative': {'f1-score': 1.0, 'precision': 1.0, 'recall': 1.0, 'support': 265},\n", + " 'positive': {'f1-score': 1.0, 'precision': 1.0, 'recall': 1.0, 'support': 501},\n", + " 'weighted avg': {'f1-score': 1.0,\n", + " 'precision': 1.0,\n", + " 'recall': 1.0,\n", + " 'support': 766}}\n", + "-------------------\n", + "TESTING...\n", + "{'accuracy': 0.9583333333333334,\n", + " 'macro avg': {'f1-score': 0.9547543301519972,\n", + " 'precision': 0.9494820160633222,\n", + " 'recall': 0.9610746268656716,\n", + " 'support': 192},\n", + " 'negative': {'f1-score': 0.9420289855072463,\n", + " 'precision': 0.9154929577464789,\n", + " 'recall': 0.9701492537313433,\n", + " 'support': 67},\n", + " 'positive': {'f1-score': 0.967479674796748,\n", + " 'precision': 0.9834710743801653,\n", + " 'recall': 0.952,\n", + " 'support': 125},\n", + " 'weighted avg': {'f1-score': 0.9585984446800989,\n", + " 'precision': 0.9597495440965352,\n", + " 'recall': 0.9583333333333334,\n", + " 'support': 192}}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Feature Importances" + ], + "metadata": { + "id": "VtGGKLa3LaWF" + } + }, + { + "cell_type": "code", + "source": [ + "from pandas import Series\n", + "\n", + "model = dt\n", + "\n", + "feature_importances = Series(model.feature_importances_, features).sort_values(ascending=False)\n", + "feature_importances.name = \"encoded_feature\"\n", + "feature_importances" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NhlUStVyLcLa", + "outputId": "d67ad42c-1f44-4957-8444-2e7a20e61213" + }, + "execution_count": 15, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "b2_o 0.113211\n", + "c1_x 0.089668\n", + "c1_o 0.082508\n", + "a3_o 0.074124\n", + "b3_o 0.070662\n", + "a1_o 0.070609\n", + "c3_x 0.070545\n", + "a2_o 0.060734\n", + "c2_o 0.060618\n", + "b1_o 0.052907\n", + "a2_x 0.045175\n", + "b3_x 0.044105\n", + "b1_x 0.036198\n", + "a1_x 0.032857\n", + "c2_x 0.027928\n", + "c3_o 0.026983\n", + "a3_x 0.015001\n", + "c2_b 0.009991\n", + "a3_b 0.004616\n", + "b3_b 0.003228\n", + "a1_b 0.002885\n", + "a2_b 0.002400\n", + "c1_b 0.001539\n", + "b2_b 0.000962\n", + "c3_b 0.000549\n", + "b2_x 0.000000\n", + "b1_b 0.000000\n", + "Name: encoded_feature, dtype: float64" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ] + }, + { + "cell_type": "code", + "source": [ + "from plotly.express import bar\n", + "\n", + "top_features = feature_importances.sort_values(ascending=True)# [0:10]\n", + "\n", + "fig = bar(\n", + " x=top_features.values,\n", + " y=top_features.keys(),\n", + " orientation=\"h\", # horizontal flips x and y params below, but not labels\n", + " labels={\"x\": \"Relative Importance\", \"y\": \"Feature Name\"},\n", + " title=\"Importance of Board Features for X Player in Tic Tac Toe\",\n", + " height=750\n", + ")\n", + "fig.show()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 767 + }, + "id": "ZtdEBgXXMIjq", + "outputId": "3b6f618b-af39-411a-a194-39d2c4a9d8b3" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + " \n", + "
\n", + "\n", + "" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Decision Tree Visualization" + ], + "metadata": { + "id": "9JjUhZCYN3RO" + } + }, + { + "cell_type": "code", + "source": [ + "# https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html\n", + "from sklearn.tree import export_graphviz\n", + "\n", + "# https://pypi.org/project/graphviz/\n", + "from graphviz import Source \n", + "\n", + "#class_names = [{0:\"Not Survive\", 1:\"Survive\"}[class_name] for class_name in search.classes_] # [0, 1]\n", + "\n", + "dot_data = export_graphviz(dt,\n", + " out_file=None,\n", + " max_depth=3,\n", + " feature_names=features,\n", + " #class_names=class_names, # expects class names to be strings\n", + " impurity=False,\n", + " filled=True,\n", + " proportion=True,\n", + " rounded=True\n", + ")\n", + "graph = Source(dot_data)\n", + "png_bytes = graph.pipe(format=\"png\")\n", + "with open(\"decision_tree.png\", \"wb\") as f:\n", + " f.write(png_bytes)" + ], + "metadata": { + "id": "NBYX71tEN7Mb" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from IPython.display import Image, display\n", + "\n", + "display(Image(filename=\"decision_tree.png\"))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 344 + }, + "id": "7nz6DXjAOo0X", + "outputId": "28273b29-3a46-445c-f3a4-6431a86b5ab1" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file diff --git a/test/board_test.py b/test/board_test.py index 6f27c10..456ebc9 100644 --- a/test/board_test.py +++ b/test/board_test.py @@ -55,19 +55,19 @@ def test_winner_determination(): -def test_board_state_notation(): - - board = Board() - assert board.notation == "---------" - - board.set_square("A1", "X") - assert board.notation == "X--------" - - board.set_square("A2", "O") - assert board.notation == "X--O-----" - - board.set_square("B1", "X") - assert board.notation == "XX-O-----" - - board.set_square("B2", "O") - assert board.notation == "XX-OO----" +#def test_board_state_notation(): +# +# board = Board() +# assert board.notation == "---------" +# +# board.set_square("A1", "X") +# assert board.notation == "X--------" +# +# board.set_square("A2", "O") +# assert board.notation == "X--O-----" +# +# board.set_square("B1", "X") +# assert board.notation == "XX-O-----" +# +# board.set_square("B2", "O") +# assert board.notation == "XX-OO----"