-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
25 lines (20 loc) · 889 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import pathlib
def get_training_data():
"""
extracts training data from npz file
:return: array of image pixel data (grayscale) and array of labels for training purposes
"""
with np.load(f"{pathlib.Path(__file__).parent.absolute()}/data/mnist.npz") as fi:
images, labels = fi["x_train"], fi["y_train"]
images = images.astype("float32") / 255
images = np.reshape(images, (images.shape[0], images.shape[1] * images.shape[2]))
labels = np.eye(10)[labels]
return images, labels
def get_test_data():
with np.load(f"{pathlib.Path(__file__).parent.absolute()}/data/mnist.npz") as fi:
images, labels = fi["x_test"], fi["y_test"]
images = images.astype("float32") / 255
images = np.reshape(images, (images.shape[0], images.shape[1] * images.shape[2]))
labels = np.eye(10)[labels]
return images, labels