-
Notifications
You must be signed in to change notification settings - Fork 0
/
pyMind_MNIST.py
127 lines (108 loc) · 5.55 KB
/
pyMind_MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import sys
import imageio
import cv2
from mlxtend.data import loadlocal_mnist
np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(formatter={'float': lambda x: "{0:5.3f}".format(x)})
#===========================================================================================================
# Initialize Parameters
#===========================================================================================================
n = 784 #Number of neurons (equal to the number of MNIST pixels)
eta = 0.01 #Learning rate (see https://en.wikipedia.org/wiki/Generalized_Hebbian_Algorithm)
actionPotential = 4.3/(n-1)
#===========================================================================================================
# Read in MNIST data set
#===========================================================================================================
X, y = loadlocal_mnist(
images_path='/home/evan/Desktop/ML_Stuff/PyMind/train-images.idx3-ubyte',
labels_path='/home/evan/Desktop/ML_Stuff/PyMind/train-labels.idx1-ubyte')
X.astype(float)
average_power = np.sum(X)/np.shape(X)[0]/255.
#===========================================================================================================
# Set up neurons and axons
#===========================================================================================================
xcoords = np.arange(0,784)%28 #np.random.rand(n) - 0.5
ycoords = np.repeat(np.arange(0,28,1),28)[::-1] #np.random.rand(n) - 0.5
axons = np.zeros((n,n))
#Randomly distribute initial axon strengths
for i in range (0,n):
for j in range (0,n):
if (i != j):
axons[i][j] = np.random.random()
#===========================================================================================================
# Train on MNIST images
#===========================================================================================================
#Train the neurons on various different input node activations
for train in range (0,1000):#np.shape(X)[0]):
#Activate the inital activation nodes by inputting a digit
activations = X[train]/255.
#Now we adjust the axons based on the signals that passed through them
#x = (activations*np.ones((n,n))).transpose()/(n-1) #How much is coming along each neuron along path ij
#np.fill_diagonal(x,0)
#y = np.sum(x*axons,axis=0) #How much is coming to each neuron j (the sum of all paths)
#dw = eta*(y*x - (y**2)*axons) #Generalized Hebbian Algorithm
diff = np.abs((activations*np.ones((n,n))).transpose()-activations)
if (train%1000 == 0): eta/=2
dw = eta*(np.outer(activations,activations) - diff*np.square(axons))
np.fill_diagonal(dw,0)
axons += dw
#sc = plt.scatter(xcoords, ycoords, s=activations*50+1, c=activations*1, alpha=0.8, vmin=0, vmax=1)
#max_dw = np.max(dw)
#for i in range(0,n):
# for j in range(0,n):
# if (dw[i][j] == max_dw):
# plt.plot((xcoords[i],ycoords[i]), 'r--')
#plt.show()
print(train)
#print(np.mean(dw))
#print(np.max(dw))
#===========================================================================================================
# Test
#===========================================================================================================
vmax = 0.08**2 #max plotting value
howbig = 15
activations = X[-1265]/255.
activations *= average_power/np.sum(activations)
plt.subplot(2,4,1)
sc = plt.scatter(xcoords, ycoords, s=activations*50+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = np.dot(activations/(n-1),axons) #The electric potential must be divided up by how many nodes it is sent to
activations = np.square(activations)
activations[np.where(activations < actionPotential)] = 0
plt.subplot(2,4,2)
sc = plt.scatter(xcoords, ycoords, s=activations*50*howbig+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = X[-6]/255.
activations *= average_power/np.sum(activations)
plt.subplot(2,4,3)
sc = plt.scatter(xcoords, ycoords, s=activations*50+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = np.dot(activations/(n-1),axons) #The electric potential must be divided up by how many nodes it is sent to
activations = np.square(activations)
activations[np.where(activations < actionPotential)] = 0
plt.subplot(2,4,4)
sc = plt.scatter(xcoords, ycoords, s=activations*50*howbig+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = X[-7]/255.
activations *= average_power/np.sum(activations)
plt.subplot(2,4,5)
sc = plt.scatter(xcoords, ycoords, s=activations*50+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = np.dot(activations/(n-1),axons) #The electric potential must be divided up by how many nodes it is sent to
activations = np.square(activations)
activations[np.where(activations < actionPotential)] = 0
plt.subplot(2,4,6)
sc = plt.scatter(xcoords, ycoords, s=activations*50*howbig+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
#Also read in a letter to cross check
letter = imageio.imread("symbol.png",as_gray=True)/255.
letter = cv2.resize(letter, (28,28), interpolation = cv2.INTER_AREA)
activations = np.ravel(1-letter)
activations *= average_power/np.sum(activations)
plt.subplot(2,4,7)
sc = plt.scatter(xcoords, ycoords, s=activations*50+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
activations = np.dot(activations/(n-1),axons) #The electric potential must be divided up by how many nodes it is sent to
activations = np.square(activations)
activations[np.where(activations < actionPotential)] = 0
plt.subplot(2,4,8)
sc = plt.scatter(xcoords, ycoords, s=activations*50*howbig+1, c=activations*1, alpha=0.8, vmin=0, vmax=vmax)
plt.show()
#1268 is tough