NLP.TM[33] | 纠错:pycorrector的错误检测

技术

【 NLP.TM 】

本人有关自然语言处理和文本挖掘方面的学习和笔记,欢迎大家关注。

往期回顾

纠错是NLP中的一个看着不是很火但其实在现实应用中非常重要的一个部分,在一个强NLP以来的项目(如搜索)发展至中期,纠错就会成为一个效果提升的新增长点,经过统计,在微博等新媒体领域中,文本出错概率在2%左右,在语音识别领域中,出错率最高可达8-10%(数据来自:https://zhuanlan.zhihu.com/p/159101860),从这个比例来看,如果能修正这些错误,对效果的提升无疑是巨大的,那么我们来看看,纠错任务是怎么做的。

文章较长,懒人目录再现:

  • pycorrector简介
  • pycorrector的纠错思路
  • 混淆词典
  • 未登录词检测
  • 语言模型
  • 结果输出
  • 小结

pycorrector简介

pycorrector是非常基础的纠错模块工具,里面已经实现了一些非常通用的纠错方法,用里面的方法来做基线其实其实非常方便。

连接先放在这里:https://github.com/shibing624/pycorrector

他的使用方法其实也比较简单:


            
import pycorrector  
  
corrected_sent, detail = pycorrector.correct('少先队员因该为老人让坐')  
print(corrected_sent, detail)  

        

这是一个非常简单的官方case,详情还是可以去github里面去看看。

pycorrect的纠错思路

其实pycorrect里面造了很多飞机,不过实质上正式使用的还是非常经典的方法,来看看它的主函数具体思路是什么样的。


            
def correct(self, text, include\_symbol=True, num\_fragment=1, threshold=57, **kwargs):  
    """  
    句子改错  
    :param text: str, query 文本  
    :param include\_symbol: bool, 是否包含标点符号  
    :param num\_fragment: 纠错候选集分段数, 1 / (num\_fragment + 1)  
    :param threshold: 语言模型纠错ppl阈值  
    :param kwargs: ...  
    :return: text (str)改正后的句子, list(wrong, right, begin\_idx, end\_idx)  
    """  
    text_new = ''  
    details = []  
    self.check_corrector_initialized()  
    # 编码统一,utf-8 to unicode  
    text = convert_to_unicode(text)  
    # 长句切分为短句  
    blocks = self.split_2_short_text(text, include_symbol=include_symbol)  
    for blk, idx in blocks:  
        maybe_errors = self.detect_short(blk, idx)  
        for cur_item, begin_idx, end_idx, err_type in maybe_errors:  
            # 纠错,逐个处理  
            before_sent = blk[:(begin_idx - idx)]  
            after_sent = blk[(end_idx - idx):]  
  
            # 困惑集中指定的词,直接取结果  
            if err_type == ErrorType.confusion:  
                corrected_item = self.custom_confusion[cur_item]  
            else:  
                # 取得所有可能正确的词  
                candidates = self.generate_items(cur_item, fragment=num_fragment)  
                if not candidates:  
                    continue  
                corrected_item = self.get_lm_correct_item(cur_item, candidates, before_sent, after_sent,  
                                                          threshold=threshold)  
            # output  
            if corrected_item != cur_item:  
                blk = before_sent + corrected_item + after_sent  
                detail_word = [cur_item, corrected_item, begin_idx, end_idx]  
                details.append(detail_word)  
        text_new += blk  
    details = sorted(details, key=operator.itemgetter(2))  
    return text_new, details  

        

这里面其实还是比较明确的:

  • 分句。一个长句分成多个断句。
  • 对每个短句进行错误检测 detect\_short
  • 错误点召回可能正确的词。
  • 召回后筛选最佳结果。

在这个框架下,来看看具体pycorrect的错误检测是怎么做的。

混淆词典

直接看源码:


            
# 自定义混淆集加入疑似错误词典  
for confuse in self.custom_confusion:  
    idx = sentence.find(confuse)  
    if idx > -1:  
        maybe_err = [confuse, idx + start_idx, idx + len(confuse) + start_idx, ErrorType.confusion]  
        self._add_maybe_error_item(maybe_err, maybe_errors)  

        

这块其实还是比较简单的,其实就是用户自定义了一个词典,这个词典作者叫做混淆词典,我更愿意叫做改写词典,遇到了key,就去找v,直接做这种改写。

不过个人感觉这种遍历整个整个词典然后find的方法复杂度可能比较高,如果是我我还是比较喜欢最大逆向匹配的方式来查字典。

未登录词检测

同样上代码:


            
if self.is_word_error_detect:  
    # 切词  
    tokens = self.tokenizer.tokenize(sentence)  
    # 未登录词加入疑似错误词典  
    for token, begin_idx, end_idx in tokens:  
        # pass filter word  
        if self.is_filter_token(token):  
            continue  
        # pass in dict  
        if token in self.word_freq:  
            continue  
        maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]  
        self._add_maybe_error_item(maybe_err, maybe_errors)  

        

