讯飞-非标准化疾病诉求的简单分诊挑战赛baseline

项目简介

赛道名称:讯飞-非标准化疾病诉求的简单分诊挑战赛

赛道链接:https://challenge.xfyun.cn/topic/info?type=disease-claims

赛事背景

人民对于医疗健康的需求在不断增长,但社会现阶段医疗资源紧缺,往往排队一上午看病10分钟,时间和精神成本巨大,如何更好的优化医疗资源配置,找到合适的方向,进行分级诊疗,是当前社会的重要课题。

大众自觉身体状态异常,有时不能准确判断自己是否患有疾病,需要寻求有专业知识的人进行判断,但是主诉者一般进行口语化表述,不容易进行精准高效的指引。

赛事任务

进行简单分诊需要一定的数据和经验知识进行支撑,本次比赛提供了部分好大夫在线的真实问诊数据,经过严格脱敏,提供给参赛者进行多分类任务,具体为通过处理文字诉求,给出10个常见的就诊方向之一。

数据说明

比赛提供约5000条训练数据,1000余条测试数据。

单条数据包含年龄段、主诉、标题、希望获得的帮助和其他描述字段文本,以及就诊方向标签 i∈int [0,9]

评估指标

macro F1-score

EDA

          
# 导入必要的库  
  
import warnings  
warnings.simplefilter('ignore')  
  
import numpy as np  
import pandas as pd  
  
%matplotlib  
import matplotlib.pyplot as plt  
import seaborn as sns  

      

          
Using matplotlib backend: TkAgg  

      

          
# 读取数据集  
  
train = pd.read_excel('data/data104082/train.xlsx')  
test = pd.read_excel('data/data104082/test.xlsx')  
  
print(train.shape, test.shape)  

      

          
(7844, 7) (1412, 6)  

      

          
train.head()  

      
idagediseaseNameconditionDesctitlehopeHelplabel
0150+肺部积水等右腹处疼痛 伴气短 检查出肺部有积水等症状入院时间咨询早上好 请问下**今天去医院可以安排吗8
1230+怎么才能知道自己缺钙有时候腰膝酸软四肢无力感觉,睡眠不是太好,总是爱做梦不知道是缺钙还是肾虚请问这是肾虚症状还是缺钙症状啊4
2320+舌头发麻,右下唇僵硬,带有发烧舌头发麻,右下唇僵硬,带有发烧,嗓子有炎症舌头发麻,右下唇僵硬,带有发烧给一些建议,是否需要进行下一步检查4
3430+后背不适,胸口对应处因经常嗳气,胃吃完顶人,后就医,胃镜显示浅表性胃炎,吃药一周余,有所好转,但后背一直不舒服,...后背不适,胸口对应处,不是疼,不是酸痛,就是难受如何调整治疗4
4540+喉咙充血,有异物感,不痛。有少许痰两个月前,喉部充血,咳嗽,去医院看过,咽炎,现在咳了,就是喉咙充血,有异物感。看吃点什么药能尽快好起来看吃点什么药能尽快好起来0

          
test.head()  

      
idagediseaseNameconditionDesctitlehopeHelp
0120+咳血有过胃病,胃穿孔,中午睡觉起来咳血,几年都这样希望得到大概是什么问题希望得到大概是什么问题
1250+喉咙有痰吐不出来我是2015年1月嗓子疼开始咳嗽,拍片支气管炎,开始输液,但咳嗽一直不好,一月后拍CT肺炎,...我该怎么办我该怎么办
2340+双腿总是抽筋男,43岁。平时双腿总抽筋从大腿跟一直到小腿晚上睡觉都能抽醒请您帮助我查下是什么病该如何看补钙也不好
3430+怀疑咽喉返流六月一个鼻出血后,咳嗽了半个月(当时喉咙灼热),半个月后咳嗽停止,从那以后到现在10个月来每...NaNNaN
4550+大便量少,拉完过一会儿又想拉拉不出来。以前有肠胃炎一天拉几次但是看好了,现在是刚大完量少过会儿又想拉拉不出来。\n大便量少,天天刚...刚大完便过会儿马上又想拉拉不出来有什么特效药网上开药先吃,谢谢!

          
# label 分布  
  
