-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinternal.go
67 lines (57 loc) · 1.41 KB
/
internal.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package gomnist
import (
"image/color"
"sync"
"github.com/petar/GoMNIST"
"gonum.org/v1/gonum/mat"
)
// LoadData gets Mnist data using petar/GoMNIST package.
func load(rootPath string) (train *GoMNIST.Set, test *GoMNIST.Set, err error) {
trainSet, testSet, err := GoMNIST.Load(rootPath)
if err != nil {
return nil, nil, err
}
return trainSet, testSet, nil
}
func set2Mat(s *GoMNIST.Set, normalization bool, oneHot bool) (data mat.Matrix, labels mat.Matrix) {
d := mat.NewDense(len(s.Images), s.NRow*s.NRow, nil)
var l *mat.Dense
if oneHot {
l = mat.NewDense(len(s.Labels), 10, nil)
} else {
l = mat.NewDense(len(s.Labels), 1, nil)
}
var wg sync.WaitGroup
for i := 0; i < len(s.Images); i++ {
wg.Add(1)
go func(i int) {
image, label := s.Get(i)
b := image.Bounds()
imageVec := make([]float64, 0, s.NRow*s.NRow)
for n := 0; n < b.Max.Y; n++ {
for m := 0; m < b.Max.X; m++ {
var v float64
if normalization {
v = float64(image.At(n, m).(color.Gray).Y) / 255
} else {
v = float64(image.At(n, m).(color.Gray).Y)
}
imageVec = append(imageVec, v)
}
}
var labelVec []float64
if oneHot {
labelVec = make([]float64, 10, 10)
labelVec[int(label)] = 1
} else {
labelVec = make([]float64, 1, 1)
labelVec = []float64{float64(label)}
}
d.SetRow(i, imageVec)
l.SetRow(i, labelVec)
wg.Done()
}(i)
}
wg.Wait()
return d, l
}