From 72d91c44bbc5d5ab6c700069d78ce89334a081bb Mon Sep 17 00:00:00 2001 From: Micha Birklbauer Date: Fri, 30 Dec 2022 03:52:36 +0100 Subject: [PATCH] add colab version --- neuralnet-colab.ipynb | 1923 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1923 insertions(+) create mode 100644 neuralnet-colab.ipynb diff --git a/neuralnet-colab.ipynb b/neuralnet-colab.ipynb new file mode 100644 index 0000000..bd1f3b0 --- /dev/null +++ b/neuralnet-colab.ipynb @@ -0,0 +1,1923 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c84d618c-446a-47ea-ac6d-133eb91f9411", + "metadata": { + "tags": [] + }, + "source": [ + "# **Implementation of a Neural Network *\"from scratch\"* with NumPy**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2dc17e4f-82ae-4023-84aa-7b3a7f36a6ac", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "from typing import Tuple\n", + "from typing import List" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1c6ec59d-b246-4e8e-8488-f2281ec42d1d", + "metadata": {}, + "outputs": [], + "source": [ + "class LayerInitializer:\n", + " \"\"\"\n", + " Functions for layer weight initialization.\n", + " \"\"\"\n", + "\n", + " # He normal initialization\n", + " @staticmethod\n", + " def he_normal(size: Tuple[int], fan_in: int) -> np.array:\n", + " \"\"\"\n", + " HE NORMAL INITIALIZATION\n", + " Draws samples from a truncated normal distribution centered at 0 mean\n", + " with stddev = sqrt(2 / fan_in) where fan_in is the number of input\n", + " units per unit in the layer.\n", + " Parameters:\n", + " - size: Tuple[int] (rows, columns)\n", + " shape of the initialized weight matrix\n", + " - fan_in: int\n", + " number of input units per unit in the layer\n", + " Returns:\n", + " - np.array (rows, columns)\n", + " He normal initialized weight matrix\n", + " Ref:\n", + " https://arxiv.org/abs/1502.01852\n", + " \"\"\"\n", + " return np.random.normal(0, math.sqrt(2 / fan_in), size = size)\n", + "\n", + " # Glorot / Xavier normal initialization\n", + " @staticmethod\n", + " def glorot_normal(size: Tuple[int], fan_in: int, fan_out: int) -> np.array:\n", + " \"\"\"\n", + " GLOROT / XAVIER NORMAL INITIALIZATION\n", + " Draws samples from a truncated normal distribution centered at 0 mean\n", + " with stddev = sqrt(2 / (fan_in + fan_out)) where fan_in is the number of\n", + " input units per unit in the layer and fan_out is the number of output\n", + " units per unit in the layer.\n", + " Parameters:\n", + " - size: Tuple[int] (rows, columns)\n", + " shape of the initialized weight matrix\n", + " - fan_in: int\n", + " number of input units per unit in the layer\n", + " - fan_out: int\n", + " number of output units per unit in the layer\n", + " Returns:\n", + " - np.array (rows, columns)\n", + " Glorot normal initialized weight matrix\n", + " Ref:\n", + " http://proceedings.mlr.press/v9/glorot10a.html\n", + " \"\"\"\n", + " return np.random.normal(0, math.sqrt(2 / (fan_in + fan_out)), size = size)\n", + "\n", + " # Bias initialization\n", + " @staticmethod\n", + " def bias(size: Tuple[int]):\n", + " \"\"\"\n", + " BIAS INITIALIZATION\n", + " Initializes the bias vector / matrix with zeros.\n", + " Parameters:\n", + " - size: Tuple[int] (rows, columns)\n", + " shape of the initialized bias vector / matrix\n", + " Returns:\n", + " - np.array (rows, columns)\n", + " Zero initialized bias vector / matrix\n", + " Ref:\n", + " https://cs231n.github.io/neural-networks-2/\n", + " \"\"\"\n", + " return np.zeros(shape = size)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "88017f0b-882c-4195-9e20-564626be8284", + "metadata": {}, + "outputs": [], + "source": [ + "class ActivationFunctions:\n", + " \"\"\"\n", + " Layer activation functions.\n", + " \"\"\"\n", + "\n", + " # Rectified Linear Units\n", + " @staticmethod\n", + " def relu(x: np.array, derivative: bool = False) -> np.array:\n", + " \"\"\"\n", + " RECTIFIED LINEAR UNITS\n", + " ReLU activation function.\n", + " Parameters:\n", + " - x: np.array\n", + " input matrix to apply activation function to\n", + " - derivative: bool\n", + " if set to 'True' returns the derivative instead\n", + " DEFAULT: False\n", + " Returns:\n", + " - np.array (same shape as x)\n", + " activated x / derivative of x\n", + " Ref:\n", + " https://en.wikipedia.org/wiki/Rectifier_(neural_networks)\n", + " \"\"\"\n", + " if not derivative:\n", + " return np.maximum(x, 0)\n", + " else:\n", + " return np.where(x > 0, 1, 0)\n", + "\n", + " # Sigmoid activation function\n", + " @staticmethod\n", + " def sigmoid(x: np.array, derivative: bool = False) -> np.array:\n", + " \"\"\"\n", + " SIGMOID / LOGISTIC FUNCTION\n", + " Sigmoid activation function.\n", + " Parameters:\n", + " - x: np.array\n", + " input matrix to apply activation function to\n", + " - derivative: bool\n", + " if set to 'True' returns the derivative instead\n", + " DEFAULT: False\n", + " Returns:\n", + " - np.array (same shape as x)\n", + " activated x / derivative of x\n", + " Refs:\n", + " https://en.wikipedia.org/wiki/Sigmoid_function\n", + " https://en.wikipedia.org/wiki/Activation_function\n", + " \"\"\"\n", + " def f_sigmoid(x: np.array) -> np.array:\n", + " return 1 / (1 + np.exp(-x))\n", + "\n", + " if not derivative:\n", + " return f_sigmoid(x)\n", + " else:\n", + " return f_sigmoid(x) * (1 - f_sigmoid(x))\n", + "\n", + " # Softmax activation function\n", + " @staticmethod\n", + " def softmax(x: np.array, derivative: bool = False) -> np.array:\n", + " \"\"\"\n", + " SOFTMAX FUNCTION\n", + " Stable softmax activation function.\n", + " Parameters:\n", + " - x: np.array\n", + " input matrix to apply activation function to\n", + " Returns:\n", + " - np.array (same shape as x)\n", + " activated x\n", + " Refs:\n", + " https://en.wikipedia.org/wiki/Softmax_function\n", + " https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/\n", + " \"\"\"\n", + " if not derivative:\n", + " n = np.exp(x - np.max(x)) # stable softmax\n", + " d = np.sum(n, axis = 0)\n", + " return n / d\n", + " else:\n", + " raise NotImplementedError(\"Softmax derivative not implemented!\")\n", + " # https://stackoverflow.com/questions/54976533/derivative-of-softmax-function-in-python\n", + " # xr = x.reshape((-1, 1))\n", + " # return np.diagflat(x) - np.dot(xr, xr.T)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "654dc815-7e9e-46de-9f01-7d124736fbf5", + "metadata": {}, + "outputs": [], + "source": [ + "class LossFunctions:\n", + " \"\"\"\n", + " Loss functions for neural net fitting.\n", + " \"\"\"\n", + "\n", + " # binary cross entropy loss\n", + " @staticmethod\n", + " def binary_cross_entropy(y_true: np.array, y_predicted: np.array) -> np.array:\n", + " \"\"\"\n", + " BINARY CROSS ENTROPY LOSS\n", + " Cross entropy loss for binary-class classification.\n", + " L[BCE] = - p(i) * log(q(i)) - (1 - p(i)) * log(1 - q(i))\n", + " where\n", + " - p(i) is the true label\n", + " - q(i) is the predicted sigmoid probability\n", + " Parameters:\n", + " - y_true: np.array (1, sample_size)\n", + " true label vector\n", + " - y_predicted: np.array (1, sample_size)\n", + " the sigmoid probability\n", + " Returns:\n", + " - np.array (sample_size,)\n", + " loss for every given sample\n", + " Ref:\n", + " https://en.wikipedia.org/wiki/Cross_entropy\n", + " \"\"\"\n", + " losses = []\n", + " for i in range(y_true.shape[1]):\n", + " ## stable BCE\n", + " losses.append(float(-1 * (y_true[:, i] * np.log(y_predicted[:, i] + 1e-7) + (1 - y_true[:, i]) * np.log(1 - y_predicted[:, i] + 1e-7))))\n", + " ## unstable BCE\n", + " # losses.append(float(-1 * (y_true[:, i] * np.log(y_predicted[:, i]) + (1 - y_true[:, i]) * np.log(1 - y_predicted[:, i]))))\n", + " return np.array(losses)\n", + "\n", + " # categorical cross entropy loss\n", + " @staticmethod\n", + " def categorical_cross_entropy(y_true: np.array, y_predicted: np.array) -> np.array:\n", + " \"\"\"\n", + " CATEGORICAL CROSS ENTROPY LOSS\n", + " Cross entropy loss for binary- and multi-class class classification.\n", + " L[CCE] = - sum[from i = 0 to n]( p(i) * log(q(i)) )\n", + " where\n", + " - p(i) is the true label\n", + " - q(i) is the predicted softmax probability\n", + " - n is the number of classes\n", + " Parameters:\n", + " - y_true: np.array (n_classes, sample_size)\n", + " one-hot encoded true label vector\n", + " - y_predicted: np.array (n_classes, sample_size)\n", + " the softmax probabilities\n", + " Returns:\n", + " - np.array (sample_size,)\n", + " loss for every given sample\n", + " Ref:\n", + " https://en.wikipedia.org/wiki/Cross_entropy\n", + " \"\"\"\n", + " losses = []\n", + " for i in range(y_true.shape[1]):\n", + " ## stable CCE\n", + " # losses.append(float(-1 * np.sum(y_true[:, i] * np.log(y_predicted[:, i] + 1e-7))))\n", + " ## unstable CCE\n", + " losses.append(float(-1 * np.sum(y_true[:, i] * np.log(y_predicted[:, i]))))\n", + "\n", + " return np.array(losses)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6c6c347f-a0d8-4494-b8cf-f29295f42877", + "metadata": {}, + "outputs": [], + "source": [ + "class NeuralNetwork:\n", + " \"\"\"\n", + " Implementation of a classic feed-forward neural network that is trained via\n", + " backpropagation. Adopts a Keras-like interface for convenient usage (see\n", + " https://michabirklbauer.github.io/neuralnet for examples).\n", + " \"\"\"\n", + "\n", + " # constructor\n", + " def __init__(self, input_size: int):\n", + " \"\"\"\n", + " CONSTRUCTOR\n", + " Initializes the neural network model.\n", + " Parameters:\n", + " - input_size: int\n", + " nr. of features in the training data\n", + " Returns:\n", + " - None\n", + " Example usage:\n", + " NN = NeuralNetwork(data.shape[1])\n", + " \"\"\"\n", + " self.input_size = input_size\n", + " self.architecture = []\n", + " self.layers = []\n", + "\n", + " # adding layers\n", + " def add_layer(self, units: int, activation: str = \"relu\", initialization: str = None) -> None:\n", + " \"\"\"\n", + " LAYER MANAGEMENT\n", + " Construct the neural network architecture by adding different layers.\n", + " Parameters:\n", + " - units: int\n", + " nr. of units in the layer\n", + " - activation: str, one of (\"relu\", \"sigmoid\", \"softmax\")\n", + " activation function of the layer\n", + " DEFAULT: \"relu\"\n", + " - initialization: str, one of (\"he\", \"glorot\")\n", + " weight initialization to use\n", + " DEFAULT: None, \"relu\" layers are 'he normal' initialized,\n", + " all other layers are 'glorot normal'\n", + " initialized\n", + " Returns:\n", + " - None\n", + " Example usage:\n", + " NN = NeuralNetwork(data.shape[1])\n", + " NN.add_layer(16, \"relu\", \"glorot\")\n", + " NN.add_layer(8)\n", + " NN.add_layer(1, \"sigmoid\")\n", + " \"\"\"\n", + " if initialization == None:\n", + " if activation == \"relu\":\n", + " layer_init = \"he\"\n", + " else:\n", + " layer_init = \"glorot\"\n", + " else:\n", + " layer_init = initialization\n", + "\n", + " self.architecture.append({\"units\": units, \"activation\": activation, \"init\": layer_init})\n", + "\n", + " # compiling model\n", + " def compile(self, loss: str = \"categorical crossentropy\") -> None:\n", + " \"\"\"\n", + " MODEL INITIALIZATION\n", + " Initializes all parameters of the neural network architecture and\n", + " prepares the model for training.\n", + " Parameters:\n", + " - loss: str, one of (\"binary crossentropy\", \"categorical crossentropy\")\n", + " the loss function that should be used for training\n", + " DEFAULT: \"categorical crossentropy\"\n", + " Returns:\n", + " - None\n", + " Example usage:\n", + " NN = NeuralNetwork(data.shape[1])\n", + " NN.add_layer(16, \"relu\", \"glorot\")\n", + " NN.add_layer(8)\n", + " NN.add_layer(1, \"sigmoid\")\n", + " NN.compile(\"binary crossentropy\")\n", + " \"\"\"\n", + " self.loss = loss\n", + "\n", + " # initialize all layer weights and biases\n", + " for i in range(len(self.architecture)):\n", + " units = self.architecture[i][\"units\"]\n", + " activation = self.architecture[i][\"activation\"]\n", + " init = self.architecture[i][\"init\"]\n", + "\n", + " units_previous_layer = self.input_size\n", + " if i > 0:\n", + " units_previous_layer = self.architecture[i - 1][\"units\"]\n", + " units_next_layer = 0\n", + " if i < len(self.architecture) - 1:\n", + " units_next_layer = self.architecture[i + 1][\"units\"]\n", + "\n", + " if init == \"he\":\n", + " W = LayerInitializer.he_normal((units, units_previous_layer), fan_in = units_previous_layer)\n", + " b = LayerInitializer.bias((units, 1))\n", + " elif init == \"glorot\":\n", + " W = LayerInitializer.glorot_normal((units, units_previous_layer), fan_in = units_previous_layer, fan_out = units_next_layer)\n", + " b = LayerInitializer.bias((units, 1))\n", + " else:\n", + " raise NotImplementedError(\"Layer initialization '\" + init + \"' not implemented!\")\n", + "\n", + " self.layers.append({\"W\": W, \"b\": b, \"activation\": activation})\n", + "\n", + " # forward propagation\n", + " def __forward_propagation(self, data: np.array) -> None:\n", + " \"\"\"\n", + " FORWARD PROPAGATION (INTERNAL)\n", + " Internal function calculating the forward pass of A(Wx + b).\n", + " - The result of 'Wx + b' (L) is stored in self.layers[layer][\"L\"]\n", + " - The result of 'Activation(L)' (A) is stored in self.layers[layer][\"A\"]\n", + " Parameters:\n", + " - data: np.array\n", + " input data for the forward pass\n", + " Returns:\n", + " - None, \"L\" and \"A\" are set in the layer dictionary, to retrieve the\n", + " last layer output call 'self.layers[-1][\"A\"]'\n", + " \"\"\"\n", + "\n", + " for i in range(len(self.layers)):\n", + "\n", + " if i == 0:\n", + " A = data\n", + " else:\n", + " A = self.layers[i - 1][\"A\"]\n", + "\n", + " # Wx + b where x is the input data for the first layer and otherwise\n", + " # the output (A) of the previous layer\n", + " self.layers[i][\"L\"] = self.layers[i][\"W\"].dot(A) + self.layers[i][\"b\"]\n", + " if self.layers[i][\"activation\"] == \"relu\":\n", + " self.layers[i][\"A\"] = ActivationFunctions.relu(self.layers[i][\"L\"])\n", + " elif self.layers[i][\"activation\"] == \"sigmoid\":\n", + " self.layers[i][\"A\"] = ActivationFunctions.sigmoid(self.layers[i][\"L\"])\n", + " elif self.layers[i][\"activation\"] == \"softmax\":\n", + " self.layers[i][\"A\"] = ActivationFunctions.softmax(self.layers[i][\"L\"])\n", + " else:\n", + " raise NotImplementedError(\"Activation function '\" + layer[\"activation\"] + \"' not implemented!\")\n", + "\n", + " # back propagation\n", + " def __back_propagation(self, data: np.array, target: np.array, learning_rate: float = 0.1) -> float:\n", + " \"\"\"\n", + " BACK PROPAGATION (INTERNAL)\n", + " Internal function for learning layer weights and biases using gradient\n", + " descent and back propagation.\n", + " Parameters:\n", + " - data: np.array\n", + " input data\n", + " - target: np.array\n", + " class labels of the input data\n", + " - learning_rate: float\n", + " learning rate / how far in the direction of the gradient to\n", + " go\n", + " DEFAULT: 0.1\n", + " Returns:\n", + " - float\n", + " loss of the current forward pass\n", + " \"\"\"\n", + " # forward pass\n", + " self.__forward_propagation(data)\n", + "\n", + " output = self.layers[-1][\"A\"]\n", + " batch_size = data.shape[1]\n", + " loss = 0\n", + "\n", + " # calculate loss of the current forward pass\n", + " if self.loss == \"categorical crossentropy\":\n", + " losses = LossFunctions.categorical_cross_entropy(y_true = target, y_predicted = output)\n", + " # reduction by sum over batch size\n", + " loss = float(np.sum(losses) / batch_size)\n", + " elif self.loss == \"binary crossentropy\":\n", + " losses = LossFunctions.binary_cross_entropy(y_true = target, y_predicted = output)\n", + " # reduction by sum over batch size\n", + " loss = float(np.sum(losses) / batch_size)\n", + " else:\n", + " raise NotImplementedError(\"Loss function '\" + self.loss + \"' not implemented!\")\n", + "\n", + " # calculate and back pass the derivate of the loss w.r.t the output\n", + " # activation function\n", + " # this implementation suppports CCE + Softmax and BCE + Sigmoid in the\n", + " # output layer\n", + " if self.loss == \"categorical crossentropy\" and self.layers[-1][\"activation\"] == \"softmax\":\n", + " # for categorical cross entropy loss the derivative of softmax simplifies to\n", + " # P(i) - Y(i)\n", + " # where P(i) is the softmax output and Y(i) is the true label\n", + " # https://www.ics.uci.edu/~pjsadows/notes.pdf\n", + " # https://math.stackexchange.com/questions/945871/derivative-of-softmax-loss-function\n", + " previous_layer_activation = data.T if len(self.layers) == 1 else self.layers[len(self.layers) - 2][\"A\"].T\n", + " dL = self.layers[-1][\"A\"] - target\n", + " dW = dL.dot(previous_layer_activation) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + "\n", + " # parameter tracking\n", + " previous_dL = np.copy(dL)\n", + " previous_W = np.copy(self.layers[-1][\"W\"])\n", + "\n", + " # update\n", + " self.layers[-1][\"W\"] -= learning_rate * dW\n", + " self.layers[-1][\"b\"] -= learning_rate * db\n", + " elif self.loss == \"binary crossentropy\" and self.layers[-1][\"activation\"] == \"sigmoid\":\n", + " # for binary cross entropy loss the derivative of the loss function is\n", + " # L' = -1 * (Y(i) / P(i) - (1 - Y(i)) / (1 - P(i)))\n", + " # where P(i) is the sigmoid output and Y(i) is the true label\n", + " # and we multiply that with the derivative of the sigmoid function [1]\n", + " # https://math.stackexchange.com/questions/2503428/derivative-of-binary-cross-entropy-why-are-my-signs-not-right\n", + " previous_layer_activation = data.T if len(self.layers) == 1 else self.layers[len(self.layers) - 2][\"A\"].T\n", + " # [1]\n", + " # A = np.clip(self.layers[-1][\"A\"], 1e-7, 1 - 1e-7)\n", + " # derivative_loss = -1 * np.divide(target, A) + np.divide(1 - target, 1 - A)\n", + " # dL = derivative_loss * ActivationFunctions.sigmoid(self.layers[-1][\"L\"], derivative = True)\n", + " # alternatively we can directly simplify the derivative of the binary cross entropy loss\n", + " # with sigmoid activation function to\n", + " # P(i) - Y(i)\n", + " # where P(i) is the sigmoid output and Y(i) is the true label\n", + " # done in [2]\n", + " # https://math.stackexchange.com/questions/4227931/what-is-the-derivative-of-binary-cross-entropy-loss-w-r-t-to-input-of-sigmoid-fu\n", + " # [2]\n", + " dL = (self.layers[-1][\"A\"] - target) / batch_size\n", + " dW = dL.dot(previous_layer_activation) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + "\n", + " # parameter tracking\n", + " previous_dL = np.copy(dL)\n", + " previous_W = np.copy(self.layers[-1][\"W\"])\n", + "\n", + " # update\n", + " self.layers[-1][\"W\"] -= learning_rate * dW\n", + " self.layers[-1][\"b\"] -= learning_rate * db\n", + " else:\n", + " raise NotImplementedError(\"The combination of '\" + self.loss + \" loss' and '\" + self.layers[i][\"activation\"] + \" activation' is not implemented!\")\n", + "\n", + " # back propagation through the remaining hidden layers\n", + " for i in reversed(range(len(self.layers) - 1)):\n", + "\n", + " if i == 0:\n", + " if self.layers[i][\"activation\"] == \"relu\":\n", + " dL = previous_W.T.dot(previous_dL) * ActivationFunctions.relu(self.layers[i][\"L\"], derivative = True)\n", + " dW = dL.dot(data.T) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + " elif self.layers[i][\"activation\"] == \"sigmoid\":\n", + " dL = previous_W.T.dot(previous_dL) * ActivationFunctions.sigmoid(self.layers[i][\"L\"], derivative = True)\n", + " dW = dL.dot(data.T) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + " else:\n", + " raise NotImplementedError(\"Activation function '\" + self.layers[i][\"activation\"] + \"' not implemented for hidden layers!\")\n", + "\n", + " # parameter tracking\n", + " previous_dL = np.copy(dL)\n", + " previous_W = np.copy(self.layers[i][\"W\"])\n", + "\n", + " #update\n", + " self.layers[i][\"W\"] -= learning_rate * dW\n", + " self.layers[i][\"b\"] -= learning_rate * db\n", + " else:\n", + " if self.layers[i][\"activation\"] == \"relu\":\n", + " dL = previous_W.T.dot(previous_dL) * ActivationFunctions.relu(self.layers[i][\"L\"], derivative = True)\n", + " dW = dL.dot(self.layers[i - 1][\"A\"].T) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + " elif self.layers[i][\"activation\"] == \"sigmoid\":\n", + " dL = previous_W.T.dot(previous_dL) * ActivationFunctions.sigmoid(self.layers[i][\"L\"], derivative = True)\n", + " dW = dL.dot(self.layers[i - 1][\"A\"].T) / batch_size\n", + " db = np.reshape(np.sum(dL, axis = 1), (-1, 1)) / batch_size\n", + " else:\n", + " raise NotImplementedError(\"Activation function '\" + self.layers[i][\"activation\"] + \"' not implemented for hidden layers!\")\n", + "\n", + " # parameter tracking\n", + " previous_dL = np.copy(dL)\n", + " previous_W = np.copy(self.layers[i][\"W\"])\n", + "\n", + " #update\n", + " self.layers[i][\"W\"] -= learning_rate * dW\n", + " self.layers[i][\"b\"] -= learning_rate * db\n", + "\n", + " return loss\n", + "\n", + " # neural network architecture summary\n", + " def summary(self) -> None:\n", + " \"\"\"\n", + " MODEL SUMMARY\n", + " Print a summary of the neural network architecture.\n", + " Parameters:\n", + " - None\n", + " Returns:\n", + " - None, prints a summary of the neural network architecture to\n", + " stdout\n", + " Example usage:\n", + " NN.summary()\n", + " \"\"\"\n", + " print(\"---- Model Summary ----\")\n", + " for i, layer in enumerate(self.layers):\n", + " print(\"Layer \" + str(i + 1) + \": \" + layer[\"activation\"])\n", + " if \"L\" in layer:\n", + " print(\"W: \" + str(layer[\"W\"].shape) + \" \" +\n", + " \"b: \" + str(layer[\"b\"].shape) + \" \" +\n", + " \"L: \" + str(layer[\"L\"].shape) + \" \" +\n", + " \"A: \" + str(layer[\"A\"].shape))\n", + " else:\n", + " print(\"W: \" + str(layer[\"W\"].shape) + \" \" +\n", + " \"b: \" + str(layer[\"b\"].shape))\n", + " print(\"Trainable parameters: \" + str(\n", + " layer[\"W\"].shape[0] * layer[\"W\"].shape[1] +\n", + " layer[\"b\"].shape[0] * layer[\"b\"].shape[1]))\n", + "\n", + " # train neural network on data\n", + " def fit(self, X: np.array, y: np.array, epochs: int = 100, batch_size: int = 32, learning_rate: float = 0.1, verbose: int = 1) -> List[float]:\n", + " \"\"\"\n", + " TRAIN MODEL\n", + " Train the neural network.\n", + " Parameters:\n", + " - X: np.array (samples, features)\n", + " input data to train on\n", + " - y: np.array (samples, labels) or (labels,)\n", + " labels of the input data\n", + " - epochs: int\n", + " how many iterations to train\n", + " DEFAULT: 100\n", + " - batch_size: int\n", + " how many samples to use per backward pass\n", + " DEFAULT: 32\n", + " - learning_rate: float\n", + " learning rate / how far in the direction of the gradient to\n", + " go\n", + " DEFAULT: 0.1\n", + " - verbose: int, one of (0, 1) / bool\n", + " print information for every epoch\n", + " DEFAULT: 1 (True)\n", + " Returns:\n", + " - List[float]\n", + " loss history over all epochs\n", + " Example usage:\n", + " NN.fit(data_train, labels_train)\n", + " \"\"\"\n", + " # reshaping inputs\n", + " if y.ndim == 1:\n", + " y = np.reshape(y, (-1, 1))\n", + "\n", + " data = X.T\n", + " target = y.T\n", + " sample_size = data.shape[1]\n", + "\n", + " history = []\n", + "\n", + " # train network\n", + " for i in range(epochs):\n", + " if verbose:\n", + " print(\"Training epoch \" + str(i + 1) + \"...\")\n", + " # generate random batches of size batch_size\n", + " idx = np.random.choice(sample_size, sample_size, replace = False)\n", + " batches = np.array_split(idx, math.ceil(sample_size / batch_size))\n", + " batch_losses = []\n", + " for batch in batches:\n", + " current_data = data[:, batch]\n", + " current_target = target[:, batch]\n", + " batch_loss = self.__back_propagation(current_data, current_target, learning_rate = learning_rate)\n", + " batch_losses.append(batch_loss)\n", + " history.append(np.mean(batch_losses))\n", + " if verbose:\n", + " print(\"Current loss: \", np.mean(batch_losses))\n", + " print(\"Epoch \" + str(i + 1) + \" done!\")\n", + "\n", + " print(\"Training finished after epoch \" + str(epochs) + \" with a loss of \" + str(history[-1]) + \".\")\n", + "\n", + " return history\n", + "\n", + " # predict data with fitted neural network\n", + " def predict(self, X: np.array) -> np.array:\n", + " \"\"\"\n", + " GENERATE PREDICTIONS\n", + " Predict labels for the given input data.\n", + " Parameters:\n", + " - X: np.array (samples, features) or (features,)\n", + " input data to predict\n", + " Returns:\n", + " - np.array\n", + " predictions\n", + " Example usage:\n", + " NN.predict(data_test)\n", + " \"\"\"\n", + " if X.ndim == 1:\n", + " X = np.reshape(X, (1, -1))\n", + "\n", + " self.__forward_propagation(X.T)\n", + "\n", + " return self.layers[-1][\"A\"].T" + ] + }, + { + "cell_type": "markdown", + "id": "a60f041d-d688-4b00-8bc1-3e01da0d947f", + "metadata": { + "tags": [] + }, + "source": [ + "# **Example Usage of `neuralnet.py / class NeuralNetwork`**\n", + "\n", + "### **Multi-Class Classification**\n", + "\n", + "### **Dataset: [MNIST](http://yann.lecun.com/exdb/mnist/index.html)**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb8e53f6-c5bf-4221-8d49-f3bc804b438d", + "metadata": {}, + "outputs": [], + "source": [ + "!wget https://raw.githubusercontent.com/michabirklbauer/neuralnet/master/data.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c7a8280c-d0b9-41d5-88e1-993db76a73b4", + "metadata": {}, + "outputs": [], + "source": [ + "from zipfile import ZipFile as zip\n", + "\n", + "with zip(\"data.zip\") as f:\n", + " f.extractall()\n", + " f.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "579b7aa9-24c5-4dd1-b8b7-719cbb1f7b09", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a8e4f9b1-140c-42ac-9b04-11e23b27d1eb", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv(\"multiclass_train.csv\")\n", + "train, test = train_test_split(data, test_size = 0.3)\n", + "train_data = train.loc[:, train.columns != \"label\"].to_numpy() / 255\n", + "train_target = train[\"label\"].to_numpy()\n", + "test_data = test.loc[:, test.columns != \"label\"].to_numpy() / 255\n", + "test_target = test[\"label\"].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f9a8ba9c-7255-40b3-9e5f-f999e89eb257", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8...pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
251647000000000...0000000000
119049000000000...0000000000
378331000000000...0000000000
61015000000000...0000000000
250193000000000...0000000000
..................................................................
213907000000000...0000000000
76013000000000...0000000000
2241000000000...0000000000
375824000000000...0000000000
129262000000000...0000000000
\n", + "

29400 rows × 785 columns

\n", + "
" + ], + "text/plain": [ + " label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 \\\n", + "25164 7 0 0 0 0 0 0 0 0 \n", + "11904 9 0 0 0 0 0 0 0 0 \n", + "37833 1 0 0 0 0 0 0 0 0 \n", + "6101 5 0 0 0 0 0 0 0 0 \n", + "25019 3 0 0 0 0 0 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... \n", + "21390 7 0 0 0 0 0 0 0 0 \n", + "7601 3 0 0 0 0 0 0 0 0 \n", + "224 1 0 0 0 0 0 0 0 0 \n", + "37582 4 0 0 0 0 0 0 0 0 \n", + "12926 2 0 0 0 0 0 0 0 0 \n", + "\n", + " pixel8 ... pixel774 pixel775 pixel776 pixel777 pixel778 \\\n", + "25164 0 ... 0 0 0 0 0 \n", + "11904 0 ... 0 0 0 0 0 \n", + "37833 0 ... 0 0 0 0 0 \n", + "6101 0 ... 0 0 0 0 0 \n", + "25019 0 ... 0 0 0 0 0 \n", + "... ... ... ... ... ... ... ... \n", + "21390 0 ... 0 0 0 0 0 \n", + "7601 0 ... 0 0 0 0 0 \n", + "224 0 ... 0 0 0 0 0 \n", + "37582 0 ... 0 0 0 0 0 \n", + "12926 0 ... 0 0 0 0 0 \n", + "\n", + " pixel779 pixel780 pixel781 pixel782 pixel783 \n", + "25164 0 0 0 0 0 \n", + "11904 0 0 0 0 0 \n", + "37833 0 0 0 0 0 \n", + "6101 0 0 0 0 0 \n", + "25019 0 0 0 0 0 \n", + "... ... ... ... ... ... \n", + "21390 0 0 0 0 0 \n", + "7601 0 0 0 0 0 \n", + "224 0 0 0 0 0 \n", + "37582 0 0 0 0 0 \n", + "12926 0 0 0 0 0 \n", + "\n", + "[29400 rows x 785 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6d2c5098-ba6b-4533-be14-a7e1395c944b", + "metadata": {}, + "outputs": [], + "source": [ + "one_hot = OneHotEncoder(sparse = False, categories = \"auto\")\n", + "train_target = one_hot.fit_transform(train_target.reshape(-1, 1))\n", + "test_target = one_hot.transform(test_target.reshape(-1, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7417c6e4-fd30-498c-a4de-5657ffb0e5f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---- Model Summary ----\n", + "Layer 1: relu\n", + "W: (32, 784) b: (32, 1)\n", + "Trainable parameters: 25120\n", + "Layer 2: relu\n", + "W: (16, 32) b: (16, 1)\n", + "Trainable parameters: 528\n", + "Layer 3: softmax\n", + "W: (10, 16) b: (10, 1)\n", + "Trainable parameters: 170\n" + ] + } + ], + "source": [ + "NN = NeuralNetwork(input_size = train_data.shape[1])\n", + "NN.add_layer(32, \"relu\")\n", + "NN.add_layer(16, \"relu\")\n", + "NN.add_layer(10, \"softmax\")\n", + "NN.compile(loss = \"categorical crossentropy\")\n", + "NN.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bc46f780-ff80-43ec-8eae-0e31ecd39a30", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training epoch 1...\n", + "Current loss: 0.43747370824596604\n", + "Epoch 1 done!\n", + "Training epoch 2...\n", + "Current loss: 0.21528007156258966\n", + "Epoch 2 done!\n", + "Training epoch 3...\n", + "Current loss: 0.16742503623911392\n", + "Epoch 3 done!\n", + "Training epoch 4...\n", + "Current loss: 0.13877936553368508\n", + "Epoch 4 done!\n", + "Training epoch 5...\n", + "Current loss: 0.12099309045421619\n", + "Epoch 5 done!\n", + "Training epoch 6...\n", + "Current loss: 0.1072971880634624\n", + "Epoch 6 done!\n", + "Training epoch 7...\n", + "Current loss: 0.09396355017990504\n", + "Epoch 7 done!\n", + "Training epoch 8...\n", + "Current loss: 0.08720308198194518\n", + "Epoch 8 done!\n", + "Training epoch 9...\n", + "Current loss: 0.07927159779935378\n", + "Epoch 9 done!\n", + "Training epoch 10...\n", + "Current loss: 0.07284107143112058\n", + "Epoch 10 done!\n", + "Training epoch 11...\n", + "Current loss: 0.06600162705624461\n", + "Epoch 11 done!\n", + "Training epoch 12...\n", + "Current loss: 0.06342602649693302\n", + "Epoch 12 done!\n", + "Training epoch 13...\n", + "Current loss: 0.05783998850656874\n", + "Epoch 13 done!\n", + "Training epoch 14...\n", + "Current loss: 0.05052314129523882\n", + "Epoch 14 done!\n", + "Training epoch 15...\n", + "Current loss: 0.04563600268741524\n", + "Epoch 15 done!\n", + "Training epoch 16...\n", + "Current loss: 0.04470639462592896\n", + "Epoch 16 done!\n", + "Training epoch 17...\n", + "Current loss: 0.043506537043299306\n", + "Epoch 17 done!\n", + "Training epoch 18...\n", + "Current loss: 0.03815045738567615\n", + "Epoch 18 done!\n", + "Training epoch 19...\n", + "Current loss: 0.038454017529732515\n", + "Epoch 19 done!\n", + "Training epoch 20...\n", + "Current loss: 0.034033571538281876\n", + "Epoch 20 done!\n", + "Training epoch 21...\n", + "Current loss: 0.03033063122611392\n", + "Epoch 21 done!\n", + "Training epoch 22...\n", + "Current loss: 0.02789381646483783\n", + "Epoch 22 done!\n", + "Training epoch 23...\n", + "Current loss: 0.02688368926764838\n", + "Epoch 23 done!\n", + "Training epoch 24...\n", + "Current loss: 0.02944480698302673\n", + "Epoch 24 done!\n", + "Training epoch 25...\n", + "Current loss: 0.02519994251217897\n", + "Epoch 25 done!\n", + "Training epoch 26...\n", + "Current loss: 0.02679484096626338\n", + "Epoch 26 done!\n", + "Training epoch 27...\n", + "Current loss: 0.01805071452172742\n", + "Epoch 27 done!\n", + "Training epoch 28...\n", + "Current loss: 0.021675299545706767\n", + "Epoch 28 done!\n", + "Training epoch 29...\n", + "Current loss: 0.027434799817775905\n", + "Epoch 29 done!\n", + "Training epoch 30...\n", + "Current loss: 0.024449728356841036\n", + "Epoch 30 done!\n", + "Training finished after epoch 30 with a loss of 0.024449728356841036.\n" + ] + } + ], + "source": [ + "hist = NN.fit(train_data, train_target, epochs = 30, batch_size = 16, learning_rate = 0.05)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "5d833848-9d24-47b3-b690-d736a50ebe4c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAHHCAYAAABdm0mZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABFGUlEQVR4nO3deXxU9b3/8fdMlsm+hywQCJusklSQGBWXElmkCAqKy62Ua/Wn4kKpfVTqLai9Xqy1liqK1dalrYrFvVRxieAaRYGgKKLsgeyJZCXbzPf3R8xgBHQIM3OSyev5eJxHMmfOzHzm9GDe/Z7vYjPGGAEAAAQYu9UFAAAA+AIhBwAABCRCDgAACEiEHAAAEJAIOQAAICARcgAAQEAi5AAAgIBEyAEAAAGJkAMAAAISIQdAt2Sz2XTrrbce8+t2794tm82mxx57zOs1AehZCDkAjuqxxx6TzWaTzWbTu+++e9jzxhhlZGTIZrPpJz/5iQUVdt26detks9n0zDPPWF0KAB8h5AD4QWFhYXryyScP2//WW29p3759cjgcFlQFAN+PkAPgB5177rlatWqV2traOu1/8sknNXbsWKWmplpUGQAcHSEHwA+65JJLVFVVpddff929r6WlRc8884wuvfTSI76moaFBv/zlL5WRkSGHw6Fhw4bp7rvvljGm03HNzc36xS9+oeTkZEVHR+u8887Tvn37jvie+/fv13//938rJSVFDodDo0aN0iOPPOK9L3oEO3fu1IUXXqiEhARFRETolFNO0X/+85/Djrvvvvs0atQoRUREKD4+XuPGjevU+lVXV6cFCxYoMzNTDodDffr00TnnnKONGzf6tH6gNyPkAPhBmZmZys3N1VNPPeXe98orr6impkYXX3zxYccbY3TeeefpT3/6k6ZMmaJ77rlHw4YN069+9SstXLiw07E///nPtWzZMk2aNEl33nmnQkJCNG3atMPes6ysTKeccoreeOMNXXfddfrzn/+sIUOG6IorrtCyZcu8/p07PvPUU0/Vq6++qmuvvVZ33HGHmpqadN555+n55593H/fwww/rhhtu0MiRI7Vs2TLddtttys7O1ocffug+5uqrr9aKFSs0a9YsPfDAA7rpppsUHh6urVu3+qR2AJIMABzFo48+aiSZjz76yCxfvtxER0ebxsZGY4wxF154oTn77LONMcYMGDDATJs2zf26F154wUgy//u//9vp/WbPnm1sNpvZvn27McaYwsJCI8lce+21nY679NJLjSSzZMkS974rrrjCpKWlmcrKyk7HXnzxxSY2NtZd165du4wk8+ijj37vd1u7dq2RZFatWnXUYxYsWGAkmXfeece9r66uzgwcONBkZmYap9NpjDFmxowZZtSoUd/7ebGxsWb+/PnfewwA76IlB4BHLrroIh08eFCrV69WXV2dVq9efdRbVS+//LKCgoJ0ww03dNr/y1/+UsYYvfLKK+7jJB123IIFCzo9Nsbo2Wef1fTp02WMUWVlpXubPHmyampqfHLb5+WXX9b48eN1+umnu/dFRUXpqquu0u7du/X5559LkuLi4rRv3z599NFHR32vuLg4ffjhhyouLvZ6nQCOjJADwCPJycnKy8vTk08+qeeee05Op1OzZ88+4rF79uxRenq6oqOjO+0fMWKE+/mOn3a7XYMHD+503LBhwzo9rqio0IEDB/TQQw8pOTm50zZv3jxJUnl5uVe+53e/x3drOdL3+PWvf62oqCiNHz9eQ4cO1fz58/Xee+91es1dd92lLVu2KCMjQ+PHj9ett96qnTt3er1mAIcEW10AgJ7j0ksv1ZVXXqnS0lJNnTpVcXFxfvlcl8slSfqv//ovzZ0794jHjBkzxi+1HMmIESO0bds2rV69WmvWrNGzzz6rBx54QIsXL9Ztt90mqb0lbMKECXr++ef12muv6Q9/+IN+//vf67nnntPUqVMtqx0IZLTkAPDY+eefL7vdrg8++OCot6okacCAASouLlZdXV2n/V988YX7+Y6fLpdLO3bs6HTctm3bOj3uGHnldDqVl5d3xK1Pnz7e+IqHfY/v1nKk7yFJkZGRmjNnjh599FHt3btX06ZNc3dU7pCWlqZrr71WL7zwgnbt2qXExETdcccdXq8bQDtCDgCPRUVFacWKFbr11ls1ffr0ox537rnnyul0avny5Z32/+lPf5LNZnO3XHT8vPfeezsd993RUkFBQZo1a5aeffZZbdmy5bDPq6io6MrX+UHnnnuu1q9fr4KCAve+hoYGPfTQQ8rMzNTIkSMlSVVVVZ1eFxoaqpEjR8oYo9bWVjmdTtXU1HQ6pk+fPkpPT1dzc7NPagfA7SoAx+hot4u+bfr06Tr77LN1yy23aPfu3crKytJrr72mF198UQsWLHD3wcnOztYll1yiBx54QDU1NTr11FOVn5+v7du3H/aed955p9auXaucnBxdeeWVGjlypKqrq7Vx40a98cYbqq6u7tL3efbZZ90tM9/9njfffLOeeuopTZ06VTfccIMSEhL0+OOPa9euXXr22Wdlt7f//8RJkyYpNTVVp512mlJSUrR161YtX75c06ZNU3R0tA4cOKB+/fpp9uzZysrKUlRUlN544w199NFH+uMf/9ilugF4wNrBXQC6s28PIf8+3x1Cbkz7UOtf/OIXJj093YSEhJihQ4eaP/zhD8blcnU67uDBg+aGG24wiYmJJjIy0kyfPt0UFRUdNoTcGGPKysrM/PnzTUZGhgkJCTGpqalm4sSJ5qGHHnIfc6xDyI+2dQwb37Fjh5k9e7aJi4szYWFhZvz48Wb16tWd3usvf/mLOeOMM0xiYqJxOBxm8ODB5le/+pWpqakxxhjT3NxsfvWrX5msrCwTHR1tIiMjTVZWlnnggQe+t0YAx8dmzHemHwUAAAgA9MkBAAABiZADAAACEiEHAAAEJEIOAAAISIQcAAAQkAg5AAAgIPW6yQBdLpeKi4sVHR0tm81mdTkAAMADxhjV1dUpPT3dPRHnD+l1Iae4uFgZGRlWlwEAALqgqKhI/fr18+jYXhdyoqOjJbWfpJiYGIurAQAAnqitrVVGRob777gnel3I6bhFFRMTQ8gBAKCHOZauJnQ8BgAAAYmQAwAAAhIhBwAABCRCDgAACEiEHAAAEJAIOQAAICARcgAAQEAi5AAAgIBEyAEAAAGJkAMAAAISIQcAAAQkQg4AAAhIhBwvcbqMymubtKeqwepSAACACDle88HOKo3/v3xd+fePrS4FAACIkOM1iVGhkqTK+haLKwEAABIhx2uSohySpK8bW9TmdFlcDQAAIOR4SXxEqGw2yRipupHWHAAArEbI8ZIgu00JEe23rKq4ZQUAgOUIOV7UccuKkAMAgPUIOV50qPNxs8WVAAAAQo4XdbTkEHIAALAeIceLOlpyqhq4XQUAgNUIOV7kbsmpoyUHAACrEXK8KImWHAAAug1CjhclRnaMrqIlBwAAqxFyvIilHQAA6D4IOV707dFVxhiLqwEAoHcj5HhRR0tOc5tL9c1tFlcDAEDvRsjxoojQYEWEBkli1mMAAKxGyPEy99IODXQ+BgDASoQcL+u4ZVVRR0sOAABWIuR4GS05AAB0D4QcL3NPCEifHAAALEXI8bKOCQFZpBMAAGsRcryMlhwAALoHQo6XJUbRkgMAQHdAyPGyQ0s7EHIAALASIcfLkt2jq7hdBQCAlQg5XtZxu+pAY6tanS6LqwEAoPci5HhZXHiIguw2SVI1rTkAAFimW4Sc+++/X5mZmQoLC1NOTo7Wr1/v0etWrlwpm82mmTNn+rbAY2C325QQSb8cAACsZnnIefrpp7Vw4UItWbJEGzduVFZWliZPnqzy8vLvfd3u3bt10003acKECX6q1HOJ7pBDSw4AAFaxPOTcc889uvLKKzVv3jyNHDlSDz74oCIiIvTII48c9TVOp1OXXXaZbrvtNg0aNMiP1XomOfqbzse05AAAYBlLQ05LS4s2bNigvLw89z673a68vDwVFBQc9XW33367+vTpoyuuuOIHP6O5uVm1tbWdNl/raMlhQkAAAKxjaciprKyU0+lUSkpKp/0pKSkqLS094mveffdd/e1vf9PDDz/s0WcsXbpUsbGx7i0jI+O46/4hTAgIAID1LL9ddSzq6ur005/+VA8//LCSkpI8es2iRYtUU1Pj3oqKinxc5aGVyOmTAwCAdYKt/PCkpCQFBQWprKys0/6ysjKlpqYedvyOHTu0e/duTZ8+3b3P5WqfiyY4OFjbtm3T4MGDO73G4XDI4XD4oPqj65j1uKqBlhwAAKxiaUtOaGioxo4dq/z8fPc+l8ul/Px85ebmHnb88OHD9emnn6qwsNC9nXfeeTr77LNVWFjol1tRnkhiaQcAACxnaUuOJC1cuFBz587VuHHjNH78eC1btkwNDQ2aN2+eJOnyyy9X3759tXTpUoWFhWn06NGdXh8XFydJh+23UsftKjoeAwBgHctDzpw5c1RRUaHFixertLRU2dnZWrNmjbsz8t69e2W396iuQ+6Ox1X1LTLGyGazWVwRAAC9j80YY6wuwp9qa2sVGxurmpoaxcTE+OQzmlqdGv7bNZKkzUsmKTY8xCefAwBAb9GVv989q4mkhwgLCVK0o72RjAkBAQCwBiHHRxKjWNoBAAArEXJ85FDnY1pyAACwAiHHR9wtOQ205AAAYAVCjo+4l3aooyUHAAArEHJ8xH27ilmPAQCwBCHHRzpmPWZCQAAArEHI8ZHESFYiBwDASoQcH6ElBwAAaxFyfMTd8ZiWHAAALEHI8ZHkb0JObVObmtucFlcDAEDvQ8jxkZjwYAXb2xfmrGauHAAA/I6Q4yM2m+3QhIB1hBwAAPyNkONDHXPlVDJXDgAAfkfI8aFE9/pVtOQAAOBvhBwfSorsWImclhwAAPyNkONDSdGsRA4AgFUIOT6UGMmEgAAAWIWQ40MdfXIqaMkBAMDvCDk+xNIOAABYh5DjQx1DyKsYQg4AgN8Rcnwo6VtDyF0uY3E1AAD0LoQcH0r4puNxm8uotqnV4moAAOhdCDk+FBpsV0xYsCTmygEAwN8IOT7WMVdOJZ2PAQDwK0KOjyVFsrQDAABWIOT4mHslcm5XAQDgV4QcHzs0woqQAwCAPxFyfMzdktPA7SoAAPyJkONjHUs7VNbRkgMAgD8RcnwsuWNpB1pyAADwK0KOjyXSJwcAAEsQcnyso+Mx8+QAAOBfhBwf6+h4XN/cpqZWp8XVAADQexByfCzaEazQoPbTTL8cAAD8h5DjYzabTUkdw8gZYQUAgN8QcvzA3fm4gZADAIC/EHL8wD0hYB23qwAA8BdCjh+4R1jRkgMAgN8QcvygoyWHlcgBAPAfQo4fJEV2zJVDSw4AAP5CyPGDpGhacgAA8DdCjh8k0pIDAIDfEXL8gKUdAADwP0KOH3RMBljd0CyXy1hcDQAAvQMhxw/iI9tDjstIBw62WlwNAAC9AyHHD0KC7IqPCJFEvxwAAPyFkOMniVF0PgYAwJ8IOX6S+M0tKzofAwDgH4QcP0mK/maRTlpyAADwC0KOnyRFMiEgAAD+RMjxE/rkAADgX4QcP2FCQAAA/IuQ4yfulcgbaMkBAMAfCDl+ksTtKgAA/IqQ4ycdSzvQ8RgAAP8g5PhJR8fjxhanGlvaLK4GAIDAR8jxk8jQIIWFtJ9uWnMAAPA9Qo6f2Gw2JUbSLwcAAH8h5PhRR78chpEDAOB7hBw/6hhhxdIOAAD4HiHHjw7NlUNLDgAAvkbI8aOOEVYVdbTkAADga4QcP3LfrqIlBwAAnyPk+NGhCQFpyQEAwNcIOX7E0g4AAPgPIcePElnaAQAAvyHk+FHHZIDVjS1yuozF1QAAENgIOX6UEBkqm00yRqqm8zEAAD7VLULO/fffr8zMTIWFhSknJ0fr168/6rHPPfecxo0bp7i4OEVGRio7O1v/+Mc//Fht1wXZbUqI6Jgrh345AAD4kuUh5+mnn9bChQu1ZMkSbdy4UVlZWZo8ebLKy8uPeHxCQoJuueUWFRQU6JNPPtG8efM0b948vfrqq36uvGvolwMAgH9YHnLuueceXXnllZo3b55GjhypBx98UBEREXrkkUeOePxZZ52l888/XyNGjNDgwYN14403asyYMXr33Xf9XHnXMMIKAAD/sDTktLS0aMOGDcrLy3Pvs9vtysvLU0FBwQ++3hij/Px8bdu2TWecccYRj2lublZtbW2nzUqJ7pBDSw4AAL5kaciprKyU0+lUSkpKp/0pKSkqLS096utqamoUFRWl0NBQTZs2Tffdd5/OOeecIx67dOlSxcbGureMjAyvfodjlRjZsRI5LTkAAPiS5beruiI6OlqFhYX66KOPdMcdd2jhwoVat27dEY9dtGiRampq3FtRUZF/i/2O5GhWIgcAwB+CrfzwpKQkBQUFqaysrNP+srIypaamHvV1drtdQ4YMkSRlZ2dr69atWrp0qc4666zDjnU4HHI4HF6t+3h0tOTQ8RgAAN+ytCUnNDRUY8eOVX5+vnufy+VSfn6+cnNzPX4fl8ul5uae0TJCx2MAAPzD0pYcSVq4cKHmzp2rcePGafz48Vq2bJkaGho0b948SdLll1+uvn37aunSpZLa+9iMGzdOgwcPVnNzs15++WX94x//0IoVK6z8Gh7rGEJOx2MAAHzL8pAzZ84cVVRUaPHixSotLVV2drbWrFnj7oy8d+9e2e2HGpwaGhp07bXXat++fQoPD9fw4cP1z3/+U3PmzLHqKxyTjpacqoZmGWNks9ksrggAgMBkM8b0qkWUamtrFRsbq5qaGsXExPj98xtb2jRycfvEhVtum6woh+U5EwCAbq8rf7975OiqniwiNFgRoUGSGGEFAIAvEXIsQL8cAAB8j5BjAUZYAQDge4QcCyRGdkwISEsOAAC+QsixQFIUSzsAAOBrhBwLuIeRE3IAAPAZQo4F3B2PG7hdBQCArxByLODueFxHSw4AAL5CyLFAR0tOFS05AAD4DCHHAvTJAQDA9wg5FugIOV83tqrV6bK4GgAAAhMhxwJx4SGyf7Mu59fcsgIAwCcIORaw221KiOyY9ZiQAwCALxByLMKEgAAA+BYhxyLuzscNhBwAAHyBkGMR94SAddyuAgDAFwg5FnFPCEhLDgAAPkHIsYh7QkA6HgMA4BOEHIu4W3LoeAwAgE8QciySREsOAAA+RcixSGIkSzsAAOBLhByLJEUfmgzQGGNxNQAABB5CjkUSI9tvV7U4XaprbrO4GgAAAg8hxyJhIUGKcgRLol8OAAC+QMixEEs7AADgO4QcCyVG0fkYAABfIeRYqKNfDiuRAwDgfYQcCx0aYUVLDgAA3kbIsVBSJBMCAgDgK4QcC9GSAwCA7xByLHRo1mNacgAA8DZCjoU6ViKvbKAlBwAAbyPkWMi9EnkdIQcAAG8j5FioYzLA2qY2tbS5LK4GAIDAQsixUExYiILtNklSdQP9cgAA8CZCjoXsdtuhfjmMsAIAwKsIORbrGGFFyAEAwLsIORbraMlhGDkAAN5FyLFYchQtOQAA+EKXQk5RUZH27dvnfrx+/XotWLBADz30kNcK6y3cLTl0PAYAwKu6FHIuvfRSrV27VpJUWlqqc845R+vXr9ctt9yi22+/3asFBjrmygEAwDe6FHK2bNmi8ePHS5L+9a9/afTo0Xr//ff1xBNP6LHHHvNmfQEvsSPk0JIDAIBXdSnktLa2yuFo/+P8xhtv6LzzzpMkDR8+XCUlJd6rrhc41PGYlhwAALypSyFn1KhRevDBB/XOO+/o9ddf15QpUyRJxcXFSkxM9GqBgY6OxwAA+EaXQs7vf/97/eUvf9FZZ52lSy65RFlZWZKkl156yX0bC5759hByY4zF1QAAEDiCu/Kis846S5WVlaqtrVV8fLx7/1VXXaWIiAivFdcbJES2h5w2l1HtwTbFRoRYXBEAAIGhSy05Bw8eVHNzszvg7NmzR8uWLdO2bdvUp08frxYY6BzBQYoJa8+aFdyyAgDAa7oUcmbMmKG///3vkqQDBw4oJydHf/zjHzVz5kytWLHCqwX2Bh3DyOl8DACA93Qp5GzcuFETJkyQJD3zzDNKSUnRnj179Pe//1333nuvVwvsDdwhh2HkAAB4TZdCTmNjo6KjoyVJr732mi644ALZ7Xadcsop2rNnj1cL7A1YiRwAAO/rUsgZMmSIXnjhBRUVFenVV1/VpEmTJEnl5eWKiYnxaoG9waGQQ0sOAADe0qWQs3jxYt10003KzMzU+PHjlZubK6m9VedHP/qRVwvsDZKYKwcAAK/r0hDy2bNn6/TTT1dJSYl7jhxJmjhxos4//3yvFddbJNLxGAAAr+tSyJGk1NRUpaamulcj79evHxMBdlFS5KEJAQEAgHd06XaVy+XS7bffrtjYWA0YMEADBgxQXFycfve738nlcnm7xoCXFM3tKgAAvK1LLTm33HKL/va3v+nOO+/UaaedJkl69913deutt6qpqUl33HGHV4sMdIm05AAA4HVdCjmPP/64/vrXv7pXH5ekMWPGqG/fvrr22msJOceoo09OXXObmlqdCgsJsrgiAAB6vi7drqqurtbw4cMP2z98+HBVV1cfd1G9TUxYsEKD2v+nYEJAAAC8o0shJysrS8uXLz9s//LlyzVmzJjjLqq3sdls31qNnH45AAB4Q5duV911112aNm2a3njjDfccOQUFBSoqKtLLL7/s1QJ7i6Qoh0pqmuiXAwCAl3SpJefMM8/Ul19+qfPPP18HDhzQgQMHdMEFF+izzz7TP/7xD2/X2Ct0tOSU1jZZXAkAAIHBZowx3nqzzZs366STTpLT6fTWW3pdbW2tYmNjVVNT062WoPjja9t035vbNWFokv5xRY7V5QAA0K105e93l1py4H0XjcuQzSa981Wldlc2WF0OAAA9HiGnm8hIiNBZJyRLkp5av9fiagAA6PkIOd3IpTkDJEmrNuxTc1v3veUHAEBPcEyjqy644ILvff7AgQPHU0uvd/awZKXFhqmkpklrtpRqRnZfq0sCAKDHOqaWnNjY2O/dBgwYoMsvv9xXtQa84CC7Lj65vyTpiQ+5ZQUAwPE4ppacRx991Fd14BtzTs7QvW9+pfW7qvVVWZ2GpkRbXRIAAD0SfXK6mdTYMOWN6COJ1hwAAI5Htwg5999/vzIzMxUWFqacnBytX7/+qMc+/PDDmjBhguLj4xUfH6+8vLzvPb4nuuybDsjPbtyngy10QAYAoCssDzlPP/20Fi5cqCVLlmjjxo3KysrS5MmTVV5efsTj161bp0suuURr165VQUGBMjIyNGnSJO3fv9/PlfvO6UOS1D8hQnVNbfr3J8VWlwMAQI/k1RmPuyInJ0cnn3yye8FPl8uljIwMXX/99br55pt/8PVOp1Px8fFavny5R52eu+uMx9+1Yt0O/X7NF8rOiNML80+zuhwAACzV42Y8bmlp0YYNG5SXl+feZ7fblZeXp4KCAo/eo7GxUa2trUpISDji883Nzaqtre209QQXjuunkCCbCosOaMv+GqvLAQCgx7E05FRWVsrpdColJaXT/pSUFJWWlnr0Hr/+9a+Vnp7eKSh929KlSzsNc8/IyDjuuv0hKcqhKaPTJElPMgMyAADHzPI+Ocfjzjvv1MqVK/X8888rLCzsiMcsWrRINTU17q2oqMjPVXbdZTntc+a8uGm/6pvbLK4GAICexdKQk5SUpKCgIJWVlXXaX1ZWptTU1O997d13360777xTr732msaMGXPU4xwOh2JiYjptPUXOwAQNTo5UQ4tTL2wKnI7VAAD4g6UhJzQ0VGPHjlV+fr57n8vlUn5+vnJzc4/6urvuuku/+93vtGbNGo0bN84fpVrCZrO5h5M/8eFeWdxHHACAHsXy21ULFy7Uww8/rMcff1xbt27VNddco4aGBs2bN0+SdPnll2vRokXu43//+9/rt7/9rR555BFlZmaqtLRUpaWlqq+vt+or+NSsk/rJEWzX1pJabSo6YHU5AAD0GJaHnDlz5ujuu+/W4sWLlZ2drcLCQq1Zs8bdGXnv3r0qKSlxH79ixQq1tLRo9uzZSktLc2933323VV/Bp2IjQvSTMemSpCeZARkAAI9ZPk+Ov/WUeXK+bePer3XBA+/LEWzX+t/kKTYixOqSAADwqx43Tw4886OMOI1Ii1Fzm0vPbtxndTkAAPQIhJweoL0Dcvtw8ic+3EMHZAAAPEDI6SFm/qivIkODtKOiQR/uqra6HAAAuj1CTg8R5QjWedl9JbUPJwcAAN+PkNODdNyyWrOlRJX1zRZXAwBA90bI6UFG941VVkacWp1Gz2ygAzIAAN+HkNPDdLTmPPnhXrlcdEAGAOBoCDk9zPQx6YoOC9be6ka9u73S6nIAAOi2CDk9THhokGad1E9S+3ByAABwZIScHqjjltUbW8tVVttkcTUAAHRPhJweaGhKtMZnJsjpMnr6oyKrywEAoFsi5PRQl53S3prz1Pq9anO6LK4GAIDuh5DTQ00ZnaqEyFCV1DRp3bYKq8sBAKDbIeT0UI7gIF04lg7IAAAcDSGnB7tkfPstq3VfVqioutHiagAA6F4IOT1YZlKkJgxNkjHSyo9YzwoAgG8j5PRwl37TmvP0R/vUSgdkAADcCDk9XN7IFCVHO1RZ36zXPy+zuhwAALoNQk4PFxJk18UnZ0iS/l6w29piAADoRgg5AeDi8f0VbLfpg53V+vfmYqvLAQCgWyDkBIC+ceGaf/YQSdJvX9yi8jqWegAAgJATIK778RCNSo/RgcZWLXr2UxljrC4JAABLEXICREiQXfdclK3QILvyvyjXMxv2WV0SAACWIuQEkGGp0VpwzlBJ0u3//lzFBw5aXBEAANYh5ASYqyYM0o/6x6muuU2/fvYTblsBAHotQk6ACQ6y6+4Ls+QItuudryr1xIfMhAwA6J0IOQFocHKUfj1luCTp/17eqr1VrGsFAOh9CDkB6menZipnYIIaW5y66ZnNcrm4bQUA6F0IOQHKbrfp7guzFBEapPW7qvXo+7utLgkAAL8i5ASwjIQI3TJthCTprjVfaEdFvcUVAQDgP4ScAHfp+P6aMDRJzW0u/fJfm9XGSuUAgF6CkBPgbDab7po9RtFhwSosOqC/vL3T6pIAAPALQk4vkBYbrlunj5IkLXvjS31RWmtxRQAA+B4hp5e44KS+yhuRolan0cKnN6uljdtWAIDARsjpJWw2m/7vgtGKiwjR5yW1Wr52u9UlAQDgU4ScXqRPdJj+d+ZoSdL9a7frk30HrC0IAAAfIuT0Mj8Zk65pY9LkdBn98l+b1dTqtLokAAB8gpDTC/1uxmglRTn0VXm9/vT6l1aXAwCATxByeqGEyFAtveBESdJD7+zUhj3VFlcEAID3EXJ6qXNGpmjWSf1kjPTLf21WY0ub1SUBAOBVhJxebPH0kUqNCdPuqkbdtWab1eUAAOBVhJxeLDY8RHfNHiNJeuz93Xr98zKLKwIAwHsIOb3cGSck67Kc/pKka/65QS9s2m9xRQAAeAchB7r1vFGakZ2uNpfRgqcL9eh7u6wuCQCA40bIgUKC7PrTRdn62amZkqTb/v25/vjaNhljrC0MAIDjQMiBJMlut2nJ9JG6adIJkqT73tyuW17YIqeLoAMA6JkIOXCz2Wy67sdDdcf5o2WzSU9+uFfXPblRzW3MigwA6HkIOTjMZTkDdP+lJyk0yK5XtpRq3qMfqb6ZeXQAAD0LIQdHdO6JaXps3smKDA3S+zuqdMlDH6iyvtnqsgAA8BghB0d16pAkPXXVKUqIDNWn+2t00YMFKqputLosAAA8QsjB9xrTL07PXJ2rvnHh2lnZoNkPvq9tpXVWlwUAwA8i5OAHDUqO0rPXnKoTUqJUVtusi/5SwKKeAIBuj5ADj6TGhulf/y9XJ/WPU83BVl321w+1dlu51WUBAHBUhBx4LC4iVP/8eY7OGpasplaXrnz8Y5aBAAB0W4QcHJOI0GA9fPk4zfzWMhCPvMsyEACA7oeQg2MWEmTXPRdla95pmZKk21d/rjtf+YLZkQEA3QohB11it9u0+CeHloF48K0dmvMXhpgDALoPQg66rGMZiGVzshXlCNbHe77W1D+/o2c37GNxTwCA5Qg5OG4zf9RXr9w4QSdnxqu+uU2/XLVZ1z21STWNrVaXBgDoxQg58IqMhAitvCpXN006QcF2m/7zSYmm/Pltvb+90urSAAC9FCEHXhNkb7999ew1p2pgUqRKapp02d8+1P+9vJWVzAEAfkfIgddlZcTpPzecrkvG95cx0kNv79TM+9/Xl2UsBwEA8B9CDnwiIjRYSy84UQ9fPk4JkaHaWlKr6fe9q8fe20WnZACAXxBy4FPnjEzRmgUTdOYJyWpuc+nWf3+unz36kcprm6wuDQAQ4Ag58Lk+0WF6bN7Juu28UXIE2/XWlxWa8ud39NpnpVaXBgAIYIQc+IXNZtPcUzO1+vrTNTItRtUNLbrqHxt087OfqKG5zeryAAABiJADvxqaEq3n55+q/3fmINls0sqPijTt3ne0ae/XVpcGAAgwhBz4nSM4SIumjtATP89RWmyYdlc1ataK9/XH17appc1ldXkAgABByIFlTh2cpDU3nqEZ2elyGem+N7fr/AfeY6g5AMArLA85999/vzIzMxUWFqacnBytX7/+qMd+9tlnmjVrljIzM2Wz2bRs2TL/FQqfiI0I0Z8v/pHuv/QkxUWE6LPiWv3kvnf18Ns7WdUcAHBcLA05Tz/9tBYuXKglS5Zo48aNysrK0uTJk1VeXn7E4xsbGzVo0CDdeeedSk1N9XO18KVpY9L02oIzdPawZLW0uXTHy1t1ycMfsKo5AKDLbMbCmdlycnJ08skna/ny5ZIkl8uljIwMXX/99br55pu/97WZmZlasGCBFixYcEyfWVtbq9jYWNXU1CgmJqarpcNHjDFa+VGR/nf152pocSoyNEiLp4/UReMyZLPZrC4PAGCRrvz9tqwlp6WlRRs2bFBeXt6hYux25eXlqaCgwGuf09zcrNra2k4bui+bzaZLxvfXKzeeoZMz49XQ4tSvn/1UP3/8Y5XXMYEgAMBzloWcyspKOZ1OpaSkdNqfkpKi0lLvTRK3dOlSxcbGureMjAyvvTd8p39i+6rmvzl3uEKD7Mr/olyT//S2Xv60xOrSAAA9hOUdj31t0aJFqqmpcW9FRUVWlwQPBdltuuqMwfr3NxMIft3Yqmuf2KgFKzepprHV6vIAAN2cZSEnKSlJQUFBKisr67S/rKzMq52KHQ6HYmJiOm3oWYalRuuF+afpurOHyG6TXigs1uRlb+udryqsLg0A0I1ZFnJCQ0M1duxY5efnu/e5XC7l5+crNzfXqrLQTYUG23XT5GF65ppTNTApUqW1Tfrp39Zr8Ytb1NjCshAAgMNZertq4cKFevjhh/X4449r69atuuaaa9TQ0KB58+ZJki6//HItWrTIfXxLS4sKCwtVWFiolpYW7d+/X4WFhdq+fbtVXwF+dlL/eP3nhtN1ee4ASdLfC/bonHve1t2vbtNXTCIIAPgWS4eQS9Ly5cv1hz/8QaWlpcrOzta9996rnJwcSdJZZ52lzMxMPfbYY5Kk3bt3a+DAgYe9x5lnnql169Z59HkMIQ8c73xVoV+t+kSltYdGXQ1PjdZ52emaPiZdGQkRFlYHAPCmrvz9tjzk+BshJ7AcbHHqtc9L9e/NxXrrywq1Og9dzif1j9N5WemaNiZdydEOC6sEABwvQo4HCDmB60Bji17ZUqqXCov1wa4qdVzZdpt02pAkTc9K1+RRqYoND7G2UADAMSPkeICQ0zuU1TZp9SclemlzsTYXHXDvDw2y66xhyTovO10Th6coPDTIuiIBAB4j5HiAkNP77K5s0L83F+ulzcX6qrzevT8yNEiTRqVq9th+OnVwIstGAEA3RsjxACGn9zLG6IvSOr20uVgvFRZr/4GD7ucGJUXq0pz+mj22n+IiQi2sEgBwJIQcDxByILUHno17D+i5jfv0wqb9amhxSpIcwXb9ZEy6/uuU/srOiKN1BwC6CUKOBwg5+K765ja9WLhf//xgr7aWHFrAdWRajP7rlAGakZ2uSEewhRUCAAg5HiDk4GiMMdpUdED//GCPVn9SopY2lyQpyhGs83/UV5ed0l/DU7lmAMAKhBwPEHLgiQONLXpmwz498eFe7apscO8fNyBe/3XKAE0ZnaqwEEZmAYC/EHI8QMjBsTDG6P0dVXriwz167bMytbna/7kkRIbqwrH9dMn4/spMirS4SgAIfIQcDxBy0FXltU1a+VGRnlq/VyU1h5aSGNMvVueemKZzR6epfyJLSQCALxByPEDIwfFqc7q0dluFnvhwj97+skKub/0LGt03RueemKZpJ6ZpQCItPADgLYQcDxBy4E2V9c169bNSvfxpiQp2VHUKPCPTYjRtTJrOPTFNA7mlBQDHhZDjAUIOfKWqvlmvfV6mlz8t0fs7quT8VuIZkRajaSemauqJaRqcHGVhlQDQMxFyPEDIgT9UN7Totc9K9Z8jBJ7hqdHtfXhOTNOQPgQeAPAEIccDhBz429cNLXr98zL959MSvbe90j1CS2oPPDOy++q87HT1jQu3sEoA6N4IOR4g5MBKBxpb3Le03v2qc+AZn5mg87LTde6JaUqIZP0sAPg2Qo4HCDnoLg40tuiVLaV6sXC/PtxVrY5/icF2m844IVkzstN1zsgURYSypAQAEHI8QMhBd1RSc1D/3lysFwuL9VnxofWzwkOCNGlUimZkp2vC0GSFBNktrBIArEPI8QAhB93d9vI6vVRYrBc3F2tPVaN7f3xEiKaNSdOM7L4a2z9edjsrpAPoPQg5HiDkoKcwxqiw6IBeLCzW6k+KVVnf4n6ub1y4po5O1Qkp0cpIiFD/xAilxoQpiOADIEARcjxAyEFP1OZ0qWBnlV4sLNaaLaWqb2477JiQIJv6xoW3h55vbR0hKCYsxILKAcA7CDkeIOSgp2tqdSp/a7kKdlZqb/VBFVU3at/XjWp1fv8/5biIkEOhJyFCJ/aN1elDkwg/AHoEQo4HCDkIRE6XUVltk/ZWN2pvdaOKvvnZ8fu3b3V9W5DdprED4nX2sD46a1iyhqdGy2bjlheA7oeQ4wFCDnqjhuY2FX3dqL1V7cFnd1WDCnZUaUdFQ6fjUmPCdPbwZJ01rI9OG5KkKAfD1wF0D4QcDxBygEOKqhu1blu51m6r0Ps7KtXU6nI/FxJk08mZCe5WniF9omjlAWAZQo4HCDnAkTW1OvXhrmqt/aJc67aVa/e3hq9L7SO6zh6erLOH9VHu4EQmKQTgV4QcDxByAM/sqmzQum3lWretQgU7q9TSdqiVJzTYrkkjU3ThuAydPiSJoesAfI6Q4wFCDnDsDrY4VbCzUmu/qNDabeXa9/VB93OpMWG64KS+mj22nwYls6o6AN8g5HiAkAMcH2OMPiuu1TMb9umFwv060Njqfm7sgHjNHttPPxmTpmiGpgPwIkKOBwg5gPc0tzn15tZyrdqwT+u2latjUfWwELumjErVheMylDsokSUoABw3Qo4HCDmAb5TXNun5Tfu1asM+bS+vd+/vGxeuWSf11ayx/TQgMdLCCgH0ZIQcDxByAN8yxmjzvhqt+rhIL20uVl3ToSUoxg9M0Oyx/TR1dCq3swAcE0KOBwg5gP80tTr12udlembDPr3zVYU6/msTZLdpTL9YnTIoUbmDEjUuM54h6QC+FyHHA4QcwBolNQf13Mb9enbjPu38zkzLwXabsjLidMqgBOUOStLYAfEKDw2yqFIA3REhxwOEHMB6+75u1Ac7q/XBzioV7KjS/gMHOz0fEmRTdkacu6XnpAHxCgsh9AC9GSHHA4QcoPspqm5Uwc4qfbCzSh/sqFJxTVOn50OD7O2hZ3CixmcmKDnaoYjQIEU6ghURGiRHsJ0lJ4AAR8jxACEH6N6MMSqqPtjeyvNNS09pbdP3vibYbusUeqIcwYoIDVakI8j9MzI0WBGOYEU7ghUdFqyY8BDFhIUoJjz4m58hig4LVkiQ3U/fFMCx6Mrfb3r6AehWbDab+idGqH9ihC46OUPGGO2pamxv5dlZpc37alR7sFUNLW3uBUXbXEa1TW2q/dZIrq4KDwk6LPh0hKG48FBlZcQpd3AiK7QDPQAtOQB6LKfLqLGlTQ3NTjW0tKmx2an65rb2fS1ONTa3ffO4/fmG5jbVN7WprqlNtU2tqj3Y8bNVDS1Ojz832G7TSQPideYJyTpjaLJGpccw4SHgY9yu8gAhB8CRtDldqm9u6xR8aps6/15R16T3d1Rpz3dWaE+IDNXpQ5J0xgnJmjA0SSkxYRZ9CyBwEXI8QMgBcLz2VDXo7a8q9faXFSrYUaX65s63yYanRrsDz8mZCYwMA7yAkOMBQg4Ab2p1urRp7wG9/WWF3vmqQp/sr9G3/6saFmJXzsBETRiapDH94mSMkdNl1OZq/9nqdHV63P7Tdeixs/2nyxgNTIrUSQPilRTlsO4LAxYh5HiAkAPAl6obWvTu9kq982WF3v6qQmW1zV7/jMzECJ00IF7jBiRo7IB4De0TRZ8gBDxCjgcIOQD8xRijr8rr9faXFXrrywoVVTcqyG5TsN3e/jPI9s1jW+f9HY+DbAqy2xVst8lljLaW1OrLsvrDPic6LFgn9Y/X2AHtW1ZGHKO/EHAIOR4g5ADoyWoaW7Wp6Gtt2NO+FRYdUON3RobZbdKItBh36Dmpf7z6xYczYSJ6NEKOBwg5AAJJm9OlL0rr3KFnw56vD1smQ5KSox0anR6jUemxGpUeo5HpMeqfEEHwQY9ByPEAIQdAoCupOaiNew60h569X+uz/TVqcx3+n/poR7BGpMdo1LfCz5A+Ucz6jG6JkOMBQg6A3uZgi1Ofl9Tos+Jafba/Vp+X1GpbaZ1anK7Djg0NsuuE1CiNSovVqL7tAWh4aowi6eMDixFyPEDIAYD2oe/by+vbg09xewDaWlyruubDl8aw2aT02HD1jQ9Xv/hw9YuPUMY3P/vFhystNkzB3az1p9XpUsmBJlU1NGtInyhFh4VYXRKOEyHHA4QcADiyjsVRO0JPx8/yuu8fBh9ktyk1Jkz94sOVkRDhDkL9vglFqTHeD0HGGH3d2Kq91Y3aW92oom+2jsfFBw6q4w6dzSad0CdaP+of980WryHJDLvvaQg5HiDkAMCxqaxvdgeJfV8f/GZr1P5vfj/Sba/vigwNUlRYsCIdwYpyBCsyNFhRYd/87ghSlCNEUY4g9/Pt+9u32oOtKvq6UXurDoWYfV8fPGym6e9yBNsVEx6iiiOEtGhHsLIy4tzBJzsjXgmRoV0+R99mjFFji1ONLU4lRYX2mM7dxhhtLanTm1+UqaSmSZmJkRqUHKnByVHqFx9ueWsdIccDhBwA8B6Xy6iivln7vu4cgDp+3+9hCOqq1Jgw9U+IUL+EcPVPiOi0JUU5ZLfbVF7XpMK9B7Sp6IA27f1am4tqdLD18AVZMxMj9KP+8e3BJyNew9OiFRJklzFGdc1tqq5vUVVDi6rqm1Xd0P579Tdb++/N7mOa29q/85h+sbpywiBNHZ1qeUg4koMtTr2/o1L5X5Rr7RflKqlpOuJxIUE2DUiM1ODkSA1KjtLg5Kj2AJQUpdgI/9wKJOR4gJADAP7jchlVN7aovql9RfiGb1aGb//dqfrmVtU3O90rxNd/a7X4+uY2NbS0KTI0+FB4SYxQRnyE+7ZYV9YFa3O69GVZvTYVfa1Ne9uDz46KhsOOcwTbFRcRouqGFrU6j+9PZd+4cM07LVMXj+9v+USNxQcO6s0vyvXmF+V6b3ulO5BJ7cuQnD4kWSekRGlPdaN2VjRoZ0V9p2O+KykqVIOSotytPoOSIzWkT5QGJEZ6tW5CjgcIOQCA76ppbFXhvvbA0xF8aps63w6LCA1SQmSoEiND239GOdy/tz8OVULkoX1NrU7984O9+nvBblU1tEhqv012aU5//ey0TKXFhvvluzldRpv3HdCbW8uV/0W5tpbUdnq+b1y4fjy8j348oo9yByUeFhxdLqPimoPaWdGgHRX1nX6W1h655Wdonyi9vvBMr34PQo4HCDkAgB/ichntqmpQY7NTCVHtwaarq8k3tTr1/Kb9+us7O90tRsF2m34yJk0/nzBIo/vGerN0SVJdU6ve+apS+VvLtW5buTtkSe0zYp/UP15nD++jiSP6aFhKdJf7DdU3t2lXRYN2VtZrR3m9dlQ2aGdFg05IidKfL/6Rt76OJEKORwg5AAAruFxG674s10Nv79QHO6vd+08dnKgrzxiks05I7lLYqKxv1ufFtdpa0j4H0ufFtdpRUa9vz/8YHRasM09I1sQRfXTmCX281snanwg5HiDkAACstmV/jR5+Z6dWf1Ii5zdpZGifKP18wkDNyO57xFYjp8tod1VDe5gpPhRojjbEf1BypCYO76MfD0/RuMz4Hj+TNSHHA4QcAEB3sf/AQT323i49tb7IPSQ+KcqhubkDlDs4UdvK6tyB5ouSuiOOCrPZpIGJkRqR1r4m2ci0GI1Ii1FqbJi/v45PEXI8QMgBAHQ3tU2tenp9kR59b5eKjzKMW2of/TQ8tT3MjEhrDzTDU6N7xbIbhBwPEHIAAN1Vq9Ollz8t0aPv7VZJzcHDAs3ApEgF9dKZmrvy9zvwox8AAD1ESJBdM7L7akZ2X6tLCQg9uxcSAADAURByAABAQCLkAACAgETIAQAAAYmQAwAAAhIhBwAABCRCDgAACEiEHAAAEJC6Rci5//77lZmZqbCwMOXk5Gj9+vXfe/yqVas0fPhwhYWF6cQTT9TLL7/sp0oBAEBPYXnIefrpp7Vw4UItWbJEGzduVFZWliZPnqzy8vIjHv/+++/rkksu0RVXXKFNmzZp5syZmjlzprZs2eLnygEAQHdm+dpVOTk5Ovnkk7V8+XJJksvlUkZGhq6//nrdfPPNhx0/Z84cNTQ0aPXq1e59p5xyirKzs/Xggw/+4OexdhUAAD1PV/5+W9qS09LSog0bNigvL8+9z263Ky8vTwUFBUd8TUFBQafjJWny5MlHPb65uVm1tbWdNgAAEPgsDTmVlZVyOp1KSUnptD8lJUWlpaVHfE1paekxHb906VLFxsa6t4yMDO8UDwAAujXL++T42qJFi1RTU+PeioqKrC4JAAD4QbCVH56UlKSgoCCVlZV12l9WVqbU1NQjviY1NfWYjnc4HHI4HO7HHV2QuG0FAEDP0fF3+1i6ElsackJDQzV27Fjl5+dr5syZkto7Hufn5+u666474mtyc3OVn5+vBQsWuPe9/vrrys3N9egz6+rqJInbVgAA9EB1dXWKjY316FhLQ44kLVy4UHPnztW4ceM0fvx4LVu2TA0NDZo3b54k6fLLL1ffvn21dOlSSdKNN96oM888U3/84x81bdo0rVy5Uh9//LEeeughjz4vPT1dRUVFio6Ols1m8+p3qa2tVUZGhoqKihi5dQw4b8eOc9Y1nLeu4bx1Deft2H3fOTPGqK6uTunp6R6/n+UhZ86cOaqoqNDixYtVWlqq7OxsrVmzxt25eO/evbLbD3UdOvXUU/Xkk0/qf/7nf/Sb3/xGQ4cO1QsvvKDRo0d79Hl2u139+vXzyXfpEBMTwwXdBZy3Y8c56xrOW9dw3rqG83bsjnbOPG3B6WD5PDmBhDl4uobzduw4Z13DeesazlvXcN6OnbfPWcCPrgIAAL0TIceLHA6HlixZ0mk0F34Y5+3Ycc66hvPWNZy3ruG8HTtvnzNuVwEAgIBESw4AAAhIhBwAABCQCDkAACAgEXIAAEBAIuR4yf3336/MzEyFhYUpJydH69evt7qkbu3WW2+VzWbrtA0fPtzqsrqdt99+W9OnT1d6erpsNpteeOGFTs8bY7R48WKlpaUpPDxceXl5+uqrr6wpthv5ofP2s5/97LDrb8qUKdYU200sXbpUJ598sqKjo9WnTx/NnDlT27Zt63RMU1OT5s+fr8TEREVFRWnWrFmHrSXY23hy3s4666zDrrerr77aooq7hxUrVmjMmDHuSf9yc3P1yiuvuJ/31rVGyPGCp59+WgsXLtSSJUu0ceNGZWVlafLkySovL7e6tG5t1KhRKikpcW/vvvuu1SV1Ow0NDcrKytL9999/xOfvuusu3XvvvXrwwQf14YcfKjIyUpMnT1ZTU5OfK+1efui8SdKUKVM6XX9PPfWUHyvsft566y3Nnz9fH3zwgV5//XW1trZq0qRJamhocB/zi1/8Qv/+97+1atUqvfXWWyouLtYFF1xgYdXW8+S8SdKVV17Z6Xq76667LKq4e+jXr5/uvPNObdiwQR9//LF+/OMfa8aMGfrss88kefFaMzhu48ePN/Pnz3c/djqdJj093SxdutTCqrq3JUuWmKysLKvL6FEkmeeff9792OVymdTUVPOHP/zBve/AgQPG4XCYp556yoIKu6fvnjdjjJk7d66ZMWOGJfX0FOXl5UaSeeutt4wx7ddWSEiIWbVqlfuYrVu3GkmmoKDAqjK7ne+eN2OMOfPMM82NN95oXVE9RHx8vPnrX//q1WuNlpzj1NLSog0bNigvL8+9z263Ky8vTwUFBRZW1v199dVXSk9P16BBg3TZZZdp7969VpfUo+zatUulpaWdrr3Y2Fjl5ORw7Xlg3bp16tOnj4YNG6ZrrrlGVVVVVpfUrdTU1EiSEhISJEkbNmxQa2trp+tt+PDh6t+/P9fbt3z3vHV44oknlJSUpNGjR2vRokVqbGy0orxuyel0auXKlWpoaFBubq5XrzXLF+js6SorK+V0Ot0LinZISUnRF198YVFV3V9OTo4ee+wxDRs2TCUlJbrttts0YcIEbdmyRdHR0VaX1yOUlpZK0hGvvY7ncGRTpkzRBRdcoIEDB2rHjh36zW9+o6lTp6qgoEBBQUFWl2c5l8ulBQsW6LTTTnMvflxaWqrQ0FDFxcV1Opbr7ZAjnTdJuvTSSzVgwAClp6frk08+0a9//Wtt27ZNzz33nIXVWu/TTz9Vbm6umpqaFBUVpeeff14jR45UYWGh1641Qg4sMXXqVPfvY8aMUU5OjgYMGKB//etfuuKKKyysDL3BxRdf7P79xBNP1JgxYzR48GCtW7dOEydOtLCy7mH+/PnasmUL/eSO0dHO21VXXeX+/cQTT1RaWpomTpyoHTt2aPDgwf4us9sYNmyYCgsLVVNTo2eeeUZz587VW2+95dXP4HbVcUpKSlJQUNBhvb7LysqUmppqUVU9T1xcnE444QRt377d6lJ6jI7ri2vv+A0aNEhJSUlcf5Kuu+46rV69WmvXrlW/fv3c+1NTU9XS0qIDBw50Op7rrd3RztuR5OTkSFKvv95CQ0M1ZMgQjR07VkuXLlVWVpb+/Oc/e/VaI+Qcp9DQUI0dO1b5+fnufS6XS/n5+crNzbWwsp6lvr5eO3bsUFpamtWl9BgDBw5Uampqp2uvtrZWH374IdfeMdq3b5+qqqp69fVnjNF1112n559/Xm+++aYGDhzY6fmxY8cqJCSk0/W2bds27d27t1dfbz903o6ksLBQknr19XYkLpdLzc3N3r3WvNs3undauXKlcTgc5rHHHjOff/65ueqqq0xcXJwpLS21urRu65e//KVZt26d2bVrl3nvvfdMXl6eSUpKMuXl5VaX1q3U1dWZTZs2mU2bNhlJ5p577jGbNm0ye/bsMcYYc+edd5q4uDjz4osvmk8++cTMmDHDDBw40Bw8eNDiyq31feetrq7O3HTTTaagoMDs2rXLvPHGG+akk04yQ4cONU1NTVaXbplrrrnGxMbGmnXr1pmSkhL31tjY6D7m6quvNv379zdvvvmm+fjjj01ubq7Jzc21sGrr/dB52759u7n99tvNxx9/bHbt2mVefPFFM2jQIHPGGWdYXLm1br75ZvPWW2+ZXbt2mU8++cTcfPPNxmazmddee80Y471rjZDjJffdd5/p37+/CQ0NNePHjzcffPCB1SV1a3PmzDFpaWkmNDTU9O3b18yZM8ds377d6rK6nbVr1xpJh21z5841xrQPI//tb39rUlJSjMPhMBMnTjTbtm2ztuhu4PvOW2Njo5k0aZJJTk42ISEhZsCAAebKK6/s9f+n5EjnS5J59NFH3cccPHjQXHvttSY+Pt5ERESY888/35SUlFhXdDfwQ+dt79695owzzjAJCQnG4XCYIUOGmF/96lempqbG2sIt9t///d9mwIABJjQ01CQnJ5uJEye6A44x3rvWbMYY08WWJQAAgG6LPjkAACAgEXIAAEBAIuQAAICARMgBAAABiZADAAACEiEHAAAEJEIOAAAISIQcAL2SzWbTCy+8YHUZAHyIkAPA7372s5/JZrMdtk2ZMsXq0gAEkGCrCwDQO02ZMkWPPvpop30Oh8OiagAEIlpyAFjC4XAoNTW10xYfHy+p/VbSihUrNHXqVIWHh2vQoEF65plnOr3+008/1Y9//GOFh4crMTFRV111lerr6zsd88gjj2jUqFFyOBxKS0vTdddd1+n5yspKnX/++YqIiNDQoUP10ksvuZ/7+uuvddlllyk5OVnh4eEaOnToYaEMQPdGyAHQLf32t7/VrFmztHnzZl122WW6+OKLtXXrVklSQ0ODJk+erPj4eH300UdatWqV3njjjU4hZsWKFZo/f76uuuoqffrpp3rppZc0ZMiQTp9x22236aKLLtInn3yic889V5dddpmqq6vdn//555/rlVde0datW7VixQolJSX57wQAOH7eW1MUADwzd+5cExQUZCIjIzttd9xxhzGmfWXnq6++utNrcnJyzDXXXGOMMeahhx4y8fHxpr6+3v38f/7zH2O3292riaenp5tbbrnlqDVIMv/zP//jflxfX28kmVdeecUYY8z06dPNvHnzvPOFAViCPjkALHH22WdrxYoVnfYlJCS4f8/Nze30XG5urgoLCyVJW7duVVZWliIjI93Pn3baaXK5XNq2bZtsNpuKi4s1ceLE761hzJgx7t8jIyMVExOj8vJySdI111yjWbNmaePGjZo0aZJmzpypU089tUvfFYA1CDkALBEZGXnY7SNvCQ8P9+i4kJCQTo9tNptcLpckaerUqdqzZ49efvllvf7665o4caLmz5+vu+++2+v1AvAN+uQA6JY++OCDwx6PGDFCkjRixAht3rxZDQ0N7uffe+892e12DRs2TNHR0crMzFR+fv5x1ZCcnKy5c+fqn//8p5YtW6aHHnrouN4PgH/RkgPAEs3NzSotLe20Lzg42N25d9WqVRo3bpxOP/10PfHEE1q/fr3+9re/SZIuu+wyLVmyRHPnztWtt96qiooKXX/99frpT3+qlJQUSdKtt96qq6++Wn369NHUqVNVV1en9957T9dff71H9S1evFhjx47VqFGj1NzcrNWrV7tDFoCegZADwBJr1qxRWlpap33Dhg3TF198Ial95NPKlSt17bXXKi0tTU899ZRGjhwpSYqIiNCrr76qG2+8USeffLIiIiI0a9Ys3XPPPe73mjt3rpqamvSnP/1JN910k5KSkjR79myP6wsNDdWiRYu0e/duhYeHa8KECVq5cqUXvjkAf7EZY4zVRQDAt9lsNj3//POaOXOm1aUA6MHokwMAAAISIQcAAAQk+uQA6Ha4iw7AG2jJAQAAAYmQAwAAAhIhBwAABCRCDgAACEiEHAAAEJAIOQAAICARcgAAQEAi5AAAgIBEyAEAAAHp/wPcA0mX0rR8DgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def plot_history(hist):\n", + " plt.plot(hist)\n", + " plt.title(\"Model Loss\")\n", + " plt.xlabel(\"Epochs\")\n", + " plt.ylabel(\"Loss\")\n", + " plt.show()\n", + " \n", + "plot_history(hist);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7f30a5bf-caca-44fc-8d5a-8ed8863600e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training accuracy: 0.9894897959183674\n", + "Test accuracy: 0.9495238095238095\n" + ] + } + ], + "source": [ + "train_predictions = np.argmax(NN.predict(train_data), axis = 1)\n", + "print(\"Training accuracy: \", accuracy_score(train[\"label\"].to_numpy(), train_predictions))\n", + "test_predictions = np.argmax(NN.predict(test_data), axis = 1)\n", + "print(\"Test accuracy: \", accuracy_score(test[\"label\"].to_numpy(), test_predictions))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5d3ea10a-9a1a-4f7b-8dad-3a112b3e5add", + "metadata": {}, + "outputs": [], + "source": [ + "def predict_image(index):\n", + " current_image = test_data[index, :]\n", + " prediction = np.argmax(NN.predict(current_image), axis = 1)\n", + " label = test[\"label\"].to_numpy()[index]\n", + " print(\"Prediction: \", prediction)\n", + " print(\"Label: \", label)\n", + " \n", + " current_image = current_image.reshape((28, 28)) * 255\n", + " plt.gray()\n", + " plt.imshow(current_image, interpolation = \"nearest\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e143b0a5-0cf2-43b7-894c-8497e17b4461", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: [5]\n", + "Label: 5\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbGklEQVR4nO3df2zU9R3H8dfx60Btj5XSXisFCyolIixW6Bq1w9G0dIvhVxZxbsPFYGDFDDp1qVHROdeNJRtzYWq2BYYDf5DxI5qNDKstuhUcCCEGaCjpRgm0CEnvoJVC2s/+IN48KeD3uOv7en0+kk/Cfb/fdz9vvnzpi+/dl099zjknAAD62CDrBgAAAxMBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABNDrBv4op6eHh0/flxpaWny+XzW7QAAPHLO6cyZM8rNzdWgQZe/z0m6ADp+/Ljy8vKs2wAAXKOWlhaNGTPmsvuT7i24tLQ06xYAAHFwte/nCQug1atX66abbtLw4cNVVFSkDz/88EvV8bYbAKSGq30/T0gAvfHGG6qqqtKKFSv00UcfaerUqSovL9fJkycTMR0AoD9yCTB9+nRXWVkZed3d3e1yc3NdTU3NVWtDoZCTxGAwGIx+PkKh0BW/38f9Duj8+fPas2ePSktLI9sGDRqk0tJSNTQ0XHJ8V1eXwuFw1AAApL64B9CpU6fU3d2t7OzsqO3Z2dlqbW295PiamhoFAoHI4Ak4ABgYzJ+Cq66uVigUioyWlhbrlgAAfSDu/w8oMzNTgwcPVltbW9T2trY2BYPBS473+/3y+/3xbgMAkOTifgc0bNgwFRYWqra2NrKtp6dHtbW1Ki4ujvd0AIB+KiErIVRVVWnhwoW68847NX36dK1atUodHR36wQ9+kIjpAAD9UEIC6P7779cnn3yiZ555Rq2trfrqV7+qbdu2XfJgAgBg4PI555x1E58XDocVCASs2wAAXKNQKKT09PTL7jd/Cg4AMDARQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMDHEugHgagoKCjzX3HnnnTHNdfDgQc81ixYtimkur0pKSjzXTJw4Maa5GhsbPdc45zzXzJ8/33PNoUOHPNcgOXEHBAAwQQABAEzEPYCeffZZ+Xy+qBHLWygAgNSWkM+AbrvtNr3zzjv/n2QIHzUBAKIlJBmGDBmiYDCYiC8NAEgRCfkM6PDhw8rNzdX48eP14IMP6ujRo5c9tqurS+FwOGoAAFJf3AOoqKhIa9eu1bZt2/TSSy+publZ99xzj86cOdPr8TU1NQoEApGRl5cX75YAAEko7gFUUVGhb3/725oyZYrKy8v1t7/9Te3t7XrzzTd7Pb66ulqhUCgyWlpa4t0SACAJJfzpgJEjR+rWW29VU1NTr/v9fr/8fn+i2wAAJJmE/z+gs2fP6siRI8rJyUn0VACAfiTuAfTYY4+pvr5e//nPf/Svf/1Lc+fO1eDBg/XAAw/EeyoAQD8W97fgjh07pgceeECnT5/W6NGjdffdd2vnzp0aPXp0vKcCAPRjPhfLCoIJFA6HFQgErNtAEvn3v//tueaOO+6Iaa5Y/jr4fL6Umqcv59q7d6/nmmnTpnmugY1QKKT09PTL7mctOACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYS/gPpgGt16tQpzzWxLKYZq76aK5Z5Dh06FNNcnZ2dnmsKCgo818TyZ4vUwR0QAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEq2Ej6X3ve9/zXNPW1hbTXAcOHPBcs3Dhwpjm6gusho1kxh0QAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEyxGiqQXy4KVf/jDH2Kayznnueajjz6Kaa5UE+vCpxi4uAMCAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggsVIkZImTZoUU92BAwfi3AmAy+EOCABgggACAJjwHEA7duzQfffdp9zcXPl8Pm3ZsiVqv3NOzzzzjHJycjRixAiVlpbq8OHD8eoXAJAiPAdQR0eHpk6dqtWrV/e6f+XKlXrxxRf18ssva9euXbr++utVXl6uc+fOXXOzAIDU4fkhhIqKClVUVPS6zzmnVatW6amnntLs2bMlSevWrVN2dra2bNmiBQsWXFu3AICUEdfPgJqbm9Xa2qrS0tLItkAgoKKiIjU0NPRa09XVpXA4HDUAAKkvrgHU2toqScrOzo7anp2dHdn3RTU1NQoEApGRl5cXz5YAAEnK/Cm46upqhUKhyGhpabFuCQDQB+IaQMFgUJLU1tYWtb2trS2y74v8fr/S09OjBgAg9cU1gPLz8xUMBlVbWxvZFg6HtWvXLhUXF8dzKgBAP+f5KbizZ8+qqakp8rq5uVn79u1TRkaGxo4dq2XLlulnP/uZbrnlFuXn5+vpp59Wbm6u5syZE8++AQD9nOcA2r17t+69997I66qqKknSwoULtXbtWj3xxBPq6OjQI488ovb2dt19993atm2bhg8fHr+uAQD9ns8556yb+LxwOKxAIGDdBvq5np6emOpeeeUVzzWPPfaY55qCggLPNbE4ePBgTHWdnZ1x7gQDUSgUuuLn+uZPwQEABiYCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAnPP44B6A9iXeR90qRJnmvWrVvnuWb27Nmea3w+n+eazZs3e66RpKeeespzzaFDh2KaCwMXd0AAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM+FysqzYmSDgcViAQsG4DSaSkpMRzTV1dXUxzxfLXIZZFQpN5nljnqqqq8lyzatUqzzXoP0KhkNLT0y+7nzsgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJliMFEkvMzPTc01bW1tMc8Xy1+H06dOea37+8597rnn//fc918Tq1Vdf9VwzatQozzUzZszwXHPo0CHPNbDBYqQAgKREAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABIuRIiX19PTEVPfJJ594rlm8eLHnms2bN3uu6Utz5871XLNu3TrPNbEsLDpt2jTPNbDBYqQAgKREAAEATHgOoB07dui+++5Tbm6ufD6ftmzZErX/oYceks/nixqzZs2KV78AgBThOYA6Ojo0depUrV69+rLHzJo1SydOnIiM11577ZqaBACkniFeCyoqKlRRUXHFY/x+v4LBYMxNAQBSX0I+A6qrq1NWVpYmTpyoJUuWXPFHFnd1dSkcDkcNAEDqi3sAzZo1S+vWrVNtba1++ctfqr6+XhUVFeru7u71+JqaGgUCgcjIy8uLd0sAgCTk+S24q1mwYEHk17fffrumTJmiCRMmqK6uTjNnzrzk+OrqalVVVUVeh8NhQggABoCEP4Y9fvx4ZWZmqqmpqdf9fr9f6enpUQMAkPoSHkDHjh3T6dOnlZOTk+ipAAD9iOe34M6ePRt1N9Pc3Kx9+/YpIyNDGRkZeu655zR//nwFg0EdOXJETzzxhG6++WaVl5fHtXEAQP/mOYB2796te++9N/L6s89vFi5cqJdeekn79+/Xn//8Z7W3tys3N1dlZWV6/vnn5ff749c1AKDfYzFSpKRly5bFVBfLYqTr16+Paa5UE8tipLEsevr973/fc02yL/6aqliMFACQlAggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJuL+I7mBZLBq1SrrFgacDz74wHPNd7/7Xc81ZWVlnmtYDTs5cQcEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABIuRAjDjnPNcM3fuXM81S5Ys8VyDxOMOCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkWI0Wf+utf/+q5ZtOmTZ5r1q9f77kG/1dYWOi55vnnn/dc4/P5+qQGyYk7IACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACZYjBR9as6cOZ5rXnjhhfg3MkC8+uqrMdWVlZV5rhk1apTnmk8++cRzTUVFhecaJCfugAAAJgggAIAJTwFUU1OjadOmKS0tTVlZWZozZ44aGxujjjl37pwqKys1atQo3XDDDZo/f77a2tri2jQAoP/zFED19fWqrKzUzp07tX37dl24cEFlZWXq6OiIHLN8+XK99dZb2rhxo+rr63X8+HHNmzcv7o0DAPo3Tw8hbNu2Ler12rVrlZWVpT179qikpEShUEh/+tOftGHDBn3jG9+QJK1Zs0aTJk3Szp079bWvfS1+nQMA+rVr+gwoFApJkjIyMiRJe/bs0YULF1RaWho5pqCgQGPHjlVDQ0OvX6Orq0vhcDhqAABSX8wB1NPTo2XLlumuu+7S5MmTJUmtra0aNmyYRo4cGXVsdna2Wltbe/06NTU1CgQCkZGXlxdrSwCAfiTmAKqsrNTHH3+s119//ZoaqK6uVigUioyWlpZr+noAgP4hpv+IunTpUr399tvasWOHxowZE9keDAZ1/vx5tbe3R90FtbW1KRgM9vq1/H6//H5/LG0AAPoxT3dAzjktXbpUmzdv1rvvvqv8/Pyo/YWFhRo6dKhqa2sj2xobG3X06FEVFxfHp2MAQErwdAdUWVmpDRs2aOvWrUpLS4t8rhMIBDRixAgFAgE9/PDDqqqqUkZGhtLT0/Xoo4+quLiYJ+AAAFE8BdBLL70kSZoxY0bU9jVr1uihhx6SJP3mN7/RoEGDNH/+fHV1dam8vFy///3v49IsACB1+JxzzrqJzwuHwwoEAtZtIEFiudwOHDjguaampsZzjSQdPHjQc43P5/NcU1BQ4LnmySef9FwzceJEzzVSbL+nWP5sf/vb33quqaqq8lwDG6FQSOnp6Zfdz1pwAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATrIaNPvX88897rqmurvZcE8tqzlJsKzr31crRfTVPrHNt2rTJc82SJUs815w6dcpzDWywGjYAICkRQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwWKkSHp///vfPdeUl5fHNFcyLxLa2dnpuSbWhTv/+Mc/eq554YUXYpoLqYvFSAEASYkAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJFiNF0svMzPRcU1hYGNNcc+bM8VwzadIkzzUHDx70XPOPf/zDc83777/vuUaKfRFT4PNYjBQAkJQIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYDFSAEBCsBgpACApEUAAABOeAqimpkbTpk1TWlqasrKyNGfOHDU2NkYdM2PGDPl8vqixePHiuDYNAOj/PAVQfX29KisrtXPnTm3fvl0XLlxQWVmZOjo6oo5btGiRTpw4ERkrV66Ma9MAgP5viJeDt23bFvV67dq1ysrK0p49e1RSUhLZft111ykYDManQwBASrqmz4BCoZAkKSMjI2r7+vXrlZmZqcmTJ6u6ulqdnZ2X/RpdXV0Kh8NRAwAwALgYdXd3u29961vurrvuitr+yiuvuG3btrn9+/e7v/zlL+7GG290c+fOvezXWbFihZPEYDAYjBQboVDoijkScwAtXrzYjRs3zrW0tFzxuNraWifJNTU19br/3LlzLhQKRUZLS4v5SWMwGAzGtY+rBZCnz4A+s3TpUr399tvasWOHxowZc8Vji4qKJElNTU2aMGHCJfv9fr/8fn8sbQAA+jFPAeSc06OPPqrNmzerrq5O+fn5V63Zt2+fJCknJyemBgEAqclTAFVWVmrDhg3aunWr0tLS1NraKkkKBAIaMWKEjhw5og0bNuib3/ymRo0apf3792v58uUqKSnRlClTEvIbAAD0U14+99Fl3udbs2aNc865o0ePupKSEpeRkeH8fr+7+eab3eOPP37V9wE/LxQKmb9vyWAwGIxrH1f73s9ipACAhGAxUgBAUiKAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmEi6AHLOWbcAAIiDq30/T7oAOnPmjHULAIA4uNr3c59LsluOnp4eHT9+XGlpafL5fFH7wuGw8vLy1NLSovT0dKMO7XEeLuI8XMR5uIjzcFEynAfnnM6cOaPc3FwNGnT5+5whfdjTlzJo0CCNGTPmisekp6cP6AvsM5yHizgPF3EeLuI8XGR9HgKBwFWPSbq34AAAAwMBBAAw0a8CyO/3a8WKFfL7/datmOI8XMR5uIjzcBHn4aL+dB6S7iEEAMDA0K/ugAAAqYMAAgCYIIAAACYIIACAiX4TQKtXr9ZNN92k4cOHq6ioSB9++KF1S33u2Weflc/nixoFBQXWbSXcjh07dN999yk3N1c+n09btmyJ2u+c0zPPPKOcnByNGDFCpaWlOnz4sE2zCXS18/DQQw9dcn3MmjXLptkEqamp0bRp05SWlqasrCzNmTNHjY2NUcecO3dOlZWVGjVqlG644QbNnz9fbW1tRh0nxpc5DzNmzLjkeli8eLFRx73rFwH0xhtvqKqqSitWrNBHH32kqVOnqry8XCdPnrRurc/ddtttOnHiRGR88MEH1i0lXEdHh6ZOnarVq1f3un/lypV68cUX9fLLL2vXrl26/vrrVV5ernPnzvVxp4l1tfMgSbNmzYq6Pl577bU+7DDx6uvrVVlZqZ07d2r79u26cOGCysrK1NHRETlm+fLleuutt7Rx40bV19fr+PHjmjdvnmHX8fdlzoMkLVq0KOp6WLlypVHHl+H6genTp7vKysrI6+7ubpebm+tqamoMu+p7K1ascFOnTrVuw5Qkt3nz5sjrnp4eFwwG3a9+9avItvb2duf3+91rr71m0GHf+OJ5cM65hQsXutmzZ5v0Y+XkyZNOkquvr3fOXfyzHzp0qNu4cWPkmIMHDzpJrqGhwarNhPvieXDOua9//evuRz/6kV1TX0LS3wGdP39ee/bsUWlpaWTboEGDVFpaqoaGBsPObBw+fFi5ubkaP368HnzwQR09etS6JVPNzc1qbW2Nuj4CgYCKiooG5PVRV1enrKwsTZw4UUuWLNHp06etW0qoUCgkScrIyJAk7dmzRxcuXIi6HgoKCjR27NiUvh6+eB4+s379emVmZmry5Mmqrq5WZ2enRXuXlXSLkX7RqVOn1N3drezs7Kjt2dnZOnTokFFXNoqKirR27VpNnDhRJ06c0HPPPad77rlHH3/8sdLS0qzbM9Ha2ipJvV4fn+0bKGbNmqV58+YpPz9fR44c0ZNPPqmKigo1NDRo8ODB1u3FXU9Pj5YtW6a77rpLkydPlnTxehg2bJhGjhwZdWwqXw+9nQdJ+s53vqNx48YpNzdX+/fv109+8hM1NjZq06ZNht1GS/oAwv9VVFREfj1lyhQVFRVp3LhxevPNN/Xwww8bdoZksGDBgsivb7/9dk2ZMkUTJkxQXV2dZs6cadhZYlRWVurjjz8eEJ+DXsnlzsMjjzwS+fXtt9+unJwczZw5U0eOHNGECRP6us1eJf1bcJmZmRo8ePAlT7G0tbUpGAwadZUcRo4cqVtvvVVNTU3WrZj57Brg+rjU+PHjlZmZmZLXx9KlS/X222/rvffei/rxLcFgUOfPn1d7e3vU8al6PVzuPPSmqKhIkpLqekj6ABo2bJgKCwtVW1sb2dbT06Pa2loVFxcbdmbv7NmzOnLkiHJycqxbMZOfn69gMBh1fYTDYe3atWvAXx/Hjh3T6dOnU+r6cM5p6dKl2rx5s959913l5+dH7S8sLNTQoUOjrofGxkYdPXo0pa6Hq52H3uzbt0+Skut6sH4K4st4/fXXnd/vd2vXrnUHDhxwjzzyiBs5cqRrbW21bq1P/fjHP3Z1dXWuubnZ/fOf/3SlpaUuMzPTnTx50rq1hDpz5ozbu3ev27t3r5Pkfv3rX7u9e/e6//73v845537xi1+4kSNHuq1bt7r9+/e72bNnu/z8fPfpp58adx5fVzoPZ86ccY899phraGhwzc3N7p133nF33HGHu+WWW9y5c+esW4+bJUuWuEAg4Orq6tyJEycio7OzM3LM4sWL3dixY927777rdu/e7YqLi11xcbFh1/F3tfPQ1NTkfvrTn7rdu3e75uZmt3XrVjd+/HhXUlJi3Hm0fhFAzjn3u9/9zo0dO9YNGzbMTZ8+3e3cudO6pT53//33u5ycHDds2DB34403uvvvv981NTVZt5Vw7733npN0yVi4cKFz7uKj2E8//bTLzs52fr/fzZw50zU2Nto2nQBXOg+dnZ2urKzMjR492g0dOtSNGzfOLVq0KOX+kdbb71+SW7NmTeSYTz/91P3whz90X/nKV9x1113n5s6d606cOGHXdAJc7TwcPXrUlZSUuIyMDOf3+93NN9/sHn/8cRcKhWwb/wJ+HAMAwETSfwYEAEhNBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATPwPCRYJck+lSFQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predict_image(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7b1dd60a-05eb-4c3a-9209-ba598be45bb9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: [4]\n", + "Label: 4\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAatElEQVR4nO3de2xT9/nH8Y+5xNwS0xASJ+UWoMBUINUoZBEtoyOCZBPjJhUY0mBiIFhAA9Z2ohrQdpPSUqmrOmWwPzpY1XIZ0gCBtGg0NEHrAhUUhhBrRGg2giChRcKGAIGS7+8PVP/qEi7H2HkS835JRyL2+cbPzo7y7onNweeccwIAoI11sh4AAPBoIkAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBEF+sBvq2lpUXnzp1TamqqfD6f9TgAAI+cc7p8+bJycnLUqdPdr3PaXYDOnTun/v37W48BAHhI9fX16tev312fb3e/gktNTbUeAQAQB/f7eZ6wAJWVlWnQoEHq1q2b8vPz9cknnzzQOn7tBgDJ4X4/zxMSoO3bt2vVqlVat26dPv30U+Xl5WnKlCm6cOFCIl4OANARuQQYN26cKykpiXx969Ytl5OT40pLS++7NhQKOUlsbGxsbB18C4VC9/x5H/croBs3bujIkSMqLCyMPNapUycVFhaqurr6jv2bm5sVDoejNgBA8ot7gL788kvdunVLWVlZUY9nZWWpoaHhjv1LS0sVCAQiG5+AA4BHg/mn4FavXq1QKBTZ6uvrrUcCALSBuP89oIyMDHXu3FmNjY1Rjzc2NioYDN6xv9/vl9/vj/cYAIB2Lu5XQCkpKRozZowqKioij7W0tKiiokIFBQXxfjkAQAeVkDshrFq1SvPnz9fTTz+tcePG6e2331ZTU5N+9rOfJeLlAAAdUEICNHv2bH3xxRdau3atGhoa9NRTT6m8vPyODyYAAB5dPuecsx7im8LhsAKBgPUYAICHFAqFlJaWdtfnzT8FBwB4NBEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmulgPACRC165dY1q3bds2z2umT5/ueU3nzp09rwGSDVdAAAATBAgAYCLuAXrllVfk8/mithEjRsT7ZQAAHVxC3gN68skn9eGHH/7/i3ThrSYAQLSElKFLly4KBoOJ+NYAgCSRkPeATp06pZycHA0ePFjz5s3TmTNn7rpvc3OzwuFw1AYASH5xD1B+fr42b96s8vJybdiwQXV1dXr22Wd1+fLlVvcvLS1VIBCIbP3794/3SACAdsjnnHOJfIFLly5p4MCBeuutt7Rw4cI7nm9ublZzc3Pk63A4TITw0Ph7QIC9UCiktLS0uz6f8E8H9O7dW8OGDVNtbW2rz/v9fvn9/kSPAQBoZxL+94CuXLmi06dPKzs7O9EvBQDoQOIeoBdeeEFVVVX673//q3/961+aMWOGOnfurLlz58b7pQAAHVjcfwV39uxZzZ07VxcvXlTfvn31zDPP6ODBg+rbt2+8XwoA0IEl/EMIXoXDYQUCAesx0MFt2bIlpnXPP/98nCdpHX85G4+C+30IgXvBAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuCMi2r3Jkyd7XlNcXJyASVr3/vvvt9lrAcmEKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4G7YaFNPPfWU5zVbt271vCY1NdXzGkn6+9//7nnNz3/+85heC3jUcQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqSIWSAQ8LxmzZo1bfI6saqsrPS85quvvor/IB3Qa6+95nnNmDFjPK9Zvny55zWff/655zVIPK6AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAAT3IwUMRs1apTnNdOmTUvAJHeaN29eTOt27twZ50k6pqKiIs9rli5d6nnNY4895nlNWlqa5zVon7gCAgCYIEAAABOeA3TgwAFNnTpVOTk58vl82rVrV9TzzjmtXbtW2dnZ6t69uwoLC3Xq1Kl4zQsASBKeA9TU1KS8vDyVlZW1+vz69ev1zjvvaOPGjTp06JB69uypKVOm6Pr16w89LAAgeXj+EEJxcbGKi4tbfc45p7ffflu/+c1vIm82v/fee8rKytKuXbs0Z86ch5sWAJA04voeUF1dnRoaGlRYWBh5LBAIKD8/X9XV1a2uaW5uVjgcjtoAAMkvrgFqaGiQJGVlZUU9npWVFXnu20pLSxUIBCJb//794zkSAKCdMv8U3OrVqxUKhSJbfX299UgAgDYQ1wAFg0FJUmNjY9TjjY2Nkee+ze/3Ky0tLWoDACS/uAYoNzdXwWBQFRUVkcfC4bAOHTqkgoKCeL4UAKCD8/wpuCtXrqi2tjbydV1dnY4dO6b09HQNGDBAK1as0O9+9zs98cQTys3N1Zo1a5STk6Pp06fHc24AQAfnOUCHDx/Wc889F/l61apVkqT58+dr8+bNeumll9TU1KTFixfr0qVLeuaZZ1ReXq5u3brFb2oAQIfnc8456yG+KRwOKxAIWI+BB/DnP//Z85qf/vSnntfs27fP85pYr7ibm5tjWpdstm/f7nnNrFmzPK/5+OOPPa+J5Uap165d87wGDy8UCt3zfX3zT8EBAB5NBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMOH5n2MAvpaVldUmr/PGG294XhPrXa3z8vI8r/n3v/8d02u1hUGDBsW0LpY7TsfizTff9LyGO1snD66AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAAT3IwUGjlyZEzrnnvuuThP0rrt27d7XvPVV1/F9Fo9e/b0vKapqcnzmgMHDnhe8/rrr3teE8tsUmzH4dy5c57XnDx50vMaJA+ugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFOrSJbbTICUlJc6TtK5Pnz5t8jqx6tWrl+c1zz//vOc106ZN87zm1q1bntfE6vPPP2+TNUgeXAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GSl07NixmNb94x//8LymsLDQ85pr1655XtOWYrkpa7du3Tyv8fv9nte0JZ/PZz0COhiugAAAJggQAMCE5wAdOHBAU6dOVU5Ojnw+n3bt2hX1/IIFC+Tz+aK2oqKieM0LAEgSngPU1NSkvLw8lZWV3XWfoqIinT9/PrJt3br1oYYEACQfzx9CKC4uVnFx8T338fv9CgaDMQ8FAEh+CXkPqLKyUpmZmRo+fLiWLl2qixcv3nXf5uZmhcPhqA0AkPziHqCioiK99957qqio0BtvvKGqqioVFxff9d+mLy0tVSAQiGz9+/eP90gAgHYo7n8PaM6cOZE/jxo1SqNHj9aQIUNUWVmpSZMm3bH/6tWrtWrVqsjX4XCYCAHAIyDhH8MePHiwMjIyVFtb2+rzfr9faWlpURsAIPklPEBnz57VxYsXlZ2dneiXAgB0IJ5/BXflypWoq5m6ujodO3ZM6enpSk9P16uvvqpZs2YpGAzq9OnTeumllzR06FBNmTIlroMDADo2zwE6fPiwnnvuucjXX79/M3/+fG3YsEHHjx/XX/7yF126dEk5OTmaPHmyfvvb37b7+1gBANqWzznnrIf4pnA4rEAgYD0GHsCPf/xjz2tiOd327NnjeU1bevrppz2vKSgo8LwmlmM3d+5cz2skKT8/3/OaQ4cOeV7zzf+YfVA3btzwvAY2QqHQPd/X515wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHdsIEk9u6778a0bv78+XGepHXr16/3vObll19OwCRIBO6GDQBolwgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE12sBwDw6Bo/frznNT169PC85urVq57XIPG4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUgBxsW7dOs9rZs+e7XnNsGHDPK85duyY5zVIPK6AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAAT3IwUQFwcOXLE85qTJ096XhMOhz2vQfvEFRAAwAQBAgCY8BSg0tJSjR07VqmpqcrMzNT06dNVU1MTtc/169dVUlKiPn36qFevXpo1a5YaGxvjOjQAoOPzFKCqqiqVlJTo4MGD2rdvn27evKnJkyerqakpss/KlSu1Z88e7dixQ1VVVTp37pxmzpwZ98EBAB2bpw8hlJeXR329efNmZWZm6siRI5owYYJCoZDeffddbdmyRT/4wQ8kSZs2bdJ3vvMdHTx4UN/73vfiNzkAoEN7qPeAQqGQJCk9PV3S7U/B3Lx5U4WFhZF9RowYoQEDBqi6urrV79Hc3KxwOBy1AQCSX8wBamlp0YoVKzR+/HiNHDlSktTQ0KCUlBT17t07at+srCw1NDS0+n1KS0sVCAQiW//+/WMdCQDQgcQcoJKSEp04cULbtm17qAFWr16tUCgU2err6x/q+wEAOoaY/iLqsmXLtHfvXh04cED9+vWLPB4MBnXjxg1dunQp6iqosbFRwWCw1e/l9/vl9/tjGQMA0IF5ugJyzmnZsmXauXOn9u/fr9zc3Kjnx4wZo65du6qioiLyWE1Njc6cOaOCgoL4TAwASAqeroBKSkq0ZcsW7d69W6mpqZH3dQKBgLp3765AIKCFCxdq1apVSk9PV1pampYvX66CggI+AQcAiOIpQBs2bJAkTZw4MerxTZs2acGCBZKk3//+9+rUqZNmzZql5uZmTZkyRX/84x/jMiwAIHl4CpBz7r77dOvWTWVlZSorK4t5KADx8dlnn1mPcE87d+60HgGGuBccAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATMT0L6IC6Bhivdt0aWlpnCcB7sQVEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABggpuRAkksFArFtK62ttbzmuXLl3teU15e7nkNkgdXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCiSxL774IqZ1Gzdu9LwmMzMzptfCo4srIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADAhM8556yH+KZwOKxAIGA9BgDgIYVCIaWlpd31ea6AAAAmCBAAwISnAJWWlmrs2LFKTU1VZmampk+frpqamqh9Jk6cKJ/PF7UtWbIkrkMDADo+TwGqqqpSSUmJDh48qH379unmzZuaPHmympqaovZbtGiRzp8/H9nWr18f16EBAB2fp38Rtby8POrrzZs3KzMzU0eOHNGECRMij/fo0UPBYDA+EwIAktJDvQcUCoUkSenp6VGPf/DBB8rIyNDIkSO1evVqXb169a7fo7m5WeFwOGoDADwCXIxu3brlfvSjH7nx48dHPf6nP/3JlZeXu+PHj7v333/fPf74427GjBl3/T7r1q1zktjY2NjYkmwLhUL37EjMAVqyZIkbOHCgq6+vv+d+FRUVTpKrra1t9fnr16+7UCgU2err680PGhsbGxvbw2/3C5Cn94C+tmzZMu3du1cHDhxQv3797rlvfn6+JKm2tlZDhgy543m/3y+/3x/LGACADsxTgJxzWr58uXbu3KnKykrl5ubed82xY8ckSdnZ2TENCABITp4CVFJSoi1btmj37t1KTU1VQ0ODJCkQCKh79+46ffq0tmzZoh/+8Ifq06ePjh8/rpUrV2rChAkaPXp0Qv4HAAA6KC/v++guv+fbtGmTc865M2fOuAkTJrj09HTn9/vd0KFD3Ysvvnjf3wN+UygUMv+9JRsbGxvbw2/3+9nPzUgBAAnBzUgBAO0SAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBEuwuQc856BABAHNzv53m7C9Dly5etRwAAxMH9fp77XDu75GhpadG5c+eUmpoqn88X9Vw4HFb//v1VX1+vtLQ0owntcRxu4zjcxnG4jeNwW3s4Ds45Xb58WTk5OerU6e7XOV3acKYH0qlTJ/Xr1++e+6SlpT3SJ9jXOA63cRxu4zjcxnG4zfo4BAKB++7T7n4FBwB4NBAgAICJDhUgv9+vdevWye/3W49iiuNwG8fhNo7DbRyH2zrScWh3H0IAADwaOtQVEAAgeRAgAIAJAgQAMEGAAAAmOkyAysrKNGjQIHXr1k35+fn65JNPrEdqc6+88op8Pl/UNmLECOuxEu7AgQOaOnWqcnJy5PP5tGvXrqjnnXNau3atsrOz1b17dxUWFurUqVM2wybQ/Y7DggUL7jg/ioqKbIZNkNLSUo0dO1apqanKzMzU9OnTVVNTE7XP9evXVVJSoj59+qhXr16aNWuWGhsbjSZOjAc5DhMnTrzjfFiyZInRxK3rEAHavn27Vq1apXXr1unTTz9VXl6epkyZogsXLliP1uaefPJJnT9/PrL985//tB4p4ZqampSXl6eysrJWn1+/fr3eeecdbdy4UYcOHVLPnj01ZcoUXb9+vY0nTaz7HQdJKioqijo/tm7d2oYTJl5VVZVKSkp08OBB7du3Tzdv3tTkyZPV1NQU2WflypXas2ePduzYoaqqKp07d04zZ840nDr+HuQ4SNKiRYuizof169cbTXwXrgMYN26cKykpiXx969Ytl5OT40pLSw2nanvr1q1zeXl51mOYkuR27twZ+bqlpcUFg0H35ptvRh67dOmS8/v9buvWrQYTto1vHwfnnJs/f76bNm2ayTxWLly44CS5qqoq59zt/++7du3qduzYEdnnP//5j5PkqqurrcZMuG8fB+ec+/73v+9++ctf2g31ANr9FdCNGzd05MgRFRYWRh7r1KmTCgsLVV1dbTiZjVOnTiknJ0eDBw/WvHnzdObMGeuRTNXV1amhoSHq/AgEAsrPz38kz4/KykplZmZq+PDhWrp0qS5evGg9UkKFQiFJUnp6uiTpyJEjunnzZtT5MGLECA0YMCCpz4dvH4evffDBB8rIyNDIkSO1evVqXb161WK8u2p3NyP9ti+//FK3bt1SVlZW1ONZWVn67LPPjKaykZ+fr82bN2v48OE6f/68Xn31VT377LM6ceKEUlNTrccz0dDQIEmtnh9fP/eoKCoq0syZM5Wbm6vTp0/r5ZdfVnFxsaqrq9W5c2fr8eKupaVFK1as0Pjx4zVy5EhJt8+HlJQU9e7dO2rfZD4fWjsOkvSTn/xEAwcOVE5Ojo4fP65f//rXqqmp0d/+9jfDaaO1+wDh/xUXF0f+PHr0aOXn52vgwIH661//qoULFxpOhvZgzpw5kT+PGjVKo0eP1pAhQ1RZWalJkyYZTpYYJSUlOnHixCPxPui93O04LF68OPLnUaNGKTs7W5MmTdLp06c1ZMiQth6zVe3+V3AZGRnq3LnzHZ9iaWxsVDAYNJqqfejdu7eGDRum2tpa61HMfH0OcH7cafDgwcrIyEjK82PZsmXau3evPvroo6h/viUYDOrGjRu6dOlS1P7Jej7c7Ti0Jj8/X5La1fnQ7gOUkpKiMWPGqKKiIvJYS0uLKioqVFBQYDiZvStXruj06dPKzs62HsVMbm6ugsFg1PkRDod16NChR/78OHv2rC5evJhU54dzTsuWLdPOnTu1f/9+5ebmRj0/ZswYde3aNep8qKmp0ZkzZ5LqfLjfcWjNsWPHJKl9nQ/Wn4J4ENu2bXN+v99t3rzZnTx50i1evNj17t3bNTQ0WI/Wpn71q1+5yspKV1dX5z7++GNXWFjoMjIy3IULF6xHS6jLly+7o0ePuqNHjzpJ7q233nJHjx51//vf/5xzzr3++uuud+/ebvfu3e748eNu2rRpLjc31127ds148vi613G4fPmye+GFF1x1dbWrq6tzH374ofvud7/rnnjiCXf9+nXr0eNm6dKlLhAIuMrKSnf+/PnIdvXq1cg+S5YscQMGDHD79+93hw8fdgUFBa6goMBw6vi733Gora11r732mjt8+LCrq6tzu3fvdoMHD3YTJkwwnjxahwiQc8794Q9/cAMGDHApKSlu3Lhx7uDBg9YjtbnZs2e77Oxsl5KS4h5//HE3e/ZsV1tbaz1Wwn300UdO0h3b/PnznXO3P4q9Zs0al5WV5fx+v5s0aZKrqamxHToB7nUcrl696iZPnuz69u3runbt6gYOHOgWLVqUdP+R1tr/fklu06ZNkX2uXbvmfvGLX7jHHnvM9ejRw82YMcOdP3/ebugEuN9xOHPmjJswYYJLT093fr/fDR061L344osuFArZDv4t/HMMAAAT7f49IABAciJAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATPwfc5J/8vbOGW0AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predict_image(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5d69171e-e864-44a6-9e51-183002a47c90", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: [2]\n", + "Label: 2\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAb20lEQVR4nO3df2yV9fn/8dcptEfQ9mCp7WnlhwUVNhEWELpORRxNf0iIKNtESUTnMLhiBp0yu0zRbVkdSzbnwnBZDJ2ZoJINCGxpAtWWTFsMKGFO11HSrWW0RVk4B4oUpO/vH3w9H48U8D6c0+v08Hwk76Tnvu+r99U3d86L+9x37/qcc04AAAywNOsGAACXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJoZaN/B5fX19OnjwoDIzM+Xz+azbAQB45JzT0aNHVVBQoLS0c5/nJF0AHTx4UKNHj7ZuAwBwkTo6OjRq1Khzrk+6j+AyMzOtWwAAxMGF3s8TFkCrV6/WNddco8suu0xFRUV6++23v1AdH7sBQGq40Pt5QgLo1VdfVVVVlVauXKl33nlHU6ZMUVlZmQ4dOpSI3QEABiOXADNmzHCVlZWR16dPn3YFBQWupqbmgrWhUMhJYjAYDMYgH6FQ6Lzv93E/Azp58qR2796tkpKSyLK0tDSVlJSoqanprO17e3sVDoejBgAg9cU9gD766COdPn1aeXl5Ucvz8vLU1dV11vY1NTUKBAKRwR1wAHBpML8Lrrq6WqFQKDI6OjqsWwIADIC4/x5QTk6OhgwZou7u7qjl3d3dCgaDZ23v9/vl9/vj3QYAIMnF/QwoIyND06ZNU319fWRZX1+f6uvrVVxcHO/dAQAGqYQ8CaGqqkqLFi3STTfdpBkzZui5555TT0+PHnzwwUTsDgAwCCUkgO655x59+OGHeuqpp9TV1aWvfOUrqqurO+vGBADApcvnnHPWTXxWOBxWIBCwbgMAcJFCoZCysrLOud78LjgAwKWJAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmBhq3QCQCGlpsf3f6stf/rLnmqqqKs81xcXFnms++eQTzzUPPvig5xpJ2rVrV0x1gBecAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBw0iR9CZNmuS5Zs2aNTHt65ZbbvFcE8tDQmOpSU9P91yzbds2zzWS9O1vf9tzzcaNG2PaFy5dnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwcNIMaCmTp3quaa+vj4BnfTviSee8FwTy0M4//Wvf3mu+cY3vuG55ve//73nGkl68cUXY6rzigeYXto4AwIAmCCAAAAm4h5ATz/9tHw+X9SYOHFivHcDABjkEnIN6IYbbtD27dv/bydDudQEAIiWkGQYOnSogsFgIr41ACBFJOQa0L59+1RQUKBx48Zp4cKFam9vP+e2vb29CofDUQMAkPriHkBFRUWqra1VXV2d1qxZo7a2Nt166606evRov9vX1NQoEAhExujRo+PdEgAgCcU9gCoqKvTNb35TkydPVllZmf7617/qyJEjeu211/rdvrq6WqFQKDI6Ojri3RIAIAkl/O6AESNG6Prrr1dra2u/6/1+v/x+f6LbAAAkmYT/HtCxY8e0f/9+5efnJ3pXAIBBJO4B9Nhjj6mxsVH//ve/9dZbb+muu+7SkCFDdO+998Z7VwCAQSzuH8EdOHBA9957rw4fPqyrrrpKt9xyi5qbm3XVVVfFe1cAgEHM55xz1k18VjgcViAQsG4DX0BWVpbnmr/85S+eazIyMjzX3H///Z5rJKmlpSWmumQ1Z86cmOq2bNniuWbfvn2ea6ZPn+65hl/VGDxCodB53yd4FhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATCf+DdEhdCxcu9FwzdepUzzWzZ8/2XJNqDxWNVSwPf5Wk9evXe6657777PNfEcgytWbPGcw2SE2dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATPA0bMfv73//uueatt97yXNPc3Oy5Bhdnw4YNnmtieRr2bbfd5rmGp2GnDs6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmPA555x1E58VDocVCASs2wAuaVlZWZ5r2traPNf897//9Vxz0003ea45efKk5xpcvFAodN5jiTMgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJoZaNwAg+YTDYc81sTzw89ChQwOyHyQnzoAAACYIIACACc8BtGPHDs2dO1cFBQXy+XzatGlT1HrnnJ566inl5+dr2LBhKikp0b59++LVLwAgRXgOoJ6eHk2ZMkWrV6/ud/2qVav0/PPP64UXXtDOnTt1+eWXq6ysTCdOnLjoZgEAqcPzTQgVFRWqqKjod51zTs8995x+9KMf6c4775QkvfTSS8rLy9OmTZu0YMGCi+sWAJAy4noNqK2tTV1dXSopKYksCwQCKioqUlNTU781vb29CofDUQMAkPriGkBdXV2SpLy8vKjleXl5kXWfV1NTo0AgEBmjR4+OZ0sAgCRlfhdcdXW1QqFQZHR0dFi3BAAYAHENoGAwKEnq7u6OWt7d3R1Z93l+v19ZWVlRAwCQ+uIaQIWFhQoGg6qvr48sC4fD2rlzp4qLi+O5KwDAIOf5Lrhjx46ptbU18rqtrU179uxRdna2xowZo2XLlumnP/2prrvuOhUWFurJJ59UQUGB5s2bF8++AQCDnOcA2rVrl26//fbI66qqKknSokWLVFtbqxUrVqinp0cPP/ywjhw5oltuuUV1dXW67LLL4tc1AGDQ8znnnHUTnxUOhxUIBKzbAC5psdyN+v7773uu2blzp+eaz/6aB5JbKBQ673V987vgAACXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACc9/jgFA6isvL/dcc8UVV3iu+eCDDzzXIHVwBgQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEDyMFcJbi4uIB2c/7778/IPtBcuIMCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkeRoqkl5eX57kmPT09pn0dOHAgprpkdfnll8dUN2fOHM81x44d81yzdetWzzVIHZwBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMMHDSKFRo0bFVPfrX//ac83UqVM912RlZXmuGTJkiOcaSdqwYYPnmueee85zzT/+8Q/PNbFYsWJFTHW5ubmea5544gnPNR0dHZ5rkDo4AwIAmCCAAAAmPAfQjh07NHfuXBUUFMjn82nTpk1R6x944AH5fL6oUV5eHq9+AQApwnMA9fT0aMqUKVq9evU5tykvL1dnZ2dkrF+//qKaBACkHs83IVRUVKiiouK82/j9fgWDwZibAgCkvoRcA2poaFBubq4mTJigRx55RIcPHz7ntr29vQqHw1EDAJD64h5A5eXleumll1RfX6+f//znamxsVEVFhU6fPt3v9jU1NQoEApExevToeLcEAEhCcf89oAULFkS+vvHGGzV58mSNHz9eDQ0Nmj179lnbV1dXq6qqKvI6HA4TQgBwCUj4bdjjxo1TTk6OWltb+13v9/uVlZUVNQAAqS/hAXTgwAEdPnxY+fn5id4VAGAQ8fwR3LFjx6LOZtra2rRnzx5lZ2crOztbzzzzjObPn69gMKj9+/drxYoVuvbaa1VWVhbXxgEAg5vnANq1a5duv/32yOtPr98sWrRIa9as0d69e/WHP/xBR44cUUFBgUpLS/WTn/xEfr8/fl0DAAY9n3POWTfxWeFwWIFAwLqNQevqq6/2XLN9+/aY9jVx4kTPNZ2dnZ5rDh486Llm2rRpnmtiFcuvDtTV1Xmuef755z3X/OxnP/NcI0nDhw/3XHPHHXd4rvnwww8912DwCIVC572uz7PgAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAm4v4nuWGrtLTUc00sT7WWpP/9738Dsq+enh7PNYWFhZ5rJGn58uWea+6//37PNd/61rcGpCZWn/6ZFS94sjW84gwIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACR5GmmIOHTo0YPvatGmT55pwOBz/RvrR2toaU11lZaXnmoyMDM813/nOdzzXDKRly5Z5rtmyZYvnmlj/nZAaOAMCAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgwuecc9ZNfFY4HFYgELBuY9AaPny455r29vaY9hXLQzhfe+21mPY1UCZMmOC55mtf+5rnmuPHj3uuefPNNz3XxPLzSNI111zjuaajo8Nzzde//nXPNTzAdPAIhULKyso653rOgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJjgYaTQnDlzYqp79tlnPddMmjQppn0ls61bt3quefrppz3X7N6923NNfn6+5xpJqq2t9VxTWlrquSaWB5iuWLHCc80rr7ziuQYXj4eRAgCSEgEEADDhKYBqamo0ffp0ZWZmKjc3V/PmzVNLS0vUNidOnFBlZaVGjhypK664QvPnz1d3d3dcmwYADH6eAqixsVGVlZVqbm7Wtm3bdOrUKZWWlqqnpyeyzfLly7VlyxZt2LBBjY2NOnjwoO6+++64Nw4AGNyGetm4rq4u6nVtba1yc3O1e/duzZw5U6FQSC+++KLWrVsX+UuHa9eu1Ze+9CU1Nzfrq1/9avw6BwAMahd1DSgUCkmSsrOzJZ25S+fUqVMqKSmJbDNx4kSNGTNGTU1N/X6P3t5ehcPhqAEASH0xB1BfX5+WLVumm2++OXJrbVdXlzIyMjRixIiobfPy8tTV1dXv96mpqVEgEIiM0aNHx9oSAGAQiTmAKisr9d577130/fXV1dUKhUKREcvvBQAABh9P14A+tXTpUm3dulU7duzQqFGjIsuDwaBOnjypI0eORJ0FdXd3KxgM9vu9/H6//H5/LG0AAAYxT2dAzjktXbpUGzdu1Ouvv67CwsKo9dOmTVN6errq6+sjy1paWtTe3q7i4uL4dAwASAmezoAqKyu1bt06bd68WZmZmZHrOoFAQMOGDVMgENBDDz2kqqoqZWdnKysrS48++qiKi4u5Aw4AEMVTAK1Zs0aSNGvWrKjla9eu1QMPPCBJ+tWvfqW0tDTNnz9fvb29Kisr029/+9u4NAsASB08jBQxGzrU+yXEWGqS3cmTJz3X9PX1JaCT+ElPT/dc86c//clzzdy5cz3XxDJ3sb6nHDt2LKY6nMHDSAEASYkAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYCL1Hk2MAfPJJ58MSA0G3qlTpzzXLF682HNNQ0OD55orr7zScw3HXXLiDAgAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJn3POWTfxWeFwWIFAwLoNAMBFCoVCysrKOud6zoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmPAUQDU1NZo+fboyMzOVm5urefPmqaWlJWqbWbNmyefzRY0lS5bEtWkAwODnKYAaGxtVWVmp5uZmbdu2TadOnVJpaal6enqitlu8eLE6OzsjY9WqVXFtGgAw+A31snFdXV3U69raWuXm5mr37t2aOXNmZPnw4cMVDAbj0yEAICVd1DWgUCgkScrOzo5a/vLLLysnJ0eTJk1SdXW1jh8/fs7v0dvbq3A4HDUAAJcAF6PTp0+7OXPmuJtvvjlq+e9+9ztXV1fn9u7d6/74xz+6q6++2t11113n/D4rV650khgMBoORYiMUCp03R2IOoCVLlrixY8e6jo6O825XX1/vJLnW1tZ+1584ccKFQqHI6OjoMJ80BoPBYFz8uFAAeboG9KmlS5dq69at2rFjh0aNGnXebYuKiiRJra2tGj9+/Fnr/X6//H5/LG0AAAYxTwHknNOjjz6qjRs3qqGhQYWFhRes2bNnjyQpPz8/pgYBAKnJUwBVVlZq3bp12rx5szIzM9XV1SVJCgQCGjZsmPbv369169bpjjvu0MiRI7V3714tX75cM2fO1OTJkxPyAwAABikv1310js/51q5d65xzrr293c2cOdNlZ2c7v9/vrr32Wvf4449f8HPAzwqFQuafWzIYDAbj4seF3vt9/z9YkkY4HFYgELBuAwBwkUKhkLKyss65nmfBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMJF0AOeesWwAAxMGF3s+TLoCOHj1q3QIAIA4u9H7uc0l2ytHX16eDBw8qMzNTPp8val04HNbo0aPV0dGhrKwsow7tMQ9nMA9nMA9nMA9nJMM8OOd09OhRFRQUKC3t3Oc5Qwewpy8kLS1No0aNOu82WVlZl/QB9inm4Qzm4Qzm4Qzm4QzreQgEAhfcJuk+ggMAXBoIIACAiUEVQH6/XytXrpTf77duxRTzcAbzcAbzcAbzcMZgmoekuwkBAHBpGFRnQACA1EEAAQBMEEAAABMEEADAxKAJoNWrV+uaa67RZZddpqKiIr399tvWLQ24p59+Wj6fL2pMnDjRuq2E27Fjh+bOnauCggL5fD5t2rQpar1zTk899ZTy8/M1bNgwlZSUaN++fTbNJtCF5uGBBx446/goLy+3aTZBampqNH36dGVmZio3N1fz5s1TS0tL1DYnTpxQZWWlRo4cqSuuuELz589Xd3e3UceJ8UXmYdasWWcdD0uWLDHquH+DIoBeffVVVVVVaeXKlXrnnXc0ZcoUlZWV6dChQ9atDbgbbrhBnZ2dkfG3v/3NuqWE6+np0ZQpU7R69ep+169atUrPP/+8XnjhBe3cuVOXX365ysrKdOLEiQHuNLEuNA+SVF5eHnV8rF+/fgA7TLzGxkZVVlaqublZ27Zt06lTp1RaWqqenp7INsuXL9eWLVu0YcMGNTY26uDBg7r77rsNu46/LzIPkrR48eKo42HVqlVGHZ+DGwRmzJjhKisrI69Pnz7tCgoKXE1NjWFXA2/lypVuypQp1m2YkuQ2btwYed3X1+eCwaD7xS9+EVl25MgR5/f73fr16w06HBifnwfnnFu0aJG78847TfqxcujQISfJNTY2OufO/Nunp6e7DRs2RLb54IMPnCTX1NRk1WbCfX4enHPutttuc9/73vfsmvoCkv4M6OTJk9q9e7dKSkoiy9LS0lRSUqKmpibDzmzs27dPBQUFGjdunBYuXKj29nbrlky1tbWpq6sr6vgIBAIqKiq6JI+PhoYG5ebmasKECXrkkUd0+PBh65YSKhQKSZKys7MlSbt379apU6eijoeJEydqzJgxKX08fH4ePvXyyy8rJydHkyZNUnV1tY4fP27R3jkl3cNIP++jjz7S6dOnlZeXF7U8Ly9P//znP426slFUVKTa2lpNmDBBnZ2deuaZZ3TrrbfqvffeU2ZmpnV7Jrq6uiSp3+Pj03WXivLyct19990qLCzU/v379cMf/lAVFRVqamrSkCFDrNuLu76+Pi1btkw333yzJk2aJOnM8ZCRkaERI0ZEbZvKx0N/8yBJ9913n8aOHauCggLt3btXP/jBD9TS0qI///nPht1GS/oAwv+pqKiIfD158mQVFRVp7Nixeu211/TQQw8ZdoZksGDBgsjXN954oyZPnqzx48eroaFBs2fPNuwsMSorK/Xee+9dEtdBz+dc8/Dwww9Hvr7xxhuVn5+v2bNna//+/Ro/fvxAt9mvpP8ILicnR0OGDDnrLpbu7m4Fg0GjrpLDiBEjdP3116u1tdW6FTOfHgMcH2cbN26ccnJyUvL4WLp0qbZu3ao33ngj6s+3BINBnTx5UkeOHInaPlWPh3PNQ3+KiookKamOh6QPoIyMDE2bNk319fWRZX19faqvr1dxcbFhZ/aOHTum/fv3Kz8/37oVM4WFhQoGg1HHRzgc1s6dOy/54+PAgQM6fPhwSh0fzjktXbpUGzdu1Ouvv67CwsKo9dOmTVN6enrU8dDS0qL29vaUOh4uNA/92bNnjyQl1/FgfRfEF/HKK684v9/vamtr3fvvv+8efvhhN2LECNfV1WXd2oD6/ve/7xoaGlxbW5t78803XUlJicvJyXGHDh2ybi2hjh496t5991337rvvOknul7/8pXv33Xfdf/7zH+ecc88++6wbMWKE27x5s9u7d6+78847XWFhofv444+NO4+v883D0aNH3WOPPeaamppcW1ub2759u5s6daq77rrr3IkTJ6xbj5tHHnnEBQIB19DQ4Do7OyPj+PHjkW2WLFnixowZ415//XW3a9cuV1xc7IqLiw27jr8LzUNra6v78Y9/7Hbt2uXa2trc5s2b3bhx49zMmTONO482KALIOed+85vfuDFjxriMjAw3Y8YM19zcbN3SgLvnnntcfn6+y8jIcFdffbW75557XGtrq3VbCffGG284SWeNRYsWOefO3Ir95JNPury8POf3+93s2bNdS0uLbdMJcL55OH78uCstLXVXXXWVS09Pd2PHjnWLFy9Ouf+k9ffzS3Jr166NbPPxxx+77373u+7KK690w4cPd3fddZfr7Oy0azoBLjQP7e3tbubMmS47O9v5/X537bXXuscff9yFQiHbxj+HP8cAADCR9NeAAACpiQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgIn/B1eq7YuwiWtRAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predict_image(3)" + ] + }, + { + "cell_type": "markdown", + "id": "747cd3d0-29d2-444b-83dc-1c4b57687704", + "metadata": {}, + "source": [ + "### **Binary-Class Classification**\n", + "\n", + "### **Dataset: [Breast Cancer Wisconsin (Diagnostic) Data Set](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Diagnostic%29)**" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7711d2b6-5aab-471c-aa45-06e0c5ddcb2c", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv(\"binaryclass_train.csv\", header = None)\n", + "data[\"label\"] = data[1].apply(lambda x: 1 if x == \"M\" else 0)\n", + "train, test = train_test_split(data, test_size = 0.3)\n", + "train_data = train.loc[:, ~train.columns.isin([0, 1, \"label\"])].to_numpy()\n", + "train_target = train[\"label\"].to_numpy()\n", + "test_data = test.loc[:, ~test.columns.isin([0, 1, \"label\"])].to_numpy()\n", + "test_target = test[\"label\"].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "eaa8f0cc-78f0-4984-b701-437195559a4a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789...232425262728293031label
361901041B13.3021.5785.24546.10.085820.063730.033440.02424...29.2092.94621.20.11400.166700.121200.056140.26370.066580
186874217M18.3118.58118.601041.00.085880.084680.081690.05814...26.36139.201410.00.12340.244500.353800.157100.32060.069381
199877500M14.4520.2294.49642.70.098720.120600.118000.05980...30.12117.901044.00.15520.405600.496700.183800.47530.101301
38990312M19.5523.21128.901174.00.101000.131800.185600.10210...30.44142.001313.00.12510.241400.382900.182500.25760.076021
388903011B11.2715.5073.38392.00.083650.111400.100700.02757...18.9379.73450.00.11020.280900.302100.082720.21570.104300
..................................................................
430907914M14.9022.53102.10685.00.099470.222500.273300.09711...27.57125.40832.70.14190.709000.901900.247500.28660.115501
3719012568B15.1913.2197.65711.80.079630.069340.033930.02657...15.73104.50819.10.11260.173700.136200.081780.24870.067660
4659113239B13.2420.1386.87542.90.082840.122300.101000.02833...25.50115.00733.50.12010.564600.655600.135700.28450.124900
60858970B10.1714.8864.55311.90.113400.080610.010840.01290...17.4569.86368.60.12750.098660.021680.025790.35570.080200
426907409B10.4814.9867.49333.60.098160.101300.063350.02218...21.5781.41440.40.13270.299600.293900.093100.30200.096460
\n", + "

398 rows × 33 columns

\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 7 8 \\\n", + "361 901041 B 13.30 21.57 85.24 546.1 0.08582 0.06373 0.03344 \n", + "186 874217 M 18.31 18.58 118.60 1041.0 0.08588 0.08468 0.08169 \n", + "199 877500 M 14.45 20.22 94.49 642.7 0.09872 0.12060 0.11800 \n", + "389 90312 M 19.55 23.21 128.90 1174.0 0.10100 0.13180 0.18560 \n", + "388 903011 B 11.27 15.50 73.38 392.0 0.08365 0.11140 0.10070 \n", + ".. ... .. ... ... ... ... ... ... ... \n", + "430 907914 M 14.90 22.53 102.10 685.0 0.09947 0.22250 0.27330 \n", + "371 9012568 B 15.19 13.21 97.65 711.8 0.07963 0.06934 0.03393 \n", + "465 9113239 B 13.24 20.13 86.87 542.9 0.08284 0.12230 0.10100 \n", + "60 858970 B 10.17 14.88 64.55 311.9 0.11340 0.08061 0.01084 \n", + "426 907409 B 10.48 14.98 67.49 333.6 0.09816 0.10130 0.06335 \n", + "\n", + " 9 ... 23 24 25 26 27 28 29 \\\n", + "361 0.02424 ... 29.20 92.94 621.2 0.1140 0.16670 0.12120 0.05614 \n", + "186 0.05814 ... 26.36 139.20 1410.0 0.1234 0.24450 0.35380 0.15710 \n", + "199 0.05980 ... 30.12 117.90 1044.0 0.1552 0.40560 0.49670 0.18380 \n", + "389 0.10210 ... 30.44 142.00 1313.0 0.1251 0.24140 0.38290 0.18250 \n", + "388 0.02757 ... 18.93 79.73 450.0 0.1102 0.28090 0.30210 0.08272 \n", + ".. ... ... ... ... ... ... ... ... ... \n", + "430 0.09711 ... 27.57 125.40 832.7 0.1419 0.70900 0.90190 0.24750 \n", + "371 0.02657 ... 15.73 104.50 819.1 0.1126 0.17370 0.13620 0.08178 \n", + "465 0.02833 ... 25.50 115.00 733.5 0.1201 0.56460 0.65560 0.13570 \n", + "60 0.01290 ... 17.45 69.86 368.6 0.1275 0.09866 0.02168 0.02579 \n", + "426 0.02218 ... 21.57 81.41 440.4 0.1327 0.29960 0.29390 0.09310 \n", + "\n", + " 30 31 label \n", + "361 0.2637 0.06658 0 \n", + "186 0.3206 0.06938 1 \n", + "199 0.4753 0.10130 1 \n", + "389 0.2576 0.07602 1 \n", + "388 0.2157 0.10430 0 \n", + ".. ... ... ... \n", + "430 0.2866 0.11550 1 \n", + "371 0.2487 0.06766 0 \n", + "465 0.2845 0.12490 0 \n", + "60 0.3557 0.08020 0 \n", + "426 0.3020 0.09646 0 \n", + "\n", + "[398 rows x 33 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "751772e4-e90b-4c11-ae63-cdb9602529e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---- Model Summary ----\n", + "Layer 1: relu\n", + "W: (16, 30) b: (16, 1)\n", + "Trainable parameters: 496\n", + "Layer 2: relu\n", + "W: (16, 16) b: (16, 1)\n", + "Trainable parameters: 272\n", + "Layer 3: sigmoid\n", + "W: (1, 16) b: (1, 1)\n", + "Trainable parameters: 17\n" + ] + } + ], + "source": [ + "NN = NeuralNetwork(input_size = train_data.shape[1])\n", + "NN.add_layer(16, \"relu\")\n", + "NN.add_layer(16, \"relu\")\n", + "NN.add_layer(1, \"sigmoid\")\n", + "NN.compile(loss = \"binary crossentropy\")\n", + "NN.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0f010d0a-ceef-4824-b36b-9752547248f1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training finished after epoch 1000 with a loss of 0.166389194481977.\n" + ] + } + ], + "source": [ + "hist = NN.fit(train_data, train_target, epochs = 1000, batch_size = 32, learning_rate = 0.01, verbose = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f4a324e2-2070-43d5-8cb6-8b07cbb69078", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAHHCAYAAACRAnNyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA+RklEQVR4nO3deXiU1d3G8XuyTfaFkBUCQUBWQWQTcH1BFChuYKtiS7HVIrGIWquUutUiuNRSa0WxFW0r4ApaKiogSlX2TRAIO4QlCRCy75nz/oEMDJlhDTmD+X6uay7I85yZ+c0DITdnexzGGCMAAAA/FGC7AAAAAF8IKgAAwG8RVAAAgN8iqAAAAL9FUAEAAH6LoAIAAPwWQQUAAPgtggoAAPBbBBUAAOC3CCoAzjmHw6EnnnjitJ+3Y8cOORwOvfHGG3VeE4DzA0EFaCDeeOMNORwOORwOffXVV7XOG2OUlpYmh8OhH/3oRxYqPHNffPGFHA6H3nvvPdulAKhjBBWggQkNDdW0adNqHf/yyy+1e/duOZ1OC1UBgHcEFaCBGThwoN59911VV1d7HJ82bZq6du2q5ORkS5UBQG0EFaCBue2223Tw4EHNnTvXfayyslLvvfeebr/9dq/PKSkp0YMPPqi0tDQ5nU61adNGzz//vI6/+XpFRYXuv/9+JSQkKCoqStdff712797t9TX37NmjO++8U0lJSXI6nerQoYNef/31uvugXmzbtk233HKLGjVqpPDwcF166aX673//W6vdX//6V3Xo0EHh4eGKi4tTt27dPHqhioqKNGbMGKWnp8vpdCoxMVHXXHONVq5ceU7rBxoiggrQwKSnp6tXr16aPn26+9icOXNUUFCgW2+9tVZ7Y4yuv/56/fnPf9Z1112nF154QW3atNFDDz2kBx54wKPtL3/5S02aNEn9+/fXxIkTFRwcrEGDBtV6zZycHF166aWaN2+e7r33Xv3lL39Rq1at9Itf/EKTJk2q88985D179+6tTz/9VKNGjdL48eNVXl6u66+/XjNnznS3e+211zR69Gi1b99ekyZN0pNPPqmLL75YS5YscbcZOXKkJk+erCFDhujll1/Wb37zG4WFhWnDhg3npHagQTMAGoSpU6caSWbZsmXmpZdeMlFRUaa0tNQYY8wtt9xirr76amOMMc2bNzeDBg1yP2/WrFlGkvnjH//o8XpDhw41DofDbNmyxRhjzOrVq40kM2rUKI92t99+u5FkHn/8cfexX/ziFyYlJcUcOHDAo+2tt95qYmJi3HVt377dSDJTp0494WdbsGCBkWTeffddn23GjBljJJn//e9/7mNFRUWmRYsWJj093dTU1BhjjLnhhhtMhw4dTvh+MTExJiMj44RtANQNelSABujHP/6xysrKNHv2bBUVFWn27Nk+h30+/vhjBQYGavTo0R7HH3zwQRljNGfOHHc7SbXajRkzxuNrY4zef/99DR48WMYYHThwwP249tprVVBQcE6GUD7++GP16NFDl112mftYZGSk7r77bu3YsUPr16+XJMXGxmr37t1atmyZz9eKjY3VkiVLtHfv3jqvE4AnggrQACUkJKhfv36aNm2aPvjgA9XU1Gjo0KFe2+7cuVOpqamKioryON6uXTv3+SO/BgQEqGXLlh7t2rRp4/H1/v37lZ+frylTpighIcHjMWLECElSbm5unXzO4z/H8bV4+xwPP/ywIiMj1aNHD7Vu3VoZGRn6+uuvPZ7z7LPPat26dUpLS1OPHj30xBNPaNu2bXVeMwApyHYBAOy4/fbbdddddyk7O1sDBgxQbGxsvbyvy+WSJN1xxx0aPny41zadOnWql1q8adeunTIzMzV79mx98sknev/99/Xyyy/rscce05NPPinpcI/U5ZdfrpkzZ+qzzz7Tc889p2eeeUYffPCBBgwYYK124IeIHhWggbrpppsUEBCgxYsX+xz2kaTmzZtr7969Kioq8ji+ceNG9/kjv7pcLm3dutWjXWZmpsfXR1YE1dTUqF+/fl4fiYmJdfERa32O42vx9jkkKSIiQj/5yU80depU7dq1S4MGDXJPvj0iJSVFo0aN0qxZs7R9+3bFx8dr/PjxdV430NARVIAGKjIyUpMnT9YTTzyhwYMH+2w3cOBA1dTU6KWXXvI4/uc//1kOh8Pdg3Dk1xdffNGj3fGreAIDAzVkyBC9//77WrduXa33279//5l8nJMaOHCgli5dqkWLFrmPlZSUaMqUKUpPT1f79u0lSQcPHvR4XkhIiNq3by9jjKqqqlRTU6OCggKPNomJiUpNTVVFRcU5qR1oyBj6ARowX0Mvxxo8eLCuvvpqjRs3Tjt27FDnzp312Wef6cMPP9SYMWPcc1Iuvvhi3XbbbXr55ZdVUFCg3r17a/78+dqyZUut15w4caIWLFignj176q677lL79u2Vl5enlStXat68ecrLyzujz/P++++7e0iO/5yPPPKIpk+frgEDBmj06NFq1KiR3nzzTW3fvl3vv/++AgIO/7+tf//+Sk5OVp8+fZSUlKQNGzbopZde0qBBgxQVFaX8/Hw1bdpUQ4cOVefOnRUZGal58+Zp2bJl+tOf/nRGdQM4AbuLjgDUl2OXJ5/I8cuTjTm8jPf+++83qampJjg42LRu3do899xzxuVyebQrKyszo0ePNvHx8SYiIsIMHjzYZGVl1VqebIwxOTk5JiMjw6SlpZng4GCTnJxs+vbta6ZMmeJuc7rLk309jixJ3rp1qxk6dKiJjY01oaGhpkePHmb27Nker/Xqq6+aK664wsTHxxun02latmxpHnroIVNQUGCMMaaiosI89NBDpnPnziYqKspERESYzp07m5dffvmENQI4Mw5jjttaEgAAwE8wRwUAAPgtggoAAPBbBBUAAOC3CCoAAMBvEVQAAIDfIqgAAAC/dV5v+OZyubR3715FRUXJ4XDYLgcAAJwCY4yKioqUmprq3mzRl/M6qOzdu1dpaWm2ywAAAGcgKytLTZs2PWGb8zqoHLntfFZWlqKjoy1XAwAATkVhYaHS0tLcP8dP5LwOKkeGe6KjowkqAACcZ05l2gaTaQEAgN8iqAAAAL9FUAEAAH6LoAIAAPwWQQUAAPgtggoAAPBbBBUAAOC3CCoAAMBvEVQAAIDfIqgAAAC/RVABAAB+i6ACAAD81nl9U8JzpbSyWnkllXIGBSohymm7HAAAGix6VLyYuz5Hlz2zQPfNWGW7FAAAGjSCygkYY7sCAAAaNoKKFw6HQ5JkRFIBAMAmgooXju9/pUcFAAC7CCpefN+hQn8KAACWEVS8cIikAgCAPyCoeHG0R4WkAgCATQQVLxwnbwIAAOoBQeUEmEwLAIBdBBUvmEwLAIB/IKh49f0+KnSpAABgFUHFC3pUAADwDwQVL9jwDQAA/0BQ8eLoFvoAAMAmgooXLE8GAMA/EFROhLEfAACsIqh4wWRaAAD8A0HFC3dQIakAAGAVQcWLIzcl5F4/AADYRVDxhh4VAAD8AkHFC/ZRAQDAPxBUvDiyjwoAALCLoHICdKgAAGCX1aBSU1OjRx99VC1atFBYWJhatmypp556yvrNAI8O/RBVAACwKcjmmz/zzDOaPHmy3nzzTXXo0EHLly/XiBEjFBMTo9GjR1uri5EfAAD8g9Wg8s033+iGG27QoEGDJEnp6emaPn26li5darOso8uT6VABAMAqq0M/vXv31vz587Vp0yZJ0po1a/TVV19pwIABNss6ZmdakgoAADZZ7VF55JFHVFhYqLZt2yowMFA1NTUaP368hg0b5rV9RUWFKioq3F8XFhaek7oY+QEAwD9Y7VF555139NZbb2natGlauXKl3nzzTT3//PN68803vbafMGGCYmJi3I+0tLRzWh9DPwAA2OUwFpe2pKWl6ZFHHlFGRob72B//+Ef9+9//1saNG2u199ajkpaWpoKCAkVHR9dZXd9sPaDbX1uiVomRmvfAlXX2ugAA4PDP75iYmFP6+W116Ke0tFQBAZ6dOoGBgXK5XF7bO51OOZ3Oc17X0cm0dKkAAGCT1aAyePBgjR8/Xs2aNVOHDh20atUqvfDCC7rzzjttlnXMZFoAAGCT1aDy17/+VY8++qhGjRql3Nxcpaam6le/+pUee+wxm2UdnUxLUgEAwCqrQSUqKkqTJk3SpEmTbJZRy5F7/ZBTAACwi3v9eMHOtAAA+AeCygkwmRYAALsIKl64b0potQoAAEBQ8cK96oekAgCAVQQVr45MpiWpAABgE0HFC3pUAADwDwQVL9xzVAgqAABYRVDxwsH6ZAAA/AJBBQAA+C2CihdHh34Y+wEAwCaCihfclBAAAP9AUPHCcWR5MkkFAACrCCpeHO1RIakAAGATQeUE6FEBAMAugooXrE4GAMA/EFROgA4VAADsIqh4wWRaAAD8A0HFi6NDPyQVAABsIqh4wU0JAQDwDwQVL9xDP5brAACgoSOoeHG0R4WoAgCATQQVL1idDACAfyConAD9KQAA2EVQ8YLJtAAA+AeCildH9lEhqQAAYBNBxYujNyUEAAA2EVS8YL83AAD8A0HFCwd3JQQAwC8QVE6ADhUAAOwiqHhxpD+FybQAANhFUPGCybQAAPgHgooX7nv9kFQAALCKoOLF0R4VkgoAADYRVE6AHhUAAOwiqHjB6mQAAPwDQeUE6FABAMAugooXDpb9AADgFwgqXrj3USGpAABgFUHFC3eHCjkFAACrCCpeuPdRsVwHAAANHUHFi6M9KkQVAABsIqh4wepkAAD8A0HlBOhPAQDALoKKN0ymBQDALxBUvHAw+AMAgF8gqHhx7Bb6TKgFAMAegooXx/ankFMAALCHoOKF45guFXIKAAD2EFS8YIYKAAD+gaByEsxRAQDAHoKKFx6Tae2VAQBAg0dQ8eLY5cl0qAAAYA9BxRuPHhWSCgAAthBUvPDcR8VeHQAANHQEFS9Y9QMAgH8gqHhx7D4qAADAHoLKSTD0AwCAPQQVLzy20GcyLQAA1hBUvGAyLQAA/oGg4oXHPioW6wAAoKEjqHjh2aNCVAEAwBaCCgAA8FsElZOgPwUAAHsIKl4wmRYAAP9AUPHCIW6fDACAPyCoeOHgpoQAAPgFgooXHhu+kVMAALCGoOLFsff6IacAAGCP9aCyZ88e3XHHHYqPj1dYWJguuugiLV++3GpN3JIQAAD/EGTzzQ8dOqQ+ffro6quv1pw5c5SQkKDNmzcrLi7OZlke2PANAAB7rAaVZ555RmlpaZo6dar7WIsWLSxWdJiDRT8AAPgFq0M/H330kbp166ZbbrlFiYmJ6tKli1577TWbJUk6bo4KSQUAAGusBpVt27Zp8uTJat26tT799FPdc889Gj16tN58802v7SsqKlRYWOjxONdYngwAgD1Wh35cLpe6deump59+WpLUpUsXrVu3Tq+88oqGDx9eq/2ECRP05JNP1kttDsf3vSnkFAAArLHao5KSkqL27dt7HGvXrp127drltf3YsWNVUFDgfmRlZZ2z2o4M/pBTAACwx2qPSp8+fZSZmelxbNOmTWrevLnX9k6nU06nsz5KOzxPhQkqAABYZbVH5f7779fixYv19NNPa8uWLZo2bZqmTJmijIwMm2V5IKsAAGCP1aDSvXt3zZw5U9OnT1fHjh311FNPadKkSRo2bJjNsiQdO/RDUgEAwBarQz+S9KMf/Ug/+tGPbJdRy5EVyvSoAABgj/Ut9P2V4/s+FXIKAAD2EFR8cfeoEFUAALCFoOKDe44KOQUAAGsIKj44uIUyAADWEVQAAIDfIqj44J5My9APAADWEFR8cC9PZt0PAADWEFR8YDItAAD2EVR8cDjYRwUAANsIKj4c7VEhqgAAYAtBxReWJwMAYB1B5SToTwEAwB6Cig9MpgUAwD6Cig8O99a0JBUAAGwhqPjg3keFnAIAgDUEFR/oTwEAwD6Cig8O7koIAIB1BJWTYOgHAAB7CCo+HB36IakAAGALQcUHJtMCAGAfQcWn7+/1Q1ABAMAagooP7h4Vhn4AALCGoOIDO9MCAGAfQcUHVicDAGAfQQUAAPgtgooPDibTAgBgHUHFBybTAgBgH0HFBybTAgBgH0HFhyP3+iGnAABgD0HlJAxdKgAAWENQ8YHlyQAA2EdQOQn6UwAAsIeg4gM3JQQAwD6Cig+Oo+t+rNYBAEBDRlDxgR4VAADsI6j4QH8KAAD2EVR8cO+jQlIBAMAagooPrE4GAMA+gspJsOEbAAD2EFR8cd+UEAAA2EJQ8YGbEgIAYB9BxYejNyUkqQAAYAtBxQf3ZFpyCgAA1hBUfHAwRwUAAOvOKKhkZWVp9+7d7q+XLl2qMWPGaMqUKXVWmG0OFigDAGDdGQWV22+/XQsWLJAkZWdn65prrtHSpUs1btw4/eEPf6jTAm1jMi0AAPacUVBZt26devToIUl655131LFjR33zzTd666239MYbb9RlfdYcHfohqQAAYMsZBZWqqio5nU5J0rx583T99ddLktq2bat9+/bVXXV+gB4VAADsOaOg0qFDB73yyiv63//+p7lz5+q6666TJO3du1fx8fF1WqAtR5cnAwAAW84oqDzzzDN69dVXddVVV+m2225T586dJUkfffSRe0jofHd0wzeiCgAAtgSdyZOuuuoqHThwQIWFhYqLi3Mfv/vuuxUeHl5nxdnkYNEPAADWnVGPSllZmSoqKtwhZefOnZo0aZIyMzOVmJhYpwXaRn8KAAD2nFFQueGGG/TPf/5TkpSfn6+ePXvqT3/6k2688UZNnjy5Tgu0xd2jQlIBAMCaMwoqK1eu1OWXXy5Jeu+995SUlKSdO3fqn//8p1588cU6LdCWIxu+sTwZAAB7ziiolJaWKioqSpL02Wef6eabb1ZAQIAuvfRS7dy5s04LtMW9jwo5BQAAa84oqLRq1UqzZs1SVlaWPv30U/Xv31+SlJubq+jo6Dot0Jajq36slgEAQIN2RkHlscce029+8xulp6erR48e6tWrl6TDvStdunSp0wKtYR8VAACsO6PlyUOHDtVll12mffv2ufdQkaS+ffvqpptuqrPibGJ1MgAA9p1RUJGk5ORkJScnu++i3LRp0x/MZm/HYsM3AADsOaOhH5fLpT/84Q+KiYlR8+bN1bx5c8XGxuqpp56Sy+Wq6xqtOHpTQgAAYMsZ9aiMGzdO//jHPzRx4kT16dNHkvTVV1/piSeeUHl5ucaPH1+nRdrAZFoAAOw7o6Dy5ptv6u9//7v7rsmS1KlTJzVp0kSjRo36YQQVdnwDAMC6Mxr6ycvLU9u2bWsdb9u2rfLy8s66KH9AjwoAAPadUVDp3LmzXnrppVrHX3rpJXXq1Omsi/IHzFEBAMC+Mxr6efbZZzVo0CDNmzfPvYfKokWLlJWVpY8//rhOC7TFwQJlAACsO6MelSuvvFKbNm3STTfdpPz8fOXn5+vmm2/Wd999p3/96191XaNVDP0AAGDPGe+jkpqaWmvS7Jo1a/SPf/xDU6ZMOevCrHMP/ZBUAACw5Yx6VBoCJtMCAGCf3wSViRMnyuFwaMyYMbZLkcRkWgAA/IFfBJVly5bp1Vdf9asVQ0cm07KFPgAA9pzWHJWbb775hOfz8/NPu4Di4mINGzZMr732mv74xz+e9vPPFQeLfgAAsO60gkpMTMxJz//sZz87rQIyMjI0aNAg9evX76RBpaKiQhUVFe6vCwsLT+u9TgdBBQAA+04rqEydOrVO33zGjBlauXKlli1bdkrtJ0yYoCeffLJOazgZRn4AALDH2hyVrKws3XfffXrrrbcUGhp6Ss8ZO3asCgoK3I+srKxzVp97jgrTaQEAsOaM91E5WytWrFBubq4uueQS97GamhotXLhQL730kioqKhQYGOjxHKfTKafTWS/1uVf9kFMAALDGWlDp27ev1q5d63FsxIgRatu2rR5++OFaIcUWggoAAPZYCypRUVHq2LGjx7GIiAjFx8fXOm6Dw3Fk6AcAANjiF/uo+KOjO9MSVQAAsMVaj4o3X3zxhe0S3FieDACAffSonAT9KQAA2ENQ8cHdoUJSAQDAGoKKD0cn05JUAACwhaDiw9HJtFbLAACgQSOo+ODe8M1uGQAANGgEFZ9Y9gMAgG0ElZNg6AcAAHsIKj4cHfohqQAAYAtBxQcm0wIAYB9BxQcm0wIAYB9BxQfHkT4VulQAALCGoOIDPSoAANhHUPGBmxICAGAfQeUkGPkBAMAegooPR+aoGJIKAADWEFR8YY4KAADWEVR8YB8VAADsI6j44Ph+Ni05BQAAewgqPhztUSGqAABgC0HFB5YnAwBgH0EFAAD4LYKKD0ymBQDAPoKKD0cn05JUAACwhaDiAz0qAADYR1DxhQ3fAACwjqDiw9Et9C0XAgBAA0ZQ8YHlyQAA2EdQOQkm0wIAYA9BxQcm0wIAYB9BxQeGfgAAsI+g4sPRybR0qQAAYAtBxYcjPSrkFAAA7CGo+OBgHxUAAKwjqPjEJBUAAGwjqJwEQz8AANhDUPHh6NAPSQUAAFsIKj6wjwoAAPYRVHxgMi0AAPYRVHxwiPXJAADYRlDxgZ1pAQCwj6ByEvSnAABgD0HFBybTAgBgH0HFB8f3Yz8sTwYAwB6CyknQowIAgD0EFR9YngwAgH0EFR+OLE+mRwUAAHsIKj6wPBkAAPsIKifBZFoAAOwhqPjg7lAhpwAAYA1BxYcjQz9fbz2glbsO2S0GAIAGiqDiw5F9VNbtKdTNL39juRoAABomgooPx8+lNSz/AQCg3hFUfDkuqZBTAACofwQVHxzHJRVyCgAA9Y+gcopcdKkAAFDvCCo+HL/hGzkFAID6R1Dx4fjJtPSoAABQ/wgqPrCFPgAA9hFUfDh+Mi09KgAA1D+Cig/MUQEAwD6Cig/MUQEAwD6CyikipgAAUP8IKr4cN/ZjXJbqAACgASOo+FDrXj/0qQAAUO8IKj4cP5nWRU4BAKDeEVR8qHWvHybTAgBQ7wgqPtCjAgCAfQQVH5ijAgCAfVaDyoQJE9S9e3dFRUUpMTFRN954ozIzM22W5BMjPwAA1D+rQeXLL79URkaGFi9erLlz56qqqkr9+/dXSUmJzbIksTMtAAD+IMjmm3/yySceX7/xxhtKTEzUihUrdMUVV1iq6jCHg3v9AABgm9WgcryCggJJUqNGjbyer6ioUEVFhfvrwsLCeqlLYmdaAABs8JvJtC6XS2PGjFGfPn3UsWNHr20mTJigmJgY9yMtLe2c1VNr1Q/LfgAAqHd+E1QyMjK0bt06zZgxw2ebsWPHqqCgwP3Iyso6Z/Ucv48KAACof34x9HPvvfdq9uzZWrhwoZo2beqzndPplNPprMfKjmKOCgAA9c9qUDHG6Ne//rVmzpypL774Qi1atLBZjgdW/QAAYJ/VoJKRkaFp06bpww8/VFRUlLKzsyVJMTExCgsLs1larYEfelQAAKh/VueoTJ48WQUFBbrqqquUkpLifrz99ts2y5LkpUfFThkAADRo1od+/BU3JQQAwD6/WfXjb5ijAgCAfQQVHwJq7UxrqRAAABowgooPwYHHDf0wSwUAgHpHUPEhMMDz0rhclgoBAKABI6j4EESPCgAA1hFUfKg19ENOAQCg3hFUfDh+6IegAgBA/SOo+BAccPyqH5IKAAD1jaDiQ2DA8XNUAABAfSOo+BAcePzQD1EFAID6RlDx4fhVP2z4BgBA/SOo+HD80A+DPwAA1D+Cig/HD/3QowIAQP0jqPhQazItQQUAgHpHUPEh+Pgt9EkqAADUO4KKD7W20CenAABQ7wgqPgTVGvohqQAAUN8IKj4EHb+PiqU6AABoyAgqPhzfo8IcFQAA6h9BxQfmqAAAYB9BxYcgVv0AAGAdQcWHWpNpLdUBAEBDRlDxofbQD1EFAID6RlDxofbdky0VAgBAA0ZQ8eH4LfS51w8AAPWPoOIDG74BAGAfQcUHh4MeFQAAbCOonDKSCgAA9Y2gcoroUQEAoP4RVE4RU1QAAKh/BJVTxM60AADUP4LKKSKmAABQ/wgqp4jlyQAA1D+CyikipwAAUP8IKqeIOSoAANQ/gsopIqcAAFD/CCqniB4VAADqH0HlFBFTAACofwSVU8SqHwAA6h9B5RSRUwAAqH8ElVPEvX4AAKh/BJVTZJilAgBAvSOonKJje1RcLqPcwnJ7xQAA0EAQVE7VMZNUxs1aqx5Pz9fnG3MsFgQAwA9fkO0Czhf7Cso1Z+0+bT9YoulLsyRJf5m/Rf/XNslyZQAA/HARVE7Ry19srXUsJizYQiUAADQcDP2cBYIKAADnFkHlLMQSVAAAOKcIKmfBGcTlAwDgXOIn7VmorHHZLgEAgB80gspZqKwmqAAAcC4RVE7g3ZG91C4lWi/8uLPX8xUEFQAAzimCygl0T2+kOfddrqvbJHo9P3PVHs1ataeeqwIAoOEgqJyCCKfv7WbGvL1aheVV+vS7bJVX1dRjVQAA/PARVE5ByElW92S8tVK/+tcK/fXzzR7HK6tdMoabGQIAcKYIKnXgf5sPSJJmfL+1viTll1aq14T5GvnvFbbKAgDgvEdQqUNHhoiMMfpsfY4OllTq0+9yVOOiVwUAgDNBUDlF13VIPmmbXXmlWpCZqxZjP9Zv3/vWfXxvfplW7jqkKi/7rhRXVGvlrkO1jheUVan6+/Yul1FZJfNfAAANj8Ocx5MoCgsLFRMTo4KCAkVHR5/T96qucamwvFrPfrJRM5ZlnfwJxxh0UYr+u3afft47XU9c38H9egEOh0a8sUxfbtqv8Td11LCezSVJG/YV6sa/fa0eLRqpoKxK3+4ukCT96soLNHZAu5O+nzFGDodDb36zQ03jwtS3Xf3f4XnFzkMqr6pRn1aN6/29AQD+7XR+fhNUzsDKXYd088vfnNFzb76kiapqjDbuK9Tm3GKPc7+4rIV+P6idbnllkZbvrN3LIkmv3NFV0aFB+sv8zbqlW5oOFleoSVyYHnh7jV4edom+2JSr+Rty9adbOuv2vy+RJG2fMFAOh8P9Gltyi7R+X5EGd0rxOH4mSiqqlV9WpSaxYe5jVTUutR43R5K06tFrFBcRclbvAQD4YTmdn9++193Cp0uaxemVOy7R9gOl2nmwRB2bxOj3s9ad0nM/WOl735V/fLVdLRMilZld5LPN+yt3a+76HEnSku15Hud++c/l7t+/tGCL+/e5RRVKig51f93/zwvlMoff7+Vhl3iEjO0HSrQ1t1jd0xspJtzzpotz1u5T+9RoNY+PcB8b9vclWp2Vr/dG9lKzRuFKjA7VnkNl7vM5ReUKCnToj7M36Mfd09S1eZz73JbcIjVrFOFzVVVeSaUmztmgH3dLU7f0Rj6viQ3b9hfrjW92aOSVLZV6zPUDANtcLqOAgLP7T6g/oUelDuzNL1PviZ+7v+7fPknf7S3UyKta6tFTCDB//1k3Pf9ZpjaeIKCcjX7tkjS0a1OfK5CiQ4M0pGtTTf16R61zP+vVXBlXt9KkeZs0fWmWQgID9MzQi1RcUaPZa/Z6hKWEKKe++M1VGv/xBk1bskuSNLxXc5VXufT28sPDZSsfvUYhQQH6essB/epfK3R7z2Z6YnAHBQY4FPj9N9aRoauH3l2jd1fsliTtmDhIuw6WKjYiWNGh9XvX6vKqGi3ctF9XtUl0h6orn1ugnQdL1bNFI739q15akJmrgtIq3dilSb3WdiLGGO04WKrmjcJ/UP9oncj2AyVyGaOWCZG2SwGsGPvBWs1dn6NPxlyuxpFO2+X4xNCPBev2FKiovFodmkR7/CD9aM1eLd52UI8Oaq9J8zdp7e4CfbP1oPv8H27ooJ/1Sld+aaV6Pj3f67b8V7dJUExYsGat3lsvn+VsNIkN0578spM39OJHnVLUu2VjPfrhOk36ycV6cf5m9/BY75bx+mbrQbVOjNTs0ZfJGCk0OFA5heX67XvfyhkUoCvbJLjn+RhjNGneZoUEBeieK1tq24FiPf3xRt11+QXq1TLe43035xQpNDhQEc4gPffpRt3UpaneX7FbLRIiFBcerIffXytJuuLCBD1/SyfNXLlHE+ZsdD//v6Mv06AXv5IkzX/wyhP+kMwvrZQkxYaHqKrGpeDAs5vPvnZ3gVzGqE1ylEKDA93HK6pr9NHqvXrovW/1yIC2GnllS0nSgsxcPfzet3ruls668sIE7c0vU3FFtS5Miqr12i6X0SffZatdSrTySyv1h9nr9YfrO6pxVIicQYFqdNyQ3t78MpVX1eiCswgJWXml2pxbpKvbJJ72sGR5VY3aPvqJJGntE/0VdRqB1hij0soaj80dSyqqFeEM0q6DpdqUU6S+7U6/pvOFMUa/fe9b1Rije65sqZe/2KqL02L1s17NT/kz7z5Uqo37itSnVWP9Z81eDeyUoshjrueKnYcU6QxSm+Taf9d8Ka+qkcMhOYMCT974e4dKKvXKwq0a1qO55m7IUcfUaPW8IP7kTzxNK3YeUnJMqEePtCTlFpbrif98p19cdoFHD/K5UuMyCnDI/eeU/sh/JUm/G9hWd1/R8pRew+Uyqqh2KSzk1K/z2SKo+DFjjGau2qPu6Y0UFhKo+IgQ91+wb7Yc0C//uVwBDoceG9xer3+1Xb+8/AIN7dpU0uFvjCGTj86NaZ0Y6f5BPuHmi7Ry5yF3D8TZuLpNghZk7j/r17GhbXKUOqTG6LPvslVUUe21zaCLUuQyRl9vOaDC8qNtziZkHevxwe21cle+duWV6srWjRUdFqzm8RFauGm//rV4p6TDvViS9Nvr2iosOFCHSivVPiVan63P0d78Mn32/fDe7wa2VYDDoY/W7JUzKEDLdx5S37aJahoXLmdwgF79clut9++QGq3v9hZ6HAsPCdSt3Zvp9a+3u4/98caO+v2sdQoKcOjpmy5S00Zh6pAao50HS7S/qEJfbzno0f5YQQEOVX+/7P6SZrHaur9EBWVVkqRnh3TSf77dq14t49WpSay+3npAnZvGKr+0Uh2bxCgmLFh/W7BFwYEB6tOqsTqkRquovFo5heUa+8FaZReW65eXtdC2AyWKDQ/WuIHtdKi0UpnZxfrL/E363cB26tUyXvmlVbp1ymJtP1CiSGeQio/7874gIUIy0tQR3ZUaG6bv9haqZUKEVu7KV0FZlfq1S9R3ewsVERKkn72+RAeKK9WxSbT+/rPuWrh5v3773re6s08LLduRp7V7CnRN+ySNG9hO6/YWKDgwQONmrtONF6dqwEXJCgwI0Lo9Bbq9RzNVu4z+uWiHmsaF68KkSG0/UKJpS3apotqlkspqtU6MVExYsGLCgjVjWZbaJEVp3KB2WrHzkJ7/LFMTh3RSh5Ro5ZVWqkXjCJVV1ujF+Vu042CJuqXH6ZauaUqIcmpzTpGiw4L11uKdahIXppsvaaq3Fu/Ush2HlF1Yrt/0b6PkmFB9uztf05bs0q68Uu0rKJckpcaE6tYezTTyypbalVeqiuoad9g+1ph+rbVtf4keuraNEqKcevObHZqycJsev76DGoWHqFVipGLDg+UMCtClE+Yrp7BCIYEBqqxxaUDHZP300ua6ICFSxRVV6vfCQklSXHiwpt99qfbml+n3M9fpJ92b6c7L0hXpDFK1y+jPczdpQeZ+pceHa+WuQ6pxGY3pd6G+yMzVqKtb6V+LdmrUVS3V+phw7XIZLd5+UOVVNZq2ZJfmbcj1+Bw7Jg5y/76sskblVTX6assBdU9vpC25xfr0u2xdmBSp4MAAXdcxWbHhIcopLFdOYbk6NY11Pze/tFKTv9yqS5rF6Vf/WqHY8GCt+P01CnBILiMFBjh05xvL9PnGw+//yh1ddW2HJLnM4RWe0aFBcjgcyiks16pdhxQeEqTU2FA5HA6VVFQrLjxElTUuXdA4Qquy8uVyGU3+Yquu65isSy+IV3BggGLCglXtcikqNFgrdh7S7a8tVs8L4pUQ6dQ17ZPcPee39UhTREiQNmYXaezAtmqXfPjfhanfbNdD17ZRdY1RYXmVjJHeXpald5Zn6Z1f9VJocKBeXbhVD/ZvUyuE1SWCyg9YQWmV/rV4h27v2VyNIkJUVlkjZ1CAu2t/6/5ird6Vr/TG4YpwBunCxCgdKq3UjGVZ6pIWqw3ZRfpw9R799tq22nagWPmlVVq09aAWbTvcy/PskE76cfc0SdI3Ww/o9teWuN87OTpUL97WRTOW7fKYaxPpDFJilFPbDpTI4ZDO379RgH8ICw5UmY9bcoQEBig40KGSH9iWBSFBAad9R/qo0CDVuA73hp1M2+QoORwObdhXeNK2Z1JfdGiQUmLClJlzdkP4wYEOdU9v5NHzblPHJtGa9JOL1Srx1HvCTgVBBaclv7RSOw+WqnNa7Cm1N8aosKxa4c5A5ZVUuifqHiyukCTlFFZoV16JIp3B6tIsVrvyStUuJVp5JZVak5WvEW8sk3T4f+Kv/LSrqmuM3vxmh15deLh3YOSVLZVdUCaHw6EeLRrpf5v3KzEqVEO7NtX+4gq9OH+zVu3K96ipXUq0NmYXyhgpPiJE3dMbKTQ4QD0viNd/1uzVppwiVbuMOqbG6GBJpXq3jFdBWZXeO6YHqtcF8XIZo6y8UvVtl6S563OUXVh+wmvRtXmcCsqqtOWYFVwtEyLUPD7C/b+q412cFquo0CBtyS1WWVWN4sJDFBcerJXHfSZJinIGqaiiWm2ToxQY4KjVU+JNUrRTOYUVXs8FBzoUFRqsjk1itDuvVNsOlEiSUmJCVVJR7dHDJEkRIYEKCQrQodIqj+PXdUjWJ99ln7SWYzWODFF+aZW7J6auDeiYrMycIm3bX3JOXt8bgjkagjZJUZpz3+V1OteNoILzUkV1jSqrXac8t+DIpNuzUVpZrbDgQK+vU1ZZoze+2aHLWzdW+5ToM/om/WbL4dsrtEyMVKQzSOVVNYr3McGtoLRKq3fnq6KqRte0T5LD4VBpZbU25xSrXUq0eyJvdY1L+4srFB0arLDgQJVX1yg0KNCjPmOMNuwrUllVjTqkRmvxtoNqmRCp+MgQhYccnTfgchk5jhnfLiqv0pqsArVPjZbLGPdkPJfLaF9huZxBAYoICfI5lr1sR54ckjo1jZWRkfm+yzs4IEAx4cEqrqhWYVmVIpxBOlhcoUhnkGZ/u0/pjcN11YWJyikqV1x4iHILKxQTHqz1ewu1cPN+Dbmkiapdh19vf1GFAhwOFZRVqVfL+FpzZaTD4/YHiiu0dneBHA4pIMChjqkxWrhpvzqnxcgZFKhZq/boqjaJKq+uUWxYsEKDA9UkNkyVNS7tL6rQltxiVVS7FBTgUKemMdqTX6bsgnK1TIzUnvwyXdk6QQEBDhWVV2np9jwt2npQHZvEKDosSM0aRWjb/mKt31eon3RPU15JpZbvODynYdv+ErVOjFTHJjGHe0WrarT7UKmWbc9TXmmVdhw4PMTTKiFSO/NKVVxerbzSSjWJDVO/dknafqDk8N/bkECt2HFIndNilZldpLiIEF3TPkmzv92rfy3aqYQop7o0izs8f+vCBDWKCJHDISVEOrW/uEKZ2UUyRjpYUqngQIeqa4z6tktUcGCAVuw8pNVZ+crKK9XqrHx1TovVTV2aKCYsWB1So/Xt7gLtL6rQ/uIKXdM+Sat25Sszu1CrduUrKNChPq0aq7yqRgM6pmjmqj1atPWgkmNClRwTqrLKGi3bkacrLkxQs0bhWr7jkC5IiFBVjUsrdh7SoItSNLRrU23YV6SdeSXauK/IPTyVXVih4ACHel7QSOVVLuUWVSi/tFKJUaFKjz/ci9w8PlxXXpig4opqzd+QqxfmbtLNlzRRi8YRKq9yqWOTaG3MLtLMlXs0qFOKOjWNUfNGEQoOcmjJtjwt2nZQuw6WqnereA28KEVz1u7T4m156tgkRrlF5Qp0ONQkLkxF5dXKK6lUZY1LXZvFqbC8Sku25am0qkYdUw9/v+aVVKq0skahwQF6a8kuGXN46HnHwRJ1ahqri9Ni5DLS5xtz1blpjBZuPqCIkEDVGKnG5VKzRoeHjBOjnbrx4iZqlxKtCGegXC7plYVblRoTqnuvbq1Pv8tWRY1LReVVKq9yadqSnTpQXKkLEiJ0e49m6tIsTvM25Ki6xqXOabFauGm/Plufo/9rk6imcWGSw6GuzeO0bk+B3vhmh/YXVejBay7U3oIyzf52n975VS+1S6nbn7EEFQAA4LdO5+e3X2yh/7e//U3p6ekKDQ1Vz549tXTpUtslAQAAP2A9qLz99tt64IEH9Pjjj2vlypXq3Lmzrr32WuXmeh/fBwAADYf1oPLCCy/orrvu0ogRI9S+fXu98sorCg8P1+uvv267NAAAYJnVoFJZWakVK1aoX79+7mMBAQHq16+fFi1aVKt9RUWFCgsLPR4AAOCHy2pQOXDggGpqapSU5Hl336SkJGVn1176OGHCBMXExLgfaWlp9VUqAACwwPrQz+kYO3asCgoK3I+srCzbJQEAgHPI6t2TGzdurMDAQOXk5Hgcz8nJUXJycq32TqdTTqf/3mQJAADULas9KiEhIeratavmz5/vPuZyuTR//nz16tXLYmUAAMAfWO1RkaQHHnhAw4cPV7du3dSjRw9NmjRJJSUlGjFihO3SAACAZdaDyk9+8hPt379fjz32mLKzs3XxxRfrk08+qTXBFgAANDxsoQ8AAOrVebeFPgAAgDcEFQAA4LcIKgAAwG9Zn0x7No5Mr2ErfQAAzh9Hfm6fyjTZ8zqoFBUVSRJb6QMAcB4qKipSTEzMCduc16t+XC6X9u7dq6ioKDkcjjp97cLCQqWlpSkrK4sVRecQ17l+cJ3rD9e6fnCd68e5us7GGBUVFSk1NVUBASeehXJe96gEBASoadOm5/Q9oqOj+SaoB1zn+sF1rj9c6/rBda4f5+I6n6wn5Qgm0wIAAL9FUAEAAH6LoOKD0+nU448/zt2azzGuc/3gOtcfrnX94DrXD3+4zuf1ZFoAAPDDRo8KAADwWwQVAADgtwgqAADAbxFUAACA3yKoePG3v/1N6enpCg0NVc+ePbV06VLbJZ1XJkyYoO7duysqKkqJiYm68cYblZmZ6dGmvLxcGRkZio+PV2RkpIYMGaKcnByPNrt27dKgQYMUHh6uxMREPfTQQ6qurq7Pj3JemThxohwOh8aMGeM+xnWuG3v27NEdd9yh+Ph4hYWF6aKLLtLy5cvd540xeuyxx5SSkqKwsDD169dPmzdv9niNvLw8DRs2TNHR0YqNjdUvfvELFRcX1/dH8Ws1NTV69NFH1aJFC4WFhally5Z66qmnPO4Hw7U+fQsXLtTgwYOVmpoqh8OhWbNmeZyvq2v67bff6vLLL1doaKjS0tL07LPP1s0HMPAwY8YMExISYl5//XXz3XffmbvuusvExsaanJwc26WdN6699lozdepUs27dOrN69WozcOBA06xZM1NcXOxuM3LkSJOWlmbmz59vli9fbi699FLTu3dv9/nq6mrTsWNH069fP7Nq1Srz8ccfm8aNG5uxY8fa+Eh+b+nSpSY9Pd106tTJ3Hfffe7jXOezl5eXZ5o3b25+/vOfmyVLlpht27aZTz/91GzZssXdZuLEiSYmJsbMmjXLrFmzxlx//fWmRYsWpqyszN3muuuuM507dzaLFy82//vf/0yrVq3MbbfdZuMj+a3x48eb+Ph4M3v2bLN9+3bz7rvvmsjISPOXv/zF3YZrffo+/vhjM27cOPPBBx8YSWbmzJke5+vimhYUFJikpCQzbNgws27dOjN9+nQTFhZmXn311bOun6BynB49epiMjAz31zU1NSY1NdVMmDDBYlXnt9zcXCPJfPnll8YYY/Lz801wcLB599133W02bNhgJJlFixYZYw5/YwUEBJjs7Gx3m8mTJ5vo6GhTUVFRvx/AzxUVFZnWrVubuXPnmiuvvNIdVLjOdePhhx82l112mc/zLpfLJCcnm+eee859LD8/3zidTjN9+nRjjDHr1683ksyyZcvcbebMmWMcDofZs2fPuSv+PDNo0CBz5513ehy7+eabzbBhw4wxXOu6cHxQqatr+vLLL5u4uDiPfzcefvhh06ZNm7OumaGfY1RWVmrFihXq16+f+1hAQID69eunRYsWWazs/FZQUCBJatSokSRpxYoVqqqq8rjObdu2VbNmzdzXedGiRbrooouUlJTkbnPttdeqsLBQ3333XT1W7/8yMjI0aNAgj+spcZ3rykcffaRu3brplltuUWJiorp06aLXXnvNfX779u3Kzs72uM4xMTHq2bOnx3WOjY1Vt27d3G369eungIAALVmypP4+jJ/r3bu35s+fr02bNkmS1qxZo6+++koDBgyQxLU+F+rqmi5atEhXXHGFQkJC3G2uvfZaZWZm6tChQ2dV43l9U8K6duDAAdXU1Hj8oy1JSUlJ2rhxo6Wqzm8ul0tjxoxRnz591LFjR0lSdna2QkJCFBsb69E2KSlJ2dnZ7jbe/hyOnMNhM2bM0MqVK7Vs2bJa57jOdWPbtm2aPHmyHnjgAf3ud7/TsmXLNHr0aIWEhGj48OHu6+TtOh57nRMTEz3OBwUFqVGjRlznYzzyyCMqLCxU27ZtFRgYqJqaGo0fP17Dhg2TJK71OVBX1zQ7O1stWrSo9RpHzsXFxZ1xjQQVnFMZGRlat26dvvrqK9ul/OBkZWXpvvvu09y5cxUaGmq7nB8sl8ulbt266emnn5YkdenSRevWrdMrr7yi4cOHW67uh+Wdd97RW2+9pWnTpqlDhw5avXq1xowZo9TUVK51A8bQzzEaN26swMDAWqsicnJylJycbKmq89e9996r2bNna8GCBWratKn7eHJysiorK5Wfn+/R/tjrnJyc7PXP4cg5HB7ayc3N1SWXXKKgoCAFBQXpyy+/1IsvvqigoCAlJSVxnetASkqK2rdv73GsXbt22rVrl6Sj1+lE/24kJycrNzfX43x1dbXy8vK4zsd46KGH9Mgjj+jWW2/VRRddpJ/+9Ke6//77NWHCBElc63Ohrq7pufy3hKByjJCQEHXt2lXz5893H3O5XJo/f7569eplsbLzizFG9957r2bOnKnPP/+8Vndg165dFRwc7HGdMzMztWvXLvd17tWrl9auXevxzTF37lxFR0fX+qHRUPXt21dr167V6tWr3Y9u3bpp2LBh7t9znc9enz59ai2v37Rpk5o3by5JatGihZKTkz2uc2FhoZYsWeJxnfPz87VixQp3m88//1wul0s9e/ash09xfigtLVVAgOePpcDAQLlcLklc63Ohrq5pr169tHDhQlVVVbnbzJ07V23atDmrYR9JLE8+3owZM4zT6TRvvPGGWb9+vbn77rtNbGysx6oInNg999xjYmJizBdffGH27dvnfpSWlrrbjBw50jRr1sx8/vnnZvny5aZXr16mV69e7vNHls3279/frF692nzyyScmISGBZbMnceyqH2O4znVh6dKlJigoyIwfP95s3rzZvPXWWyY8PNz8+9//dreZOHGiiY2NNR9++KH59ttvzQ033OB1eWeXLl3MkiVLzFdffWVat27doJfMejN8+HDTpEkT9/LkDz74wDRu3Nj89re/dbfhWp++oqIis2rVKrNq1Sojybzwwgtm1apVZufOncaYurmm+fn5Jikpyfz0pz8169atMzNmzDDh4eEsTz5X/vrXv5pmzZqZkJAQ06NHD7N48WLbJZ1XJHl9TJ061d2mrKzMjBo1ysTFxZnw8HBz0003mX379nm8zo4dO8yAAQNMWFiYady4sXnwwQdNVVVVPX+a88vxQYXrXDf+85//mI4dOxqn02natm1rpkyZ4nHe5XKZRx991CQlJRmn02n69u1rMjMzPdocPHjQ3HbbbSYyMtJER0ebESNGmKKiovr8GH6vsLDQ3HfffaZZs2YmNDTUXHDBBWbcuHEeS1651qdvwYIFXv9NHj58uDGm7q7pmjVrzGWXXWacTqdp0qSJmThxYp3U7zDmmC3/AAAA/AhzVAAAgN8iqAAAAL9FUAEAAH6LoAIAAPwWQQUAAPgtggoAAPBbBBUAAOC3CCoAznsOh0OzZs2yXQaAc4CgAuCs/PznP5fD4aj1uO6662yXBuAHIMh2AQDOf9ddd52mTp3qcczpdFqqBsAPCT0qAM6a0+lUcnKyx+PIHVMdDocmT56sAQMGKCwsTBdccIHee+89j+evXbtW//d//6ewsDDFx8fr7rvvVnFxsUeb119/XR06dJDT6VRKSoruvfdej/MHDhzQTTfdpPDwcLVu3VofffSR+9yhQ4c0bNgwJSQkKCwsTK1bt64VrAD4J4IKgHPu0Ucf1ZAhQ7RmzRoNGzZMt956qzZs2CBJKikp0bXXXqu4uDgtW7ZM7777rubNm+cRRCZPnqyMjAzdfffdWrt2rT766CO1atXK4z2efPJJ/fjHP9a3336rgQMHatiwYcrLy3O///r16zVnzhxt2LBBkydPVuPGjevvAgA4c3Vya0MADdbw4cNNYGCgiYiI8HiMHz/eGHP4btojR470eE7Pnj3NPffcY4wxZsqUKSYuLs4UFxe7z//3v/81AQEBJjs72xhjTGpqqhk3bpzPGiSZ3//+9+6vi4uLjSQzZ84cY4wxgwcPNiNGjKibDwygXjFHBcBZu/rqqzV58mSPY40aNXL/vlevXh7nevXqpdWrV0uSNmzYoM6dOysiIsJ9vk+fPnK5XMrMzJTD4dDevXvVt2/fE9bQqVMn9+8jIiIUHR2t3NxcSdI999yjIUOGaOXKlerfv79uvPFG9e7d+4w+K4D6RVABcNYiIiJqDcXUlbCwsFNqFxwc7PG1w+GQy+WSJA0YMEA7d+7Uxx9/rLlz56pv377KyMjQ888/X+f1AqhbzFEBcM4tXry41tft2rWTJLVr105r1qxRSUmJ+/zXX3+tgIAAtWnTRlFRUUpPT9f8+fPPqoaEhAQNHz5c//73vzVp0iRNmTLlrF4PQP2gRwXAWauoqFB2drbHsaCgIPeE1XfffVfdunXTZZddprfeektLly7VP/7xD0nSsGHD9Pjjj2v48OF64okntH//fv3617/WT3/6UyUlJUmSnnjiCY0cOVKJiYkaMGCAioqK9PXXX+vXv/71KdX32GOPqWvXrurQoYMqKio0e/Zsd1AC4N8IKgDO2ieffKKUlBSPY23atNHGjRslHV6RM2PGDI0aNUopKSmaPn262rdvL0kKDw/Xp59+qvvuu0/du3dXeHi4hgwZohdeeMH9WsOHD1d5ebn+/Oc/6ze/+Y0aN26soUOHnnJ9ISEhGjt2rHbs2KGwsDBdfvnlmjFjRh18cgDnmsMYY2wXAeCHy+FwaObMmbrxxhttlwLgPMQcFQAA4LcIKgAAwG8xRwXAOcXoMoCzQY8KAADwWwQVAADgtwgqAADAbxFUAACA3yKoAAAAv0VQAQAAfougAgAA/BZBBQAA+C2CCgAA8Fv/D/IzotoErG9XAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_history(hist);" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "8b581b00-8d8e-4ea6-80c1-8f11dfbad692", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training accuracy: 0.8869346733668342\n", + "Test accuracy: 0.8888888888888888\n" + ] + } + ], + "source": [ + "train_predictions = np.round(NN.predict(train_data))\n", + "print(\"Training accuracy: \", accuracy_score(train[\"label\"].to_numpy(), train_predictions))\n", + "test_predictions = np.round(NN.predict(test_data))\n", + "print(\"Test accuracy: \", accuracy_score(test[\"label\"].to_numpy(), test_predictions))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b49a45e0-9906-4a34-912c-f3ec10ea2fa2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}