-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfasttextRun.py
executable file
·117 lines (102 loc) · 3.72 KB
/
fasttextRun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/python3
"""
fasttextRun.py: run fasttext via python api
usage: fasttextRun.py -T trainFile -t testFile [-l]
notes: input line format: label token1 token2 ...
pyfasttext api: github.com/vrasneur/pyfasttext
20180212 erikt(at)xs4all.nl
"""
from pyfasttext import FastText
import getopt
import os
import random
import re
import sys
COMMAND = sys.argv[0]
USAGE = "usage: "+COMMAND+" -T trainFile -t testFile [-l]"
DIM = 300
MINCOUNT = 5
NBROFFOLDS = 10
TMPFILENAME = COMMAND+"."+str(os.getpid())+"."+str(int(random.randrange(1000000)))
showLabels = False
def processOpts(argv):
global showLabels
argv.pop(0)
try: options = getopt.getopt(argv,"T:t:l",[])
except: sys.exit(USAGE)
trainFile = ""
testFile = ""
for option in options[0]:
if option[0] == "-T": trainFile = option[1]
elif option[0] == "-t": testFile = option[1]
elif option[0] == "-l": showLabels = True
if trainFile == "": sys.exit(USAGE)
return(trainFile,testFile)
def readData(inFileName):
text = []
classes = []
try: inFile = open(inFileName,"r")
except: sys.exit(COMMAND+": cannot read file "+inFileName)
for line in inFile:
fields = line.split()
c = fields.pop(0)
text.append(fields)
classes.append(c)
inFile.close()
return({"text":text, "classes":classes})
def runExperiment(trainFileName,testFileName):
global TMPFILENAME, DIM, MINCOUNT
model = FastText()
model.supervised(input=trainFileName,output=TMPFILENAME,dim=DIM,minCount=MINCOUNT,verbose=0)
labels = model.predict_file(testFileName)
data = readData(testFileName)
correct = 0
for i in range(0,len(labels)):
data["classes"][i] = re.sub("__label__","",data["classes"][i])
if labels[i][0] == data["classes"][i]: correct += 1
os.unlink(TMPFILENAME+".bin")
os.unlink(TMPFILENAME+".vec")
return({"correct":correct,"labels":labels})
def writeFile(fileName,text,labels):
with open(fileName,"w") as f:
for i in range(0,len(labels)):
f.write(labels[i])
for token in text[i]: f.write(" "+token)
f.write("\n")
f.close()
def printResults(correct,nbrOfLabels,prefix):
print("{1:s}Correct: {0:0.1f}%".format(100*correct/nbrOfLabels,prefix))
def run10cv(trainFileName):
global TMPFILENAME, NBROFFOLDS
data = readData(trainFileName)
classes = data["classes"]
text = data["text"]
trainFileName = TMPFILENAME+".train"
testFileName = TMPFILENAME+".test"
totalCorrect = 0
labels = []
for i in range(0,NBROFFOLDS):
testStart = int(float(i)*float(len(text))/float(NBROFFOLDS))
testEnd = int(float(i+1)*float(len(text))/float(NBROFFOLDS))
writeFile(testFileName,text[testStart:testEnd],classes[testStart:testEnd])
writeFile(trainFileName,text[:testStart]+text[testEnd:],classes[:testStart]+classes[testEnd:])
results = runExperiment(trainFileName,testFileName)
totalCorrect += results["correct"]
labels.extend(results["labels"])
if not showLabels:
printResults(results["correct"],len(results["labels"]),"Fold: {0:2d}; ".format(i+1))
os.unlink(trainFileName)
os.unlink(testFileName)
return({"correct":totalCorrect,"labels":labels})
def main(argv):
global showLabels
trainFileName, testFileName = processOpts(argv)
if testFileName != "": results = runExperiment(trainFileName,testFileName)
else: results = run10cv(trainFileName)
if not showLabels: printResults(results["correct"],len(results["labels"]),"")
else:
for i in range(0,len(results["labels"])):
print(results["labels"][i][0])
return(0)
if __name__ == "__main__":
sys.exit(main(sys.argv))