注释其实还是非常友好的,其实就这几个步骤:

  • 切词。
  • 跳过特定词汇的检测。
  • 查字典看是否有低频词(未登录词)出现。
  • 结果整理。

首先就是切词,这里的切词是一个函数,我们也来看看他具体切词是怎么切的:


            
class Tokenizer(object):  
    def \_\_init\_\_(self, dict\_path='', custom\_word\_freq\_dict=None, custom\_confusion\_dict=None):  
        self.model = jieba  
        self.model.default_logger.setLevel(logging.ERROR)  
        # 初始化大词典  
        if os.path.exists(dict_path):  
            self.model.set_dictionary(dict_path)  
        # 加载用户自定义词典  
        if custom_word_freq_dict:  
            for w, f in custom_word_freq_dict.items():  
                self.model.add_word(w, freq=f)  
  
        # 加载混淆集词典、  
        if custom_confusion_dict:  
            for k, word in custom_confusion_dict.items():  
                # 添加到分词器的自定义词典中  
                self.model.add_word(k)  
                self.model.add_word(word)  
  
    def tokenize(self, unicode\_sentence, mode="search"):  
        """  
        切词并返回切词位置, search mode用于错误扩召回  
        :param unicode\_sentence: query  
        :param mode: search, default, ngram  
        :param HMM: enable HMM  
        :return: (w, start, start + width) model='default'  
        """  
        if mode == 'ngram':  
            n = 2  
            result_set = set()  
            tokens = self.model.lcut(unicode_sentence)  
            tokens_len = len(tokens)  
            start = 0  
            for i in range(0, tokens_len):  
                w = tokens[i]  
                width = len(w)  
                result_set.add((w, start, start + width))  
                for j in range(i, i + n):  
                    gram = "".join(tokens[i:j + 1])  
                    gram_width = len(gram)  
                    if i + j > tokens_len:  
                        break  
                    result_set.add((gram, start, start + gram_width))  
                start += width  
            results = list(result_set)  
            result = sorted(results, key=lambda x: x[-1])  
        else:  
            result = list(self.model.tokenize(unicode_sentence, mode=mode))  
        return result  

        

看着很高端,稍微看看源码其实就可以发现用的是以jieba为基础的操作,只不过多了一种n-gram切词而已,其实就是切词以后按照n-gram拼装而已。

切完词后,就是过滤一些不需要检测的词汇,主要是一些数字之类的,来看看具体有哪些:


            
@staticmethod  
def is\_filter\_token(token):  
    result = False  
    # pass blank  
    if not token.strip():  
        result = True  
    # pass num  
    if token.isdigit():  
        result = True  
    # pass alpha  
    if is_alphabet_string(token.lower()):  
        result = True  
    # pass not chinese  
    if not is_chinese_string(token):  
        result = True  
    return result  

        
  • 空字符串
  • 数字
  • 字母
  • 非中文

然后就是判断是否是低频词,这个就比较容易,他是构建了一个词典,直接判断是否在里面就好了。

语言模型

NLP领域最基础的东西就要数语言模型了,这里的假设其实是人输入的语言大都是常用的,如果出现了不太常用的东西,其实说明是有错的,带着这个假设,我们来看看利用这个方法是怎么判错的。


            
# 语言模型检测疑似错误字  
try:  
    ngram_avg_scores = []  
    for n in [2, 3]:  
        scores = []  
        for i in range(len(sentence) - n + 1):  
            word = sentence[i:i + n]  
            score = self.ngram_score(list(word))  
            scores.append(score)  
        if not scores:  
            continue  
        # 移动窗口补全得分  
        for _ in range(n - 1):  
            scores.insert(0, scores[0])  
            scores.append(scores[-1])  
        avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]  
        ngram_avg_scores.append(avg_scores)  
  
    if ngram_avg_scores:  
        # 取拼接后的n-gram平均得分  
        sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))  
        # 取疑似错字信息  
        for i in self._get_maybe_error_index(sent_scores):  
            token = sentence[i]  
            # pass filter word  
            if self.is_filter_token(token):  
                continue  
            # pass in stop word dict  
            if token in self.stopwords:  
                continue  
            # token, begin\_idx, end\_idx, error\_type  
            maybe_err = [token, i + start_idx, i + start_idx + 1,  
                         ErrorType.char]  
            self._add_maybe_error_item(maybe_err, maybe_errors)  
except IndexError as ie:  
    logger.warn("index error, sentence:" + sentence + str(ie))  
except Exception as e:  
    logger.warn("detect error, sentence:" + sentence + str(e))  

        

首先这个是基于字来判断的,所以不需要切词,直接把字符串一个一个的拼接成n-gram即可。

要分析整个句子中每个位点字合理,是需要看上下文的,这里分别采用了2-gram和3-gram进行了分析,分别计算了一个叫做ngram_score的东西,具体是这样的:


            
def ngram\_score(self, chars):  
    """  
    取n元文法得分  
    :param chars: list, 以词或字切分  
    :return:  
    """  
    self.check_detector_initialized()  
    return self.lm.score(' '.join(chars), bos=False, eos=False)  

        

这里使用的是kenlm来训练的语言模型,然后用score进行得分计算,这个得分实质上就是分析这个句子组合产生的可能性,概率当然就是在之间了,然后取对数,因此这个得分就是一个非正数了,越接近0,说明这个组合出现的可能性越大,越不可能有错了。

另外,为了保证整个句子的完整性,是需要padding的,代码里做了一个移动窗口的处理,直接看可能有些难懂,但是知道了padding,应该会好明白一些:


            
# 移动窗口补全得分  
for _ in range(n - 1):  
    scores.insert(0, scores[0])  
    scores.append(scores[-1])  

        

然后就对分数进行根据句子长度的均值计算,计算完之后分别保存了每个字的2-gram得分和3-gram得分,然后后续取了这两个分数的均值,这里的代码这么看:


            
avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]  
ngram_avg_scores.append(avg_scores)  
  
if ngram_avg_scores:  
    # 取拼接后的n-gram平均得分  
    sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))  

        

然后就会开始对这个分数进行分析,最终抽取可能有问题的位点,使用的函数就是 \_get\_maybe\_error\_index


            
@staticmethod  
def \_get\_maybe\_error\_index(scores, ratio=0.6745, threshold=2):  
    """  
    取疑似错字的位置,通过平均绝对离差(MAD)  
    :param scores: np.array  
    :param ratio: 正态分布表参数  
    :param threshold: 阈值越小,得到疑似错别字越多  
    :return: 全部疑似错误字的index: list  
    """  
    result = []  
    scores = np.array(scores)  
    if len(scores.shape) == 1:  
        scores = scores[:, None]  
    median = np.median(scores, axis=0)  # get median of all scores  
    margin_median = np.abs(scores - median).flatten()  # deviation from the median  
    # 平均绝对离差值  
    med_abs_deviation = np.median(margin_median)  
    if med_abs_deviation == 0:  
        return result  
    y_score = ratio * margin_median / med_abs_deviation  
    # 打平  
    scores = scores.flatten()  
    maybe_error_indices = np.where((y_score > threshold) & (scores < median))  
    # 取全部疑似错误字的index  
    result = list(maybe_error_indices[0])  
    return result  

        

思路其实大概说了,就是基于平均离差来算,这其实就是常用异常检测的MAD。说白了就是整个句子,大部分情况是不会出错的,正常情况下打分就会在特定的一个范围内,但是出错的位置的打分会距离这个打分很远(可以理解为和常规语境和语言水平差别很大),我们需要把这几个打分比较远的对应位置提取出来。

另外这里蛮有意思的是,可以看到作者对numpy比较熟悉,可以看看里面这些操作。

结果输出

然后就是一些整理结果输出的操作了,基本的数据处理还是比较容易的,直接看看最终的输出格式吧


            
import pycorrector  
  
idx_errors = pycorrector.detect('少先队员因该为老人让坐')  
print(idx_errors)  
  
# 输出:[['因该', 4, 6, 'word'], ['坐', 10, 11, 'char']]  

        

会把他定的位置和错误类型给指出来,最终只需要整理出这个格式就行。

小结

这里给大家介绍的是pycorrector内baseline的检测方法,让大家理解最基本的错误识别方式。

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
VikingDB:大规模云原生向量数据库的前沿实践与应用
本次演讲将重点介绍 VikingDB 解决各类应用中极限性能、规模、精度问题上的探索实践,并通过落地的案例向听众介绍如何在多模态信息检索、RAG 与知识库等领域进行合理的技术选型和规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论