-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNaiveBayes.java
executable file
·324 lines (315 loc) · 12.1 KB
/
NaiveBayes.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import java.io.*;
import java.util.*;
/*
* Name: Yan Pan
* this program assums the user would enter a integer value for N and 0 < N < number of biographies.
* Stopwords and Corpus inputs are from fileReader.
*/
public class NaiveBayes {
static int N;
static ArrayList<String> stopwords = new ArrayList<String>();
static Set<String> training_bio_words = new HashSet<String>(); // no repeats of all words in all bio
static Map<String, ArrayList<String>> nameToBio;
static Map<String, String> nameToLabel;
static ArrayList<String> names;
public static void main(String[] args) throws Exception{
// read N
try{
BufferedReader input = new BufferedReader(new InputStreamReader(System.in));
System.out.print("Type an integer for N and press Enter:");
N = Integer.parseInt(input.readLine());
}catch(IOException io){
}
//read stopwords
BufferedReader input2 = new BufferedReader(new FileReader("stopwords.txt"));
String dic;
while((dic = input2.readLine())!= null){
StringTokenizer st= new StringTokenizer(dic);
while(st.hasMoreTokens()){
String word = st.nextToken();
stopwords.add(word);
}
}
// read and parse corpus
BufferedReader in = new BufferedReader(new FileReader("corpus.txt"));
String line; int index=0;
names = new ArrayList<String>();
nameToLabel = new HashMap<String, String>(); // map name to label
nameToBio = new HashMap<String, ArrayList<String>>(); // map name to normalized bio
while((line = in.readLine())!= null){
while(line.trim().isEmpty()){ // multiple blank lines between bios
line=in.readLine();
}
index++;
String name = line;
names.add(name);
String label = in.readLine();
nameToLabel.put(name, label.trim()); // trim whitespaces at the end of the category
String bio,total_bio="";
while((bio=in.readLine())!=null&&(!bio.trim().isEmpty())){
total_bio+=bio.toLowerCase();
}
//System.out.println("Original: "+total_bio);
//System.out.println("Normalized: "+omitStopwords(omitPunctuation(total_bio)));
nameToBio.put(name,omitStopwords(omitPunctuation(total_bio)));
}
countC(nameToLabel);
int [][] occ_W = countW(nameToBio,nameToLabel);
//step 1
double [] freq_C = freqC();
double [][] freq_W = freqW(occ_W);
//step 2
double [] prob_C =prob_C(freq_C);
double [][] prob_W = prob_W(freq_W);
//step 3
double [] log_c = log_probC(prob_C);
double [][] log_w = log_probW(prob_W); // end of learning
omitUnseenW();
// apply classifier to test data
double [][] predict = computeL(log_w, log_c);
double [][] real_prob = recoverProb(predict);
output(real_prob, predict);
}
// omits all occurances of comma and periods
static String omitPunctuation(String bio){
String output="";
StringTokenizer s = new StringTokenizer(bio);
while(s.hasMoreTokens()){
String word = s.nextToken();
word = word.replaceAll(",","");
word = word.replaceAll("\\.","");
output+=(word+" ");
}
return output;
}
// omits stopwords and words with one or two letters
static ArrayList<String> omitStopwords(String bio){
ArrayList<String> al = new ArrayList<String>();
StringTokenizer s = new StringTokenizer(bio);
while(s.hasMoreTokens()){
String word = s.nextToken();
if(!stopwords.contains(word)&& word.length()>=3){
al.add(word);
}
}
return al;
}
static Map<String, Integer> labelToNum = new HashMap<String, Integer>();
// builds labelToNum, Occurance of C in training set
static Set<String> set = new HashSet<String>();
static void countC(Map<String, String> nameToLabel){
for(int i=0;i<N;i++){ // in the range of the training set, find out how many categories and save in a set
String name = names.get(i);
set.add(nameToLabel.get(name));
}
for(String s : set){
int count=0;
for(int i=0;i<N;i++){
String name = names.get(i);
if(nameToLabel.get(name).equals(s)){
count++;
}
}
labelToNum.put(s, count);
}
}
static int num_categ;
static int num_word_t;
static String [] train_words;
static int [][] countW(Map<String, ArrayList<String>> nameToBio, Map<String, String> nameToLabel){
// builds a set of different words from training set
for(int i=0;i<N;i++){
String name = names.get(i);
for(String s:nameToBio.get(name)){
training_bio_words.add(s);
}
}
//System.out.println(training_bio_words);
num_categ = set.size();
num_word_t = training_bio_words.size();
// store in array Occ_t(W|C)
int [][] Occ_t = new int[num_word_t][num_categ];
train_words = training_bio_words.toArray(new String[0]);
for(int i=0;i<num_word_t;i++){
String W = train_words[i];
for(int j=0;j<N;j++){
String name = names.get(j);
String label = nameToLabel.get(name);
int idx = categoryID(label); // For the tiny corpus example: writer -> 0, Music -> 1, Government -> 2
if(nameToBio.get(name).contains(W)){
Occ_t[i][idx]++;
}
}
}
return Occ_t;
}
// return a integer ID for each category
static String [] categ = new String[set.size()];
static int categoryID(String C){
int ID=-1;
categ = set.toArray(new String[0]);
for(int i=0;i<categ.length;i++){
if(categ[i].equals(C)){
ID=i;
}
}
return ID;
}
//step 1: calculate freq_T(C)
static double [] freqC(){
double [] freq = new double[num_categ];
for(String s: categ){
double occ = labelToNum.get(s);
int id = categoryID(s);
freq[id]= occ/(double)N;
//System.out.println(s+" Freq:"+ freq[id]);
}
return freq;
}
static double [][] freqW(int [][] occ){
double [][] freq = new double[num_word_t][num_categ];
for(int i=0;i<num_word_t;i++){
for(int j=0;j<num_categ;j++){
String C = categ[j];
double num = labelToNum.get(C);
freq[i][j]= occ[i][j]/num;
}
}
return freq;
}
//step 2: compute P(C) and P(W|C) using Laplacian correction
static double e= 0.1;
static double [] prob_C(double [] freq){
double [] prob = new double[num_categ];
for(int i=0;i<freq.length;i++){
prob[i]= (freq[i]+e)/(1+num_categ*e);
//System.out.println("prob:"+ prob[i]);
}
return prob;
}
static double [][] prob_W(double [][] freq){
double [][] prob = new double[num_word_t][num_categ];
for(int i=0;i<num_word_t;i++){
for(int j=0;j<num_categ;j++){
prob[i][j]= (freq[i][j]+e)/(1+2*e);
}
}
//System.out.println("prob:"+ prob[10][2]); // test if P(american|govern) match professor's result
return prob;
}
//step 3: negative log probability
// we can combine step 2 and step 3 in one function,it would be faster. But this way, we can access P(C) and P(W|C) if we need
static double log_base2(double x){
return Math.log(x)/Math.log(2);
}
static double [] log_probC(double[] prob){
double [] log_p = new double[num_categ];
for(int i=0;i<prob.length;i++){
log_p[i]= -log_base2(prob[i]);
//System.out.println("log:"+ log_p[i]);
}
return log_p;
}
static double [][] log_probW(double [][] prob){
double [][] log_p = new double[num_word_t][num_categ];
for(int i=0;i<num_word_t;i++){
for(int j=0;j<num_categ;j++){
log_p[i][j]= -log_base2(prob[i][j]);
}
}
//System.out.println("log:"+ log_p[10][2]); // test if match professor's result
return log_p;
}
// end of learning
// omit words not in training set
static void omitUnseenW(){
for(int i=N;i<names.size();i++){
String name = names.get(i);
ArrayList<String> bio = nameToBio.get(name);
ArrayList<String> normal= new ArrayList<String>();
for(String s: bio){
if(training_bio_words.contains(s)){
normal.add(s);
}
}
nameToBio.put(name, normal);
}
}
static double [][] computeL(double [][] log_w, double [] log_c){
double [][] predict = new double[names.size()-N][num_categ+2]; // store prediction in predict [x][num_categ]
// store min L(C|B) in predict [x][num_categ+1]
for(int i=N;i<names.size();i++){
String name = names.get(i);
ArrayList<String> bio = nameToBio.get(name);
double sum;
int prediction=-1;
double min= Integer.MIN_VALUE;
for(int j=0;j<num_categ;j++){
sum=0;
for(String s: bio){
int idx = Arrays.asList(train_words).indexOf(s);
//System.out.println(s+" "+log_w[idx][j]);
sum+= log_w[idx][j];
}
sum+= log_c[j]; // add L(C)
//System.out.println(j+" "+sum);
predict[i-N][j]=sum;
if(j==0){
min=sum;
prediction =0;
}
else{
if(sum<min){
min=sum;
prediction=j;
}
}
}
predict[i-N][num_categ]=prediction;
predict[i-N][num_categ+1]= min;
//System.out.println("prediction: "+prediction+ " min:"+min);
}
return predict;
}
static double [][] recoverProb( double [][] predict){
double [][] real_prob = new double [names.size()-N][num_categ+1];
for(int i=N;i<names.size();i++){ // test set range
double sum=0;
for(int j=0;j<num_categ;j++){
double expo = predict[i-N][num_categ+1]-predict[i-N][j];
real_prob[i-N][j]= Math.pow(2,expo);
sum+=real_prob[i-N][j];
}
real_prob[i-N][num_categ]=sum;
}
for(int i=N;i<names.size();i++){
for(int j=0;j<num_categ;j++){
real_prob[i-N][j]=(real_prob[i-N][j])/(real_prob[i-N][num_categ]);
//System.out.println("P: "+real_prob[i-N][j]);
}
}
return real_prob;
}
static void output(double [][] real_prob,double [][] predict){
StringBuilder sb = new StringBuilder();
int accurat_count=0;
for(int i=N;i<names.size();i++){
String name = names.get(i);
String p = categ[(int)predict[i-N][num_categ]].trim();
String correct = nameToLabel.get(name).trim();
String outcome="Wrong";
if(p.equals(correct)){
outcome="Right";
accurat_count++;
}
sb.append(name).append(" Prediction: ").append(p).append(" ").append(outcome).append("\n");
for(int j=0;j<num_categ;j++){
sb.append(categ[j]).append(": ").append(String.format("%.2f",real_prob[i-N][j])).append(" ");
}
sb.append("\n").append("\n");
}
int sizeTest = names.size()-N;
sb.append("Overall accuracy: ").append(accurat_count).append(" out of ").append(sizeTest).append(" = ").append(String.format("%.2f",accurat_count/(double)sizeTest)).append("\n");
System.out.print(sb);
}
}