Skip to content

Commit

Permalink
Merge pull request #261 from okcd00/master
Browse files Browse the repository at this point in the history
Solve the misalignment problem on UNK tokens (MacBERT).
  • Loading branch information
shibing624 authored Mar 21, 2022
2 parents 347f045 + 46233be commit 2f2051d
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions pycorrector/macbert/infer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected]), Abtion([email protected])
@author:XuMing([email protected]), Abtion([email protected]), okcd00([email protected])
@description:
"""
import sys
import operator
import torch
import operator
from transformers import BertTokenizer

sys.path.append('../..')
Expand All @@ -25,6 +25,7 @@ def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt',
logger.debug("device: {}".format(device))
self.tokenizer = BertTokenizer.from_pretrained(vocab_path)
cfg.merge_from_file(cfg_path)

if 'macbert4csc' in cfg_path:
self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path,
cfg=cfg,
Expand Down Expand Up @@ -75,23 +76,38 @@ def predict_with_error_detail(self, sentence_list):
sentence_list = [sentence_list]
corrected_texts = self.model.predict(sentence_list)

def get_errors(corrected_text, origin_text):
def get_errors(_corrected_text, _origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:

# Flags, we found that blanks are remained but enters are cleaned.
blanks_cleaned = False
enter_cleaned = True

for i, ori_char in enumerate(_origin_text):
if ori_char == " ":
# add blank word
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if blanks_cleaned else i + 1:]
continue
if ori_char == "\n":
# add enter word
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i if enter_cleaned else i + 1:]
continue
if ori_char in ['“', '”', '‘', '’', '琊', '…', '—', '擤']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:]
continue
if i >= len(corrected_text):
if i >= len(_corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
if ori_char != _corrected_text[i]:
# print(ori_char, corrected_text[i])
if (ori_char.lower() == _corrected_text[i]) or _corrected_text[i] == '֍':
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
_corrected_text = _corrected_text[:i] + ori_char + _corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details.append((ori_char, _corrected_text[i], i, i + 1))
# print(_corrected_text)
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
return _corrected_text, sub_details

for corrected_text, text in zip(corrected_texts, sentence_list):
corrected_text, sub_details = get_errors(corrected_text, text)
Expand All @@ -111,6 +127,7 @@ def get_errors(corrected_text, origin_text):
inputs = [
'它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。',
'老是较书。',
'少先队 员因该 为老人让 坐',
'感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
'遇到一位很棒的奴生跟我聊天。',
'遇到一位很美的女生跟我疗天。',
Expand Down

0 comments on commit 2f2051d

Please sign in to comment.