-
Notifications
You must be signed in to change notification settings - Fork 1
/
nbc.go
86 lines (70 loc) · 2.2 KB
/
nbc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package main
import (
"flag"
"fmt"
"math"
"os"
)
var train *bool = flag.Bool("train", true, "training mode")
var class *string = flag.String("class", "true", "The class associated with this training set")
var filename *string = flag.String("filename", "./nbc.go", "the filename to read from in training mode")
var forget = flag.Bool("nuke", false, "forget the learned data")
var collection = flag.String("collection", "data", "The db collection to use")
var laplaceConstant = flag.Float64("k", 1, "The laplacian smoothing constant to use")
var nGramSize = flag.Int("n", 3, "The size of the ngrams")
var verbose = flag.Bool("v", false, "Be verbose")
func main() {
flag.Parse()
mongoConnect()
defer mongoDisconnect()
if *forget {
fmt.Printf("Forgetting learned data in %s.%s\n",mongoDB,mongoCollection)
forgetData()
os.Exit(0)
}
doc := NewDocument()
doc.TokenizeFile(*filename)
doc.GenerateNGrams(*nGramSize, *class)
if *train {
if *verbose { // dump out the ngrams we've discovered
for _, v := range doc.ngrams {
fmt.Printf("%d -> %s\n", v.Count[*class], v.Hash )
}
}
doc.DumpToMongo()
} else {
classCount := CountDistinctNGrams()
cb := GetClassProbabilities()
if *verbose {
for k, v := range cb {
fmt.Printf("P(%s) = %f\n", k, v)
}
}
for class, v := range cb {
totalngrams := GetTotalNGrams(class)
probabilities := make([]float64, doc.totalNgrams)
idx := 0
for _, v := range doc.ngrams {
instanceCount := v.GetInstanceCount(class)
probabilities[idx] = laplaceSmoothing(instanceCount, totalngrams, classCount)
if *verbose {
fmt.Printf("P(%s|%s) = (%d+1)/(%d+%d) = %f\n",
class, v.Hash, instanceCount, totalngrams, classCount, probabilities[idx] )
}
idx += 1
}
p := totalProbability(probabilities, v)
fmt.Printf("P(%s|Message) = %f\n", class, p)
}
}
}
func totalProbability(probabilities []float64, classProbability float64) float64 {
ret := classProbability
for _, v := range probabilities {
ret += math.Log(v)
}
return ret
}
func laplaceSmoothing(n int, N int, classCount int) float64 {
return ( float64(n) + *laplaceConstant ) / ( float64(N) + float64(classCount) )
}