-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.go
113 lines (102 loc) · 4.71 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package main
import (
"flag"
"fmt"
"path/filepath"
"github.com/RenatoGeh/gospn/app"
"github.com/RenatoGeh/gospn/io"
"github.com/RenatoGeh/gospn/learn/gens"
"github.com/RenatoGeh/gospn/sys"
"github.com/RenatoGeh/gospn/utils"
//profile "github.com/pkg/profile"
)
var dataset = "olivetti_3bit"
func convertData() {
cmn, _ := filepath.Abs("data/" + dataset + "/")
io.BufferedPGMFToData(cmn, "all.data")
}
func main() {
var p float64
var clusters int
var rseed int64
var iterations int
var concurrents int
var mode string
flag.Float64Var(&p, "p", 0.7, "Train/test partition ratio to be used for cross-validation. ")
flag.IntVar(&clusters, "clusters", -1, "Number of clusters to be used during training. If "+
"clusters = -1, GoSPN shall use DBSCAN. Else, if clusters = -2, then use OPTICS "+
"(experimental). Else, if clusters > 0, then use k-means clustering with the indicated "+
"number of clusters.")
flag.Int64Var(&rseed, "rseed", -1, "Seed to be used when choosing which instances to be used as "+
"training set and which to be used as testing set. If omitted, rseed defaults to -1, which "+
"means GoSPN chooses a random seed according to the current time.")
flag.IntVar(&iterations, "iterations", 1, "How many iterations to be run when running a "+
"classification job. This allows for better, more general and randomized results, as some "+
"test/train partitions may become degenerated.")
flag.IntVar(&concurrents, "concurrents", -1, "GoSPN makes use of Go's native concurrency and is "+
"able to run on multiple cores in parallel. Argument concurrents defines the number of "+
"concurrent jobs GoSPN should run at most. If concurrents <= 0, then concurrents = nCPU, "+
"where nCPU is the number of CPUs the running machine has available.")
flag.StringVar(&dataset, "dataset", dataset, "The name of the directory containing the "+
"dataset structure inside the data folder. Setting -mode=data will cause a new given "+
"dataset data file to be created. Omitting -mode or setting -mode to something different "+
"than data will run a job on the given dataset.")
flag.IntVar(&sys.Width, "width", sys.Width, "The width of the images to be classified or "+
"completed.")
flag.IntVar(&sys.Height, "height", sys.Height, "The height of the images to be classified or "+
"completed.")
flag.IntVar(&sys.Max, "max", sys.Max, "The maximum pixel value the images can have.")
flag.StringVar(&mode, "mode", "cmpl", "Whether to convert a directory structure into a data "+
"file (data), run an image completion job (cmpl) or a classification job (class).")
flag.Float64Var(&sys.Pval, "pval", sys.Pval, "The significance value for the independence test.")
flag.Float64Var(&sys.Eps, "eps", sys.Eps, "The epsilon minimum distance value for DBSCAN.")
flag.IntVar(&sys.Mp, "mp", sys.Mp, "The minimum points density for DBSCAN.")
flag.BoolVar(&sys.Verbose, "v", sys.Verbose, "Verbose mode.")
flag.Parse()
if p == 0 || p < 0 || p == 1 {
fmt.Println("Argument p must be a float64 in range (0, 1).")
return
}
if iterations <= 0 {
fmt.Println("Argument iterations must be an integer greater than 0.")
return
}
//defer profile.Start().Stop()
in, _ := filepath.Abs("data/" + dataset + "/compiled")
//out, _ := filepath.Abs("results/" + dataset + "/models")
if mode == "data" {
fmt.Printf("Converting dataset %s...\n", dataset)
convertData()
return
} else if mode == "cmpl" {
fmt.Printf("Running image completion on dataset %s with %d threads...\n", dataset, concurrents)
lf := gens.Binded(clusters, sys.Pval, sys.Eps, sys.Mp)
app.ImgCompletion(lf, utils.StringConcat(in, "/all.data"), concurrents)
return
} else if mode == "class" {
lf := gens.Binded(clusters, sys.Pval, sys.Eps, sys.Mp)
app.ImgBatchClassify(lf, dataset, p, rseed, clusters, iterations)
} else if mode == "test" {
//_, data, _ := io.ParseDataNL("data/digits/compiled/all.data")
//_, data, _ := io.ParseDataNL("data/test/compiled/all.data")
//sys.Width, sys.Height = 4, 4
sys.Max = 256
//sys.Width, sys.Height = 92, 112
//sys.Width, sys.Height = 48, 56
sys.Width, sys.Height = 20, 30
//sys.Max = 8
//sys.Width, sys.Height = 4, 4
sys.Verbose = true
//app.ImgTest("data/olivetti_padded_u/compiled/all.data", 4, 20, 4, 0.1, 1)
app.ImgTestParallel("data/digits_x/compiled/all.data", 4, 4, 4, 0.1, 1, 3)
//app.ImgTest("data/fourbyfour/compiled/all.data", 2, 1, 1, 0.1)
//lf := learn.BindedPoonGD(2, 4, 0.1, 1)
//sc, data, _ := io.ParseDataNL(filename)
//S := lf(sc, data)
//app.ImgClassify(lf, "data/digits/compiled/all.data", 0.3, -1)
//app.ImgCompletion(lf, "data/olivetti_padded/compiled/all.data", 1)
//learn.PoonTest(data, 2, 2)
} else {
fmt.Printf("Mode %s not found. Possible mode options:\n cmpl, class, data\n", mode)
}
}