-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #261 from okcd00/master
Solve the misalignment problem on UNK tokens (MacBERT).
- Loading branch information
Showing
1 changed file
with
29 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]), Abtion([email protected]) | ||
@author:XuMing([email protected]), Abtion([email protected]), okcd00([email protected]) | ||
@description: | ||
""" | ||
import sys | ||
import operator | ||
import torch | ||
import operator | ||
from transformers import BertTokenizer | ||
|
||
sys.path.append('../..') | ||
|
@@ -25,6 +25,7 @@ def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt', | |
logger.debug("device: {}".format(device)) | ||
self.tokenizer = BertTokenizer.from_pretrained(vocab_path) | ||
cfg.merge_from_file(cfg_path) | ||
|
||
if 'macbert4csc' in cfg_path: | ||
self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path, | ||
cfg=cfg, | ||
|
@@ -75,23 +76,38 @@ def predict_with_error_detail(self, sentence_list): | |
sentence_list = [sentence_list] | ||
corrected_texts = self.model.predict(sentence_list) | ||
|
||
def get_errors(corrected_text, origin_text): | ||
def get_errors(_corrected_text, _origin_text): | ||
sub_details = [] | ||
for i, ori_char in enumerate(origin_text): | ||
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: | ||
|
||
# Flags, we found that blanks are remained but enters are cleaned. | ||
blanks_cleaned = False | ||
enter_cleaned = True | ||
|
||
for i, ori_char in enumerate(_origin_text): | ||
if ori_char == " ": | ||
# add blank word | ||
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blanks_cleaned else i + 1:] | ||
continue | ||
if ori_char == "\n": | ||
# add enter word | ||
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if enter_cleaned else i + 1:] | ||
continue | ||
if ori_char in ['“', '”', '‘', '’', '琊', '…', '—', '擤']: | ||
# add unk word | ||
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] | ||
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] | ||
continue | ||
if i >= len(corrected_text): | ||
if i >= len(_corrected_text): | ||
continue | ||
if ori_char != corrected_text[i]: | ||
if ori_char.lower() == corrected_text[i]: | ||
if ori_char != _corrected_text[i]: | ||
# print(ori_char, corrected_text[i]) | ||
if (ori_char.lower() == _corrected_text[i]) or _corrected_text[i] == '֍': | ||
# pass english upper char | ||
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] | ||
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:] | ||
continue | ||
sub_details.append((ori_char, corrected_text[i], i, i + 1)) | ||
sub_details.append((ori_char, _corrected_text[i], i, i + 1)) | ||
# print(_corrected_text) | ||
sub_details = sorted(sub_details, key=operator.itemgetter(2)) | ||
return corrected_text, sub_details | ||
return _corrected_text, sub_details | ||
|
||
for corrected_text, text in zip(corrected_texts, sentence_list): | ||
corrected_text, sub_details = get_errors(corrected_text, text) | ||
|
@@ -111,6 +127,7 @@ def get_errors(corrected_text, origin_text): | |
inputs = [ | ||
'它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。', | ||
'老是较书。', | ||
'少先队 员因该 为老人让 坐', | ||
'感谢等五分以后,碰到一位很棒的奴生跟我可聊。', | ||
'遇到一位很棒的奴生跟我聊天。', | ||
'遇到一位很美的女生跟我疗天。', | ||
|