-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
75 lines (65 loc) · 2.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from argparse import ArgumentParser
from transformers import pipeline
from transformers.utils.dummy_pt_objects import load_tf_weights_in_bert_generation
from preprocessing.document import *
from preprocessing.config import *
from algorithm.basic_algorithm import *
from preprocessing.stanza_processor import *
from algorithm.mention_detection import mentionDetection
from evaluation.mention_matching import *
from evaluation.cluster_matching import *
from evaluation.score import writeConllForScoring
from algorithm.add_features import addFeatures
def logGreen(message: str):
print('\033[92m' + message + '\033[0m')
def main():
logGreen('Starting coreference prediction procedure')
parser = ArgumentParser()
parser.add_argument('configFile', help='Path to the config file')
args = parser.parse_args()
config = Config(args.configFile)
stanzaAnnotator = StanzaAnnotator()
nerPipeline = pipeline('ner', model='KB/bert-base-swedish-cased-ner', tokenizer='KB/bert-base-swedish-cased-ner')
docs = documentsFromTextinatorFile(config.inputFile)
if not config.useAllDocs:
if config.docId >= len(docs):
raise Exception(f'Document id {config.docId} out of bounds, check config')
docs = [docs[config.docId]]
correct = 0
falseNegatives = 0
falsePositives = 0
for id, doc in enumerate(docs):
logGreen(f'Processing document {id}')
load_tf_weights_in_bert_generation('Adding stanza annotation')
stanzaAnnotator.annotateDocument(doc)
if not config.useGoldMentions:
logGreen('Doing mention detection')
mentionDetection(doc)
else:
doc.predictedMentions = doc.goldMentions
logGreen('Preprocessing')
addStanzaLinksToGoldMentions(doc)
addFeatures(doc, nerPipeline)
logGreen('Coreference prediction')
predictCoreference(doc, config)
logGreen('Evaluation')
c, fp, fn = matchMentions(doc, config)
correct += c
falsePositives += fp
falseNegatives += fn
if config.compareClusters:
compareClusters(doc)
print()
precision = float(correct)/(float(correct)+float(falsePositives))
recall = float(correct)/(float(correct)+float(falseNegatives))
print("Total stats")
print(f'Correct mentions: {correct}')
print(f'Missed mentions: {falseNegatives}')
print(f'Extra mentions: {falsePositives}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')
#printStatistics(docs)
if config.writeForScoring:
writeConllForScoring(docs)
if __name__ == "__main__":
main()