train['label'].value_counts()  

      

          
4    1086  
5     947  
8     888  
7     843  
0     842  
2     824  
1     818  
3     806  
6     516  
9     274  
Name: label, dtype: int64  

      

          
# 空值填充  
  
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:  
    train[col].fillna('', inplace=True)  
    test[col].fillna('', inplace=True)  

      

          
# 文本长度  
  
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:  
    print(train[col].apply(len).describe())  
    print(test[col].apply(len).describe())  

      

          
count    7844.000000  
mean       11.085798  
std         7.512438  
min         1.000000  
25%         5.000000  
50%        10.000000  
75%        17.000000  
max       101.000000  
Name: diseaseName, dtype: float64  
count    1412.000000  
mean       11.211048  
std         8.447323  
min         0.000000  
25%         5.000000  
50%        10.000000  
75%        17.000000  
max       147.000000  
Name: diseaseName, dtype: float64  
count    7844.000000  
mean       47.455762  
std        36.604923  
min         2.000000  
25%        23.000000  
50%        33.000000  
75%        58.000000  
max       248.000000  
Name: conditionDesc, dtype: float64  
count    1412.000000  
mean       49.145892  
std        37.890270  
min         2.000000  
25%        24.000000  
50%        33.000000  
75%        60.000000  
max       250.000000  
Name: conditionDesc, dtype: float64  
count    7844.000000  
mean       10.052907  
std         6.138516  
min         0.000000  
25%         6.000000  
50%         9.000000  
75%        12.000000  
max        50.000000  
Name: title, dtype: float64  
count    1412.000000  
mean       10.050283  
std         6.085816  
min         0.000000  
25%         6.000000  
50%         9.000000  
75%        12.000000  
max        50.000000  
Name: title, dtype: float64  
count    7844.000000  
mean       14.754080  
std        11.096384  
min         0.000000  
25%         8.000000  
50%        13.000000  
75%        18.000000  
max       130.000000  
Name: hopeHelp, dtype: float64  
count    1412.000000  
mean       15.459632  
std        12.526165  
min         0.000000  
25%         8.000000  
50%        13.000000  
75%        19.000000  
max       137.000000  
Name: hopeHelp, dtype: float64  

      

          
# 清理一些换行字符,不然会对创建数据集有影响  
  
def clean\_str(x):  
    return x.replace('\r', '').replace('\t', ' ').replace('\n', ' ')  
  
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:  
    train[col] = train[col].apply(lambda x: clean_str(x))  
    test[col] = test[col].apply(lambda x: clean_str(x))  

      

          
# 拼接的长度  
  
train['text'] = train['diseaseName'].astype(str) + " " + \  
                train['conditionDesc'].astype(str) + " " + \  
                train['title'].astype(str) + " " + \  
                train['hopeHelp'].astype(str)  
  
test['text'] =  test['diseaseName'].astype(str) + " " + \  
                test['conditionDesc'].astype(str) + " " + \  
                test['title'].astype(str) + " " + \  
                test['hopeHelp'].astype(str)  
  
print(train['text'].apply(len).describe())  
print(test['text'].apply(len).describe())  

      

          
count    7844.000000  
mean       86.343447  
std        41.443523  
min        19.000000  
25%        58.000000  
50%        75.000000  
75%       101.000000  
max       257.000000  
Name: text, dtype: float64  
count    1412.000000  
mean       88.853399  
std        43.299549  
min        24.000000  
25%        59.000000  
50%        76.000000  
75%       104.000000  
max       258.000000  
Name: text, dtype: float64  

      

看来 max_seq_len 我们使用 256 就可以覆盖所有的样例了

baseline 思路

该题为典型的文本多分类,我们可以先把文本都拼接起来使用 BERT 等预训练模型进行微调建模。

先手动 9:1 切分下训练集,将训练集、验证集和测试集保存为 ChnSentiCorp 格式。


          
train = train[['text', 'label']].copy()  
test = test[['text']].copy()  
  
train = train.sample(frac=1, random_state=42)  # 随机打乱  
  
train_size = int(0.9 * len(train))  
train_df = train[:train_size]  
valid_df = train[train_size:]  
test_df = test.copy()  
  
print(train_df.shape, valid_df.shape, test_df.shape)  

      

          
(7059, 2) (785, 2) (1412, 1)  

      

          
# 保存为文本文件  
  
train_df[['label', 'text']].to_csv('train.txt', index=False, header=False, sep='\t')  
valid_df[['label', 'text']].to_csv('valid.txt', index=False, header=False, sep='\t')  

      
自定义数据加载

          
# 更新 paddlehub  
  
!pip install -q --upgrade paddlehub -i https://pypi.tuna.tsinghua.edu.cn/simple  

      

          
from typing import Dict, List, Optional, Union, Tuple  
  
import paddle  
import paddlehub as hub  
from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset  
from paddlehub.text.bert_tokenizer import BertTokenizer  
from paddlehub.text.tokenizer import CustomTokenizer  
  
class DemoDataset(TextClassificationDataset):  
    def \_\_init\_\_(self, tokenizer: Union[BertTokenizer, CustomTokenizer], max\_seq\_len: int = 256, mode: str = 'train'):  
        base_path = './'  
        if mode == 'train':  
            data_file = 'train.txt'  
        elif mode == 'test':  
            data_file = 'test.txt'  
        else:  
            data_file = 'valid.txt'  
        super().__init__(  
            base_path=base_path,  
            tokenizer=tokenizer,  
            max_seq_len=max_seq_len,  
            mode=mode,  
            data_file=data_file,  
            label_list=["0", "1", "2", "3", "4",  
                        "5", "6", "7", "8", "9"],  
            is_file_with_header=False)  

      
加载预训练模型

我们用 ernie_tiny 模型测试下效果


          
model = hub.Module(name='ernie\_tiny', version='2.0.1', task='seq-cls', num_classes=10)  

      

          
[2021-08-12 20:53:41,155] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/ernie_tiny.pdparams  

      

          
# 生成数据集  
  
train_dataset = DemoDataset(tokenizer=model.get_tokenizer(), mode='train')  
valid_dataset = DemoDataset(tokenizer=model.get_tokenizer(), mode='valid')  

      

          
[2021-08-12 20:53:49,156] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt  
[2021-08-12 20:53:49,160] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model  
[2021-08-12 20:53:49,163] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle  
[2021-08-12 20:53:57,314] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt  
[2021-08-12 20:53:57,318] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model  
[2021-08-12 20:53:57,321] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle  

      
训练

          
optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters())  
trainer = hub.Trainer(model, optimizer, checkpoint_dir='./checkpoint', use_gpu=True)  

      

          
[2021-08-12 20:54:02,807] [ WARNING] - PaddleHub model checkpoint not found, start from scratch...  

      

          
trainer.train(  
    train_dataset,  
    epochs=5,  
    batch_size=32,  
    eval_dataset=valid_dataset,  
    save_interval=5,  
)  

      

          
[2021-08-12 20:54:04,768] [   TRAIN] - Epoch=1/5, Step=10/221 loss=2.2223 acc=0.1812 lr=0.000050 step/sec=5.11 | ETA 00:03:36  
[2021-08-12 20:54:06,600] [   TRAIN] - Epoch=1/5, Step=20/221 loss=1.8941 acc=0.4281 lr=0.000050 step/sec=5.46 | ETA 00:03:29  
[2021-08-12 20:54:08,432] [   TRAIN] - Epoch=1/5, Step=30/221 loss=1.4312 acc=0.6438 lr=0.000050 step/sec=5.46 | ETA 00:03:26  
[2021-08-12 20:54:10,267] [   TRAIN] - Epoch=1/5, Step=40/221 loss=1.0634 acc=0.7125 lr=0.000050 step/sec=5.45 | ETA 00:03:25  
[2021-08-12 20:54:12,098] [   TRAIN] - Epoch=1/5, Step=50/221 loss=0.8385 acc=0.7438 lr=0.000050 step/sec=5.46 | ETA 00:03:25  
[2021-08-12 20:54:13,937] [   TRAIN] - Epoch=1/5, Step=60/221 loss=0.6728 acc=0.8094 lr=0.000050 step/sec=5.44 | ETA 00:03:24  
[2021-08-12 20:54:15,775] [   TRAIN] - Epoch=1/5, Step=70/221 loss=0.6850 acc=0.7844 lr=0.000050 step/sec=5.44 | ETA 00:03:24  
[2021-08-12 20:54:17,606] [   TRAIN] - Epoch=1/5, Step=80/221 loss=0.5537 acc=0.8125 lr=0.000050 step/sec=5.46 | ETA 00:03:24  
[2021-08-12 20:54:19,449] [   TRAIN] - Epoch=1/5, Step=90/221 loss=0.5635 acc=0.8219 lr=0.000050 step/sec=5.43 | ETA 00:03:24  
[2021-08-12 20:54:21,291] [   TRAIN] - Epoch=1/5, Step=100/221 loss=0.5011 acc=0.8281 lr=0.000050 step/sec=5.43 | ETA 00:03:24  
[2021-08-12 20:54:23,129] [   TRAIN] - Epoch=1/5, Step=110/221 loss=0.5155 acc=0.8031 lr=0.000050 step/sec=5.44 | ETA 00:03:24  
[2021-08-12 20:54:24,971] [   TRAIN] - Epoch=1/5, Step=120/221 loss=0.5309 acc=0.8219 lr=0.000050 step/sec=5.43 | ETA 00:03:24  
[2021-08-12 20:54:26,815] [   TRAIN] - Epoch=1/5, Step=130/221 loss=0.5388 acc=0.8187 lr=0.000050 step/sec=5.42 | ETA 00:03:24  
[2021-08-12 20:54:28,654] [   TRAIN] - Epoch=1/5, Step=140/221 loss=0.5070 acc=0.8250 lr=0.000050 step/sec=5.44 | ETA 00:03:23  
[2021-08-12 20:54:30,491] [   TRAIN] - Epoch=1/5, Step=150/221 loss=0.4604 acc=0.8156 lr=0.000050 step/sec=5.44 | ETA 00:03:23  
[2021-08-12 20:54:32,333] [   TRAIN] - Epoch=1/5, Step=160/221 loss=0.3950 acc=0.8688 lr=0.000050 step/sec=5.43 | ETA 00:03:23  
[2021-08-12 20:54:34,178] [   TRAIN] - Epoch=1/5, Step=170/221 loss=0.4469 acc=0.8406 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:54:36,006] [   TRAIN] - Epoch=1/5, Step=180/221 loss=0.5265 acc=0.8219 lr=0.000050 step/sec=5.47 | ETA 00:03:23  
[2021-08-12 20:54:37,834] [   TRAIN] - Epoch=1/5, Step=190/221 loss=0.4789 acc=0.8375 lr=0.000050 step/sec=5.47 | ETA 00:03:23  
[2021-08-12 20:54:39,670] [   TRAIN] - Epoch=1/5, Step=200/221 loss=0.4247 acc=0.8750 lr=0.000050 step/sec=5.44 | ETA 00:03:23  
[2021-08-12 20:54:41,510] [   TRAIN] - Epoch=1/5, Step=210/221 loss=0.4540 acc=0.8313 lr=0.000050 step/sec=5.43 | ETA 00:03:23  
[2021-08-12 20:54:43,334] [   TRAIN] - Epoch=1/5, Step=220/221 loss=0.4790 acc=0.8219 lr=0.000050 step/sec=5.49 | ETA 00:03:23  
[2021-08-12 20:54:45,302] [   TRAIN] - Epoch=2/5, Step=10/221 loss=0.3176 acc=0.8812 lr=0.000050 step/sec=5.59 | ETA 00:03:23  
[2021-08-12 20:54:47,144] [   TRAIN] - Epoch=2/5, Step=20/221 loss=0.3946 acc=0.8562 lr=0.000050 step/sec=5.43 | ETA 00:03:23  
[2021-08-12 20:54:48,982] [   TRAIN] - Epoch=2/5, Step=30/221 loss=0.3702 acc=0.8688 lr=0.000050 step/sec=5.44 | ETA 00:03:23  
[2021-08-12 20:54:50,826] [   TRAIN] - Epoch=2/5, Step=40/221 loss=0.3149 acc=0.8969 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:54:52,675] [   TRAIN] - Epoch=2/5, Step=50/221 loss=0.2551 acc=0.9094 lr=0.000050 step/sec=5.41 | ETA 00:03:23  
[2021-08-12 20:54:54,514] [   TRAIN] - Epoch=2/5, Step=60/221 loss=0.3289 acc=0.8906 lr=0.000050 step/sec=5.44 | ETA 00:03:23  
[2021-08-12 20:54:56,358] [   TRAIN] - Epoch=2/5, Step=70/221 loss=0.3411 acc=0.8750 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:54:58,205] [   TRAIN] - Epoch=2/5, Step=80/221 loss=0.2811 acc=0.9125 lr=0.000050 step/sec=5.41 | ETA 00:03:23  
[2021-08-12 20:55:00,050] [   TRAIN] - Epoch=2/5, Step=90/221 loss=0.3720 acc=0.8406 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:55:01,902] [   TRAIN] - Epoch=2/5, Step=100/221 loss=0.3408 acc=0.8844 lr=0.000050 step/sec=5.40 | ETA 00:03:23  
[2021-08-12 20:55:03,752] [   TRAIN] - Epoch=2/5, Step=110/221 loss=0.2865 acc=0.8969 lr=0.000050 step/sec=5.41 | ETA 00:03:23  
[2021-08-12 20:55:05,602] [   TRAIN] - Epoch=2/5, Step=120/221 loss=0.2984 acc=0.9000 lr=0.000050 step/sec=5.40 | ETA 00:03:23  
[2021-08-12 20:55:07,451] [   TRAIN] - Epoch=2/5, Step=130/221 loss=0.2746 acc=0.8906 lr=0.000050 step/sec=5.41 | ETA 00:03:23  
[2021-08-12 20:55:09,309] [   TRAIN] - Epoch=2/5, Step=140/221 loss=0.3675 acc=0.8750 lr=0.000050 step/sec=5.38 | ETA 00:03:23  
[2021-08-12 20:55:11,152] [   TRAIN] - Epoch=2/5, Step=150/221 loss=0.3222 acc=0.8750 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:55:12,994] [   TRAIN] - Epoch=2/5, Step=160/221 loss=0.2806 acc=0.9000 lr=0.000050 step/sec=5.43 | ETA 00:03:23  
[2021-08-12 20:55:14,840] [   TRAIN] - Epoch=2/5, Step=170/221 loss=0.2960 acc=0.8875 lr=0.000050 step/sec=5.42 | ETA 00:03:23  
[2021-08-12 20:55:16,693] [   TRAIN] - Epoch=2/5, Step=180/221 loss=0.3479 acc=0.8781 lr=0.000050 step/sec=5.40 | ETA 00:03:23  
[2021-08-12 20:55:18,541] [   TRAIN] - Epoch=2/5, Step=190/221 loss=0.3052 acc=0.9000 lr=0.000050 step/sec=5.41 | ETA 00:03:23  
[2021-08-12 20:55:20,393] [   TRAIN] - Epoch=2/5, Step=200/221 loss=0.3134 acc=0.8938 lr=0.000050 step/sec=5.40 | ETA 00:03:23  
[2021-08-12 20:55:22,250] [   TRAIN] - Epoch=2/5, Step=210/221 loss=0.2998 acc=0.8781 lr=0.000050 step/sec=5.39 | ETA 00:03:23  
[2021-08-12 20:55:24,147] [   TRAIN] - Epoch=2/5, Step=220/221 loss=0.3500 acc=0.8812 lr=0.000050 step/sec=5.27 | ETA 00:03:23  
[2021-08-12 20:55:26,186] [   TRAIN] - Epoch=3/5, Step=10/221 loss=0.1696 acc=0.9469 lr=0.000050 step/sec=5.40 | ETA 00:03:23  
[2021-08-12 20:55:28,082] [   TRAIN] - Epoch=3/5, Step=20/221 loss=0.1910 acc=0.9281 lr=0.000050 step/sec=5.28 | ETA 00:03:23  
[2021-08-12 20:55:29,986] [   TRAIN] - Epoch=3/5, Step=30/221 loss=0.1739 acc=0.9375 lr=0.000050 step/sec=5.25 | ETA 00:03:24  
[2021-08-12 20:55:31,879] [   TRAIN] - Epoch=3/5, Step=40/221 loss=0.1944 acc=0.9344 lr=0.000050 step/sec=5.28 | ETA 00:03:24  
[2021-08-12 20:55:33,791] [   TRAIN] - Epoch=3/5, Step=50/221 loss=0.1955 acc=0.9375 lr=0.000050 step/sec=5.23 | ETA 00:03:24  
[2021-08-12 20:55:35,684] [   TRAIN] - Epoch=3/5, Step=60/221 loss=0.1678 acc=0.9406 lr=0.000050 step/sec=5.28 | ETA 00:03:24  
[2021-08-12 20:55:37,585] [   TRAIN] - Epoch=3/5, Step=70/221 loss=0.1499 acc=0.9469 lr=0.000050 step/sec=5.26 | ETA 00:03:24  
[2021-08-12 20:55:39,494] [   TRAIN] - Epoch=3/5, Step=80/221 loss=0.1700 acc=0.9531 lr=0.000050 step/sec=5.24 | ETA 00:03:24  
[2021-08-12 20:55:41,381] [   TRAIN] - Epoch=3/5, Step=90/221 loss=0.2193 acc=0.9094 lr=0.000050 step/sec=5.30 | ETA 00:03:24  
[2021-08-12 20:55:43,247] [   TRAIN] - Epoch=3/5, Step=100/221 loss=0.2092 acc=0.9313 lr=0.000050 step/sec=5.36 | ETA 00:03:24  
[2021-08-12 20:55:45,104] [   TRAIN] - Epoch=3/5, Step=110/221 loss=0.2010 acc=0.9437 lr=0.000050 step/sec=5.38 | ETA 00:03:24  
[2021-08-12 20:55:46,966] [   TRAIN] - Epoch=3/5, Step=120/221 loss=0.2267 acc=0.9219 lr=0.000050 step/sec=5.37 | ETA 00:03:24  
[2021-08-12 20:55:48,828] [   TRAIN] - Epoch=3/5, Step=130/221 loss=0.1778 acc=0.9437 lr=0.000050 step/sec=5.37 | ETA 00:03:24  
[2021-08-12 20:55:50,679] [   TRAIN] - Epoch=3/5, Step=140/221 loss=0.2249 acc=0.9094 lr=0.000050 step/sec=5.40 | ETA 00:03:24  
[2021-08-12 20:55:52,529] [   TRAIN] - Epoch=3/5, Step=150/221 loss=0.2156 acc=0.9375 lr=0.000050 step/sec=5.41 | ETA 00:03:24  
[2021-08-12 20:55:54,385] [   TRAIN] - Epoch=3/5, Step=160/221 loss=0.1875 acc=0.9437 lr=0.000050 step/sec=5.39 | ETA 00:03:24  
[2021-08-12 20:55:56,249] [   TRAIN] - Epoch=3/5, Step=170/221 loss=0.1529 acc=0.9500 lr=0.000050 step/sec=5.36 | ETA 00:03:24  
[2021-08-12 20:55:58,160] [   TRAIN] - Epoch=3/5, Step=180/221 loss=0.1865 acc=0.9375 lr=0.000050 step/sec=5.23 | ETA 00:03:24  
[2021-08-12 20:56:00,042] [   TRAIN] - Epoch=3/5, Step=190/221 loss=0.1899 acc=0.9375 lr=0.000050 step/sec=5.32 | ETA 00:03:24  
[2021-08-12 20:56:01,936] [   TRAIN] - Epoch=3/5, Step=200/221 loss=0.2619 acc=0.9000 lr=0.000050 step/sec=5.28 | ETA 00:03:25  
[2021-08-12 20:56:03,831] [   TRAIN] - Epoch=3/5, Step=210/221 loss=0.2324 acc=0.9156 lr=0.000050 step/sec=5.28 | ETA 00:03:25  
[2021-08-12 20:56:05,705] [   TRAIN] - Epoch=3/5, Step=220/221 loss=0.2106 acc=0.9344 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:56:07,705] [   TRAIN] - Epoch=4/5, Step=10/221 loss=0.1316 acc=0.9688 lr=0.000050 step/sec=5.50 | ETA 00:03:25  
[2021-08-12 20:56:09,609] [   TRAIN] - Epoch=4/5, Step=20/221 loss=0.1366 acc=0.9594 lr=0.000050 step/sec=5.25 | ETA 00:03:25  
[2021-08-12 20:56:11,470] [   TRAIN] - Epoch=4/5, Step=30/221 loss=0.1062 acc=0.9656 lr=0.000050 step/sec=5.37 | ETA 00:03:25  
[2021-08-12 20:56:13,328] [   TRAIN] - Epoch=4/5, Step=40/221 loss=0.1167 acc=0.9688 lr=0.000050 step/sec=5.38 | ETA 00:03:25  
[2021-08-12 20:56:15,182] [   TRAIN] - Epoch=4/5, Step=50/221 loss=0.1110 acc=0.9656 lr=0.000050 step/sec=5.39 | ETA 00:03:25  
[2021-08-12 20:56:17,040] [   TRAIN] - Epoch=4/5, Step=60/221 loss=0.1017 acc=0.9719 lr=0.000050 step/sec=5.38 | ETA 00:03:25  
[2021-08-12 20:56:18,900] [   TRAIN] - Epoch=4/5, Step=70/221 loss=0.1088 acc=0.9688 lr=0.000050 step/sec=5.38 | ETA 00:03:25  
[2021-08-12 20:56:20,758] [   TRAIN] - Epoch=4/5, Step=80/221 loss=0.1037 acc=0.9750 lr=0.000050 step/sec=5.38 | ETA 00:03:25  
[2021-08-12 20:56:22,628] [   TRAIN] - Epoch=4/5, Step=90/221 loss=0.1002 acc=0.9625 lr=0.000050 step/sec=5.35 | ETA 00:03:25  
[2021-08-12 20:56:24,504] [   TRAIN] - Epoch=4/5, Step=100/221 loss=0.1085 acc=0.9563 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:26,373] [   TRAIN] - Epoch=4/5, Step=110/221 loss=0.0752 acc=0.9750 lr=0.000050 step/sec=5.35 | ETA 00:03:25  
[2021-08-12 20:56:28,257] [   TRAIN] - Epoch=4/5, Step=120/221 loss=0.0959 acc=0.9656 lr=0.000050 step/sec=5.31 | ETA 00:03:25  
[2021-08-12 20:56:30,118] [   TRAIN] - Epoch=4/5, Step=130/221 loss=0.0969 acc=0.9688 lr=0.000050 step/sec=5.37 | ETA 00:03:25  
[2021-08-12 20:56:31,995] [   TRAIN] - Epoch=4/5, Step=140/221 loss=0.1358 acc=0.9469 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:33,873] [   TRAIN] - Epoch=4/5, Step=150/221 loss=0.1128 acc=0.9719 lr=0.000050 step/sec=5.32 | ETA 00:03:25  
[2021-08-12 20:56:35,738] [   TRAIN] - Epoch=4/5, Step=160/221 loss=0.1150 acc=0.9563 lr=0.000050 step/sec=5.36 | ETA 00:03:25  
[2021-08-12 20:56:37,612] [   TRAIN] - Epoch=4/5, Step=170/221 loss=0.0940 acc=0.9625 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:56:39,489] [   TRAIN] - Epoch=4/5, Step=180/221 loss=0.0867 acc=0.9719 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:41,366] [   TRAIN] - Epoch=4/5, Step=190/221 loss=0.0969 acc=0.9625 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:43,248] [   TRAIN] - Epoch=4/5, Step=200/221 loss=0.1025 acc=0.9656 lr=0.000050 step/sec=5.31 | ETA 00:03:25  
[2021-08-12 20:56:45,121] [   TRAIN] - Epoch=4/5, Step=210/221 loss=0.1372 acc=0.9500 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:56:46,987] [   TRAIN] - Epoch=4/5, Step=220/221 loss=0.0801 acc=0.9781 lr=0.000050 step/sec=5.36 | ETA 00:03:25  
[2021-08-12 20:56:48,979] [   TRAIN] - Epoch=5/5, Step=10/221 loss=0.0604 acc=0.9875 lr=0.000050 step/sec=5.52 | ETA 00:03:25  
[2021-08-12 20:56:50,857] [   TRAIN] - Epoch=5/5, Step=20/221 loss=0.0337 acc=0.9906 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:52,726] [   TRAIN] - Epoch=5/5, Step=30/221 loss=0.0336 acc=0.9938 lr=0.000050 step/sec=5.35 | ETA 00:03:25  
[2021-08-12 20:56:54,603] [   TRAIN] - Epoch=5/5, Step=40/221 loss=0.0772 acc=0.9781 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:56:56,500] [   TRAIN] - Epoch=5/5, Step=50/221 loss=0.0518 acc=0.9844 lr=0.000050 step/sec=5.27 | ETA 00:03:25  
[2021-08-12 20:56:58,379] [   TRAIN] - Epoch=5/5, Step=60/221 loss=0.0603 acc=0.9781 lr=0.000050 step/sec=5.32 | ETA 00:03:25  
[2021-08-12 20:57:00,255] [   TRAIN] - Epoch=5/5, Step=70/221 loss=0.0578 acc=0.9875 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:57:02,120] [   TRAIN] - Epoch=5/5, Step=80/221 loss=0.0897 acc=0.9719 lr=0.000050 step/sec=5.36 | ETA 00:03:25  
[2021-08-12 20:57:03,997] [   TRAIN] - Epoch=5/5, Step=90/221 loss=0.0341 acc=0.9906 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:57:05,879] [   TRAIN] - Epoch=5/5, Step=100/221 loss=0.0322 acc=0.9938 lr=0.000050 step/sec=5.31 | ETA 00:03:25  
[2021-08-12 20:57:07,753] [   TRAIN] - Epoch=5/5, Step=110/221 loss=0.0264 acc=0.9969 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:57:09,628] [   TRAIN] - Epoch=5/5, Step=120/221 loss=0.0279 acc=0.9938 lr=0.000050 step/sec=5.33 | ETA 00:03:25  
[2021-08-12 20:57:11,494] [   TRAIN] - Epoch=5/5, Step=130/221 loss=0.0326 acc=0.9906 lr=0.000050 step/sec=5.36 | ETA 00:03:25  
[2021-08-12 20:57:13,355] [   TRAIN] - Epoch=5/5, Step=140/221 loss=0.0694 acc=0.9844 lr=0.000050 step/sec=5.37 | ETA 00:03:25  
[2021-08-12 20:57:15,217] [   TRAIN] - Epoch=5/5, Step=150/221 loss=0.0618 acc=0.9812 lr=0.000050 step/sec=5.37 | ETA 00:03:25  
[2021-08-12 20:57:17,089] [   TRAIN] - Epoch=5/5, Step=160/221 loss=0.0614 acc=0.9781 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:57:18,958] [   TRAIN] - Epoch=5/5, Step=170/221 loss=0.0854 acc=0.9656 lr=0.000050 step/sec=5.35 | ETA 00:03:25  
[2021-08-12 20:57:20,829] [   TRAIN] - Epoch=5/5, Step=180/221 loss=0.0989 acc=0.9688 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:57:22,703] [   TRAIN] - Epoch=5/5, Step=190/221 loss=0.0864 acc=0.9656 lr=0.000050 step/sec=5.34 | ETA 00:03:25  
[2021-08-12 20:57:24,571] [   TRAIN] - Epoch=5/5, Step=200/221 loss=0.0838 acc=0.9719 lr=0.000050 step/sec=5.35 | ETA 00:03:25  
[2021-08-12 20:57:26,436] [   TRAIN] - Epoch=5/5, Step=210/221 loss=0.0357 acc=0.9906 lr=0.000050 step/sec=5.36 | ETA 00:03:25  
[2021-08-12 20:57:28,293] [   TRAIN] - Epoch=5/5, Step=220/221 loss=0.0575 acc=0.9812 lr=0.000050 step/sec=5.39 | ETA 00:03:25  
[2021-08-12 20:57:30,037] [    EVAL] - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - [Evaluation result] avg_acc=0.8229  
[2021-08-12 20:57:32,928] [    EVAL] - Saving best model to ./checkpoint/best_model [best acc=0.8229]  
[2021-08-12 20:57:32,932] [    INFO] - Saving model checkpoint to ./checkpoint/epoch_5  

      
预测

          
model = hub.Module(  
    name='ernie\_tiny',  
    version='2.0.1',  
    task='seq-cls',  
    load_checkpoint='./checkpoint/best\_model/model.pdparams',  
    label_map={i:str(i) for i in range(10)}  
)  

      

          
[2021-08-12 21:11:45,345] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/ernie_tiny.pdparams  
[2021-08-12 21:11:55,772] [    INFO] - Loaded parameters from /home/aistudio/checkpoint/best_model/model.pdparams  

      

          
data = [[i] for i in test_df['text'].values]  
results = model.predict(data, max_seq_len=256, batch_size=1, use_gpu=False)  

      

          
[2021-08-12 21:12:16,153] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt  
[2021-08-12 21:12:16,226] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model  
[2021-08-12 21:12:16,230] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle  

      

          
sub = pd.DataFrame({'id':[i+1 for i in range(len(test_df))], 'label': results})  
sub.head()  

      
idlabel
013
128
232
340
453

          
sub.to_csv('ernie\_tiny\_baseline.csv', index=False)  

      
提交结果

线上得分: 0.85174

picture.image

扫码加我好友进微信群|QQ群

picture.image

0
0
0
0
评论
未登录
暂无评论