-
Notifications
You must be signed in to change notification settings - Fork 0
/
MLTraining.py
113 lines (87 loc) · 3.69 KB
/
MLTraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# this is made for training the model based on Microsoft training on EDx
# It will make a model and train it based on the pictures in training folder
import os
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import tensorflow as tf
training_folder_name = "C:/Users/sular/Desktop/Dataset"
# All images are 128x128 pixels
img_size = (128, 128)
# The folder contains a subfolder for each class of shape
classes = sorted(os.listdir(training_folder_name))
print(classes)
batch_size = 10
print("Getting Data...")
datagen = ImageDataGenerator(rescale=1./255, # normalize pixel values
validation_split=0.2) # hold back 20% of the images for validation
print("Preparing training dataset...")
train_generator = datagen.flow_from_directory(
training_folder_name,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='training') # set as training data
print("Preparing validation dataset...")
validation_generator = datagen.flow_from_directory(
training_folder_name,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='validation') # set as validation data
# Define the model as a sequence of layers
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(50, (3, 3), input_shape=train_generator.image_shape, activation='relu'))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(40, activation='relu'))
model.add(tf.keras.layers.LeakyReLU(alpha=0.1))
model.add(tf.keras.layers.Dropout(0.6, noise_shape=None, seed=None))
model.add(tf.keras.layers.Dense(20, activation='relu'))
model.add(tf.keras.layers.LeakyReLU(alpha=0.1))
model.add(tf.keras.layers.Dense(train_generator.num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
checkpoint_path = "CategoryChkp/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1, period=2)
print(model.summary())
model.load_weights(checkpoint_path)
num_epochs = 1
history = model.fit_generator(
train_generator,
steps_per_epoch = train_generator.samples // batch_size,
validation_data = validation_generator,
validation_steps = validation_generator.samples // batch_size,
epochs = num_epochs,
callbacks=[cp_callback])
epoch_nums = range(1, num_epochs+1)
training_loss = history.history["loss"]
validation_loss = history.history["val_loss"]
plt.plot(epoch_nums, training_loss)
plt.plot(epoch_nums, validation_loss)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['training', 'validation'], loc='upper right')
plt.show()
print("Generating predictions from validation data...")
# Get the image and label arrays for the first batch of validation data
x_test = validation_generator[0][0]
y_test = validation_generator[0][1]
# Use the model to predict the class
class_probabilities = model.predict(x_test)
# The model returns a probability value for each class
# The one with the highest probability is the predicted class
predictions = np.argmax(class_probabilities, axis=1)
# The actual labels are hot encoded (e.g. [0 1 0], so get the one with the value 1
true_labels = np.argmax(y_test, axis=1)
# Plot the confusion matrix
cm = confusion_matrix(true_labels, predictions)
plt.imshow(cm, interpolation="nearest", cmap='gray')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=85)
plt.yticks(tick_marks, classes)
plt.xlabel("Predicted Class")
plt.ylabel("True Class")
plt.show()