项目简介
赛道名称:讯飞-非标准化疾病诉求的简单分诊挑战赛
赛道链接:https://challenge.xfyun.cn/topic/info?type=disease-claims
赛事背景
人民对于医疗健康的需求在不断增长,但社会现阶段医疗资源紧缺,往往排队一上午看病10分钟,时间和精神成本巨大,如何更好的优化医疗资源配置,找到合适的方向,进行分级诊疗,是当前社会的重要课题。
大众自觉身体状态异常,有时不能准确判断自己是否患有疾病,需要寻求有专业知识的人进行判断,但是主诉者一般进行口语化表述,不容易进行精准高效的指引。
赛事任务
进行简单分诊需要一定的数据和经验知识进行支撑,本次比赛提供了部分好大夫在线的真实问诊数据,经过严格脱敏,提供给参赛者进行多分类任务,具体为通过处理文字诉求,给出10个常见的就诊方向之一。
数据说明
比赛提供约5000条训练数据,1000余条测试数据。
单条数据包含年龄段、主诉、标题、希望获得的帮助和其他描述字段文本,以及就诊方向标签 i∈int [0,9]
评估指标
macro F1-score
EDA
# 导入必要的库
import warnings
warnings.simplefilter('ignore')
import numpy as np
import pandas as pd
%matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
Using matplotlib backend: TkAgg
# 读取数据集
train = pd.read_excel('data/data104082/train.xlsx')
test = pd.read_excel('data/data104082/test.xlsx')
print(train.shape, test.shape)
(7844, 7) (1412, 6)
train.head()
| id | age | diseaseName | conditionDesc | title | hopeHelp | label | |
|---|---|---|---|---|---|---|---|
| 0 | 1 | 50+ | 肺部积水等 | 右腹处疼痛 伴气短 检查出肺部有积水等症状 | 入院时间咨询 | 早上好 请问下**今天去医院可以安排吗 | 8 |
| 1 | 2 | 30+ | 怎么才能知道自己缺钙 | 有时候腰膝酸软四肢无力感觉,睡眠不是太好,总是爱做梦 | 不知道是缺钙还是肾虚 | 请问这是肾虚症状还是缺钙症状啊 | 4 |
| 2 | 3 | 20+ | 舌头发麻,右下唇僵硬,带有发烧 | 舌头发麻,右下唇僵硬,带有发烧,嗓子有炎症 | 舌头发麻,右下唇僵硬,带有发烧 | 给一些建议,是否需要进行下一步检查 | 4 |
| 3 | 4 | 30+ | 后背不适,胸口对应处 | 因经常嗳气,胃吃完顶人,后就医,胃镜显示浅表性胃炎,吃药一周余,有所好转,但后背一直不舒服,... | 后背不适,胸口对应处,不是疼,不是酸痛,就是难受 | 如何调整治疗 | 4 |
| 4 | 5 | 40+ | 喉咙充血,有异物感,不痛。有少许痰 | 两个月前,喉部充血,咳嗽,去医院看过,咽炎,现在咳了,就是喉咙充血,有异物感。 | 看吃点什么药能尽快好起来 | 看吃点什么药能尽快好起来 | 0 |
test.head()
| id | age | diseaseName | conditionDesc | title | hopeHelp | |
|---|---|---|---|---|---|---|
| 0 | 1 | 20+ | 咳血 | 有过胃病,胃穿孔,中午睡觉起来咳血,几年都这样 | 希望得到大概是什么问题 | 希望得到大概是什么问题 |
| 1 | 2 | 50+ | 喉咙有痰吐不出来 | 我是2015年1月嗓子疼开始咳嗽,拍片支气管炎,开始输液,但咳嗽一直不好,一月后拍CT肺炎,... | 我该怎么办 | 我该怎么办 |
| 2 | 3 | 40+ | 双腿总是抽筋 | 男,43岁。平时双腿总抽筋从大腿跟一直到小腿晚上睡觉都能抽醒 | 请您帮助我查下是什么病该如何看 | 补钙也不好 |
| 3 | 4 | 30+ | 怀疑咽喉返流 | 六月一个鼻出血后,咳嗽了半个月(当时喉咙灼热),半个月后咳嗽停止,从那以后到现在10个月来每... | NaN | NaN |
| 4 | 5 | 50+ | 大便量少,拉完过一会儿又想拉拉不出来。 | 以前有肠胃炎一天拉几次但是看好了,现在是刚大完量少过会儿又想拉拉不出来。\n大便量少,天天刚... | 刚大完便过会儿马上又想拉拉不出来 | 有什么特效药网上开药先吃,谢谢! |
# label 分布
train['label'].value_counts()
4 1086
5 947
8 888
7 843
0 842
2 824
1 818
3 806
6 516
9 274
Name: label, dtype: int64
# 空值填充
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:
train[col].fillna('', inplace=True)
test[col].fillna('', inplace=True)
# 文本长度
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:
print(train[col].apply(len).describe())
print(test[col].apply(len).describe())
count 7844.000000
mean 11.085798
std 7.512438
min 1.000000
25% 5.000000
50% 10.000000
75% 17.000000
max 101.000000
Name: diseaseName, dtype: float64
count 1412.000000
mean 11.211048
std 8.447323
min 0.000000
25% 5.000000
50% 10.000000
75% 17.000000
max 147.000000
Name: diseaseName, dtype: float64
count 7844.000000
mean 47.455762
std 36.604923
min 2.000000
25% 23.000000
50% 33.000000
75% 58.000000
max 248.000000
Name: conditionDesc, dtype: float64
count 1412.000000
mean 49.145892
std 37.890270
min 2.000000
25% 24.000000
50% 33.000000
75% 60.000000
max 250.000000
Name: conditionDesc, dtype: float64
count 7844.000000
mean 10.052907
std 6.138516
min 0.000000
25% 6.000000
50% 9.000000
75% 12.000000
max 50.000000
Name: title, dtype: float64
count 1412.000000
mean 10.050283
std 6.085816
min 0.000000
25% 6.000000
50% 9.000000
75% 12.000000
max 50.000000
Name: title, dtype: float64
count 7844.000000
mean 14.754080
std 11.096384
min 0.000000
25% 8.000000
50% 13.000000
75% 18.000000
max 130.000000
Name: hopeHelp, dtype: float64
count 1412.000000
mean 15.459632
std 12.526165
min 0.000000
25% 8.000000
50% 13.000000
75% 19.000000
max 137.000000
Name: hopeHelp, dtype: float64
# 清理一些换行字符,不然会对创建数据集有影响
def clean\_str(x):
return x.replace('\r', '').replace('\t', ' ').replace('\n', ' ')
for col in ['diseaseName', 'conditionDesc', 'title', 'hopeHelp']:
train[col] = train[col].apply(lambda x: clean_str(x))
test[col] = test[col].apply(lambda x: clean_str(x))
# 拼接的长度
train['text'] = train['diseaseName'].astype(str) + " " + \
train['conditionDesc'].astype(str) + " " + \
train['title'].astype(str) + " " + \
train['hopeHelp'].astype(str)
test['text'] = test['diseaseName'].astype(str) + " " + \
test['conditionDesc'].astype(str) + " " + \
test['title'].astype(str) + " " + \
test['hopeHelp'].astype(str)
print(train['text'].apply(len).describe())
print(test['text'].apply(len).describe())
count 7844.000000
mean 86.343447
std 41.443523
min 19.000000
25% 58.000000
50% 75.000000
75% 101.000000
max 257.000000
Name: text, dtype: float64
count 1412.000000
mean 88.853399
std 43.299549
min 24.000000
25% 59.000000
50% 76.000000
75% 104.000000
max 258.000000
Name: text, dtype: float64
看来 max_seq_len 我们使用 256 就可以覆盖所有的样例了
baseline 思路
该题为典型的文本多分类,我们可以先把文本都拼接起来使用 BERT 等预训练模型进行微调建模。
先手动 9:1 切分下训练集,将训练集、验证集和测试集保存为 ChnSentiCorp 格式。
train = train[['text', 'label']].copy()
test = test[['text']].copy()
train = train.sample(frac=1, random_state=42) # 随机打乱
train_size = int(0.9 * len(train))
train_df = train[:train_size]
valid_df = train[train_size:]
test_df = test.copy()
print(train_df.shape, valid_df.shape, test_df.shape)
(7059, 2) (785, 2) (1412, 1)
# 保存为文本文件
train_df[['label', 'text']].to_csv('train.txt', index=False, header=False, sep='\t')
valid_df[['label', 'text']].to_csv('valid.txt', index=False, header=False, sep='\t')
自定义数据加载
# 更新 paddlehub
!pip install -q --upgrade paddlehub -i https://pypi.tuna.tsinghua.edu.cn/simple
from typing import Dict, List, Optional, Union, Tuple
import paddle
import paddlehub as hub
from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset
from paddlehub.text.bert_tokenizer import BertTokenizer
from paddlehub.text.tokenizer import CustomTokenizer
class DemoDataset(TextClassificationDataset):
def \_\_init\_\_(self, tokenizer: Union[BertTokenizer, CustomTokenizer], max\_seq\_len: int = 256, mode: str = 'train'):
base_path = './'
if mode == 'train':
data_file = 'train.txt'
elif mode == 'test':
data_file = 'test.txt'
else:
data_file = 'valid.txt'
super().__init__(
base_path=base_path,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
mode=mode,
data_file=data_file,
label_list=["0", "1", "2", "3", "4",
"5", "6", "7", "8", "9"],
is_file_with_header=False)
加载预训练模型
我们用 ernie_tiny 模型测试下效果
model = hub.Module(name='ernie\_tiny', version='2.0.1', task='seq-cls', num_classes=10)
[2021-08-12 20:53:41,155] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/ernie_tiny.pdparams
# 生成数据集
train_dataset = DemoDataset(tokenizer=model.get_tokenizer(), mode='train')
valid_dataset = DemoDataset(tokenizer=model.get_tokenizer(), mode='valid')
[2021-08-12 20:53:49,156] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt
[2021-08-12 20:53:49,160] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model
[2021-08-12 20:53:49,163] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle
[2021-08-12 20:53:57,314] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt
[2021-08-12 20:53:57,318] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model
[2021-08-12 20:53:57,321] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle
训练
optimizer = paddle.optimizer.AdamW(learning_rate=5e-5, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir='./checkpoint', use_gpu=True)
[2021-08-12 20:54:02,807] [ WARNING] - PaddleHub model checkpoint not found, start from scratch...
trainer.train(
train_dataset,
epochs=5,
batch_size=32,
eval_dataset=valid_dataset,
save_interval=5,
)
[2021-08-12 20:54:04,768] [ TRAIN] - Epoch=1/5, Step=10/221 loss=2.2223 acc=0.1812 lr=0.000050 step/sec=5.11 | ETA 00:03:36
[2021-08-12 20:54:06,600] [ TRAIN] - Epoch=1/5, Step=20/221 loss=1.8941 acc=0.4281 lr=0.000050 step/sec=5.46 | ETA 00:03:29
[2021-08-12 20:54:08,432] [ TRAIN] - Epoch=1/5, Step=30/221 loss=1.4312 acc=0.6438 lr=0.000050 step/sec=5.46 | ETA 00:03:26
[2021-08-12 20:54:10,267] [ TRAIN] - Epoch=1/5, Step=40/221 loss=1.0634 acc=0.7125 lr=0.000050 step/sec=5.45 | ETA 00:03:25
[2021-08-12 20:54:12,098] [ TRAIN] - Epoch=1/5, Step=50/221 loss=0.8385 acc=0.7438 lr=0.000050 step/sec=5.46 | ETA 00:03:25
[2021-08-12 20:54:13,937] [ TRAIN] - Epoch=1/5, Step=60/221 loss=0.6728 acc=0.8094 lr=0.000050 step/sec=5.44 | ETA 00:03:24
[2021-08-12 20:54:15,775] [ TRAIN] - Epoch=1/5, Step=70/221 loss=0.6850 acc=0.7844 lr=0.000050 step/sec=5.44 | ETA 00:03:24
[2021-08-12 20:54:17,606] [ TRAIN] - Epoch=1/5, Step=80/221 loss=0.5537 acc=0.8125 lr=0.000050 step/sec=5.46 | ETA 00:03:24
[2021-08-12 20:54:19,449] [ TRAIN] - Epoch=1/5, Step=90/221 loss=0.5635 acc=0.8219 lr=0.000050 step/sec=5.43 | ETA 00:03:24
[2021-08-12 20:54:21,291] [ TRAIN] - Epoch=1/5, Step=100/221 loss=0.5011 acc=0.8281 lr=0.000050 step/sec=5.43 | ETA 00:03:24
[2021-08-12 20:54:23,129] [ TRAIN] - Epoch=1/5, Step=110/221 loss=0.5155 acc=0.8031 lr=0.000050 step/sec=5.44 | ETA 00:03:24
[2021-08-12 20:54:24,971] [ TRAIN] - Epoch=1/5, Step=120/221 loss=0.5309 acc=0.8219 lr=0.000050 step/sec=5.43 | ETA 00:03:24
[2021-08-12 20:54:26,815] [ TRAIN] - Epoch=1/5, Step=130/221 loss=0.5388 acc=0.8187 lr=0.000050 step/sec=5.42 | ETA 00:03:24
[2021-08-12 20:54:28,654] [ TRAIN] - Epoch=1/5, Step=140/221 loss=0.5070 acc=0.8250 lr=0.000050 step/sec=5.44 | ETA 00:03:23
[2021-08-12 20:54:30,491] [ TRAIN] - Epoch=1/5, Step=150/221 loss=0.4604 acc=0.8156 lr=0.000050 step/sec=5.44 | ETA 00:03:23
[2021-08-12 20:54:32,333] [ TRAIN] - Epoch=1/5, Step=160/221 loss=0.3950 acc=0.8688 lr=0.000050 step/sec=5.43 | ETA 00:03:23
[2021-08-12 20:54:34,178] [ TRAIN] - Epoch=1/5, Step=170/221 loss=0.4469 acc=0.8406 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:54:36,006] [ TRAIN] - Epoch=1/5, Step=180/221 loss=0.5265 acc=0.8219 lr=0.000050 step/sec=5.47 | ETA 00:03:23
[2021-08-12 20:54:37,834] [ TRAIN] - Epoch=1/5, Step=190/221 loss=0.4789 acc=0.8375 lr=0.000050 step/sec=5.47 | ETA 00:03:23
[2021-08-12 20:54:39,670] [ TRAIN] - Epoch=1/5, Step=200/221 loss=0.4247 acc=0.8750 lr=0.000050 step/sec=5.44 | ETA 00:03:23
[2021-08-12 20:54:41,510] [ TRAIN] - Epoch=1/5, Step=210/221 loss=0.4540 acc=0.8313 lr=0.000050 step/sec=5.43 | ETA 00:03:23
[2021-08-12 20:54:43,334] [ TRAIN] - Epoch=1/5, Step=220/221 loss=0.4790 acc=0.8219 lr=0.000050 step/sec=5.49 | ETA 00:03:23
[2021-08-12 20:54:45,302] [ TRAIN] - Epoch=2/5, Step=10/221 loss=0.3176 acc=0.8812 lr=0.000050 step/sec=5.59 | ETA 00:03:23
[2021-08-12 20:54:47,144] [ TRAIN] - Epoch=2/5, Step=20/221 loss=0.3946 acc=0.8562 lr=0.000050 step/sec=5.43 | ETA 00:03:23
[2021-08-12 20:54:48,982] [ TRAIN] - Epoch=2/5, Step=30/221 loss=0.3702 acc=0.8688 lr=0.000050 step/sec=5.44 | ETA 00:03:23
[2021-08-12 20:54:50,826] [ TRAIN] - Epoch=2/5, Step=40/221 loss=0.3149 acc=0.8969 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:54:52,675] [ TRAIN] - Epoch=2/5, Step=50/221 loss=0.2551 acc=0.9094 lr=0.000050 step/sec=5.41 | ETA 00:03:23
[2021-08-12 20:54:54,514] [ TRAIN] - Epoch=2/5, Step=60/221 loss=0.3289 acc=0.8906 lr=0.000050 step/sec=5.44 | ETA 00:03:23
[2021-08-12 20:54:56,358] [ TRAIN] - Epoch=2/5, Step=70/221 loss=0.3411 acc=0.8750 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:54:58,205] [ TRAIN] - Epoch=2/5, Step=80/221 loss=0.2811 acc=0.9125 lr=0.000050 step/sec=5.41 | ETA 00:03:23
[2021-08-12 20:55:00,050] [ TRAIN] - Epoch=2/5, Step=90/221 loss=0.3720 acc=0.8406 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:55:01,902] [ TRAIN] - Epoch=2/5, Step=100/221 loss=0.3408 acc=0.8844 lr=0.000050 step/sec=5.40 | ETA 00:03:23
[2021-08-12 20:55:03,752] [ TRAIN] - Epoch=2/5, Step=110/221 loss=0.2865 acc=0.8969 lr=0.000050 step/sec=5.41 | ETA 00:03:23
[2021-08-12 20:55:05,602] [ TRAIN] - Epoch=2/5, Step=120/221 loss=0.2984 acc=0.9000 lr=0.000050 step/sec=5.40 | ETA 00:03:23
[2021-08-12 20:55:07,451] [ TRAIN] - Epoch=2/5, Step=130/221 loss=0.2746 acc=0.8906 lr=0.000050 step/sec=5.41 | ETA 00:03:23
[2021-08-12 20:55:09,309] [ TRAIN] - Epoch=2/5, Step=140/221 loss=0.3675 acc=0.8750 lr=0.000050 step/sec=5.38 | ETA 00:03:23
[2021-08-12 20:55:11,152] [ TRAIN] - Epoch=2/5, Step=150/221 loss=0.3222 acc=0.8750 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:55:12,994] [ TRAIN] - Epoch=2/5, Step=160/221 loss=0.2806 acc=0.9000 lr=0.000050 step/sec=5.43 | ETA 00:03:23
[2021-08-12 20:55:14,840] [ TRAIN] - Epoch=2/5, Step=170/221 loss=0.2960 acc=0.8875 lr=0.000050 step/sec=5.42 | ETA 00:03:23
[2021-08-12 20:55:16,693] [ TRAIN] - Epoch=2/5, Step=180/221 loss=0.3479 acc=0.8781 lr=0.000050 step/sec=5.40 | ETA 00:03:23
[2021-08-12 20:55:18,541] [ TRAIN] - Epoch=2/5, Step=190/221 loss=0.3052 acc=0.9000 lr=0.000050 step/sec=5.41 | ETA 00:03:23
[2021-08-12 20:55:20,393] [ TRAIN] - Epoch=2/5, Step=200/221 loss=0.3134 acc=0.8938 lr=0.000050 step/sec=5.40 | ETA 00:03:23
[2021-08-12 20:55:22,250] [ TRAIN] - Epoch=2/5, Step=210/221 loss=0.2998 acc=0.8781 lr=0.000050 step/sec=5.39 | ETA 00:03:23
[2021-08-12 20:55:24,147] [ TRAIN] - Epoch=2/5, Step=220/221 loss=0.3500 acc=0.8812 lr=0.000050 step/sec=5.27 | ETA 00:03:23
[2021-08-12 20:55:26,186] [ TRAIN] - Epoch=3/5, Step=10/221 loss=0.1696 acc=0.9469 lr=0.000050 step/sec=5.40 | ETA 00:03:23
[2021-08-12 20:55:28,082] [ TRAIN] - Epoch=3/5, Step=20/221 loss=0.1910 acc=0.9281 lr=0.000050 step/sec=5.28 | ETA 00:03:23
[2021-08-12 20:55:29,986] [ TRAIN] - Epoch=3/5, Step=30/221 loss=0.1739 acc=0.9375 lr=0.000050 step/sec=5.25 | ETA 00:03:24
[2021-08-12 20:55:31,879] [ TRAIN] - Epoch=3/5, Step=40/221 loss=0.1944 acc=0.9344 lr=0.000050 step/sec=5.28 | ETA 00:03:24
[2021-08-12 20:55:33,791] [ TRAIN] - Epoch=3/5, Step=50/221 loss=0.1955 acc=0.9375 lr=0.000050 step/sec=5.23 | ETA 00:03:24
[2021-08-12 20:55:35,684] [ TRAIN] - Epoch=3/5, Step=60/221 loss=0.1678 acc=0.9406 lr=0.000050 step/sec=5.28 | ETA 00:03:24
[2021-08-12 20:55:37,585] [ TRAIN] - Epoch=3/5, Step=70/221 loss=0.1499 acc=0.9469 lr=0.000050 step/sec=5.26 | ETA 00:03:24
[2021-08-12 20:55:39,494] [ TRAIN] - Epoch=3/5, Step=80/221 loss=0.1700 acc=0.9531 lr=0.000050 step/sec=5.24 | ETA 00:03:24
[2021-08-12 20:55:41,381] [ TRAIN] - Epoch=3/5, Step=90/221 loss=0.2193 acc=0.9094 lr=0.000050 step/sec=5.30 | ETA 00:03:24
[2021-08-12 20:55:43,247] [ TRAIN] - Epoch=3/5, Step=100/221 loss=0.2092 acc=0.9313 lr=0.000050 step/sec=5.36 | ETA 00:03:24
[2021-08-12 20:55:45,104] [ TRAIN] - Epoch=3/5, Step=110/221 loss=0.2010 acc=0.9437 lr=0.000050 step/sec=5.38 | ETA 00:03:24
[2021-08-12 20:55:46,966] [ TRAIN] - Epoch=3/5, Step=120/221 loss=0.2267 acc=0.9219 lr=0.000050 step/sec=5.37 | ETA 00:03:24
[2021-08-12 20:55:48,828] [ TRAIN] - Epoch=3/5, Step=130/221 loss=0.1778 acc=0.9437 lr=0.000050 step/sec=5.37 | ETA 00:03:24
[2021-08-12 20:55:50,679] [ TRAIN] - Epoch=3/5, Step=140/221 loss=0.2249 acc=0.9094 lr=0.000050 step/sec=5.40 | ETA 00:03:24
[2021-08-12 20:55:52,529] [ TRAIN] - Epoch=3/5, Step=150/221 loss=0.2156 acc=0.9375 lr=0.000050 step/sec=5.41 | ETA 00:03:24
[2021-08-12 20:55:54,385] [ TRAIN] - Epoch=3/5, Step=160/221 loss=0.1875 acc=0.9437 lr=0.000050 step/sec=5.39 | ETA 00:03:24
[2021-08-12 20:55:56,249] [ TRAIN] - Epoch=3/5, Step=170/221 loss=0.1529 acc=0.9500 lr=0.000050 step/sec=5.36 | ETA 00:03:24
[2021-08-12 20:55:58,160] [ TRAIN] - Epoch=3/5, Step=180/221 loss=0.1865 acc=0.9375 lr=0.000050 step/sec=5.23 | ETA 00:03:24
[2021-08-12 20:56:00,042] [ TRAIN] - Epoch=3/5, Step=190/221 loss=0.1899 acc=0.9375 lr=0.000050 step/sec=5.32 | ETA 00:03:24
[2021-08-12 20:56:01,936] [ TRAIN] - Epoch=3/5, Step=200/221 loss=0.2619 acc=0.9000 lr=0.000050 step/sec=5.28 | ETA 00:03:25
[2021-08-12 20:56:03,831] [ TRAIN] - Epoch=3/5, Step=210/221 loss=0.2324 acc=0.9156 lr=0.000050 step/sec=5.28 | ETA 00:03:25
[2021-08-12 20:56:05,705] [ TRAIN] - Epoch=3/5, Step=220/221 loss=0.2106 acc=0.9344 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:56:07,705] [ TRAIN] - Epoch=4/5, Step=10/221 loss=0.1316 acc=0.9688 lr=0.000050 step/sec=5.50 | ETA 00:03:25
[2021-08-12 20:56:09,609] [ TRAIN] - Epoch=4/5, Step=20/221 loss=0.1366 acc=0.9594 lr=0.000050 step/sec=5.25 | ETA 00:03:25
[2021-08-12 20:56:11,470] [ TRAIN] - Epoch=4/5, Step=30/221 loss=0.1062 acc=0.9656 lr=0.000050 step/sec=5.37 | ETA 00:03:25
[2021-08-12 20:56:13,328] [ TRAIN] - Epoch=4/5, Step=40/221 loss=0.1167 acc=0.9688 lr=0.000050 step/sec=5.38 | ETA 00:03:25
[2021-08-12 20:56:15,182] [ TRAIN] - Epoch=4/5, Step=50/221 loss=0.1110 acc=0.9656 lr=0.000050 step/sec=5.39 | ETA 00:03:25
[2021-08-12 20:56:17,040] [ TRAIN] - Epoch=4/5, Step=60/221 loss=0.1017 acc=0.9719 lr=0.000050 step/sec=5.38 | ETA 00:03:25
[2021-08-12 20:56:18,900] [ TRAIN] - Epoch=4/5, Step=70/221 loss=0.1088 acc=0.9688 lr=0.000050 step/sec=5.38 | ETA 00:03:25
[2021-08-12 20:56:20,758] [ TRAIN] - Epoch=4/5, Step=80/221 loss=0.1037 acc=0.9750 lr=0.000050 step/sec=5.38 | ETA 00:03:25
[2021-08-12 20:56:22,628] [ TRAIN] - Epoch=4/5, Step=90/221 loss=0.1002 acc=0.9625 lr=0.000050 step/sec=5.35 | ETA 00:03:25
[2021-08-12 20:56:24,504] [ TRAIN] - Epoch=4/5, Step=100/221 loss=0.1085 acc=0.9563 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:26,373] [ TRAIN] - Epoch=4/5, Step=110/221 loss=0.0752 acc=0.9750 lr=0.000050 step/sec=5.35 | ETA 00:03:25
[2021-08-12 20:56:28,257] [ TRAIN] - Epoch=4/5, Step=120/221 loss=0.0959 acc=0.9656 lr=0.000050 step/sec=5.31 | ETA 00:03:25
[2021-08-12 20:56:30,118] [ TRAIN] - Epoch=4/5, Step=130/221 loss=0.0969 acc=0.9688 lr=0.000050 step/sec=5.37 | ETA 00:03:25
[2021-08-12 20:56:31,995] [ TRAIN] - Epoch=4/5, Step=140/221 loss=0.1358 acc=0.9469 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:33,873] [ TRAIN] - Epoch=4/5, Step=150/221 loss=0.1128 acc=0.9719 lr=0.000050 step/sec=5.32 | ETA 00:03:25
[2021-08-12 20:56:35,738] [ TRAIN] - Epoch=4/5, Step=160/221 loss=0.1150 acc=0.9563 lr=0.000050 step/sec=5.36 | ETA 00:03:25
[2021-08-12 20:56:37,612] [ TRAIN] - Epoch=4/5, Step=170/221 loss=0.0940 acc=0.9625 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:56:39,489] [ TRAIN] - Epoch=4/5, Step=180/221 loss=0.0867 acc=0.9719 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:41,366] [ TRAIN] - Epoch=4/5, Step=190/221 loss=0.0969 acc=0.9625 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:43,248] [ TRAIN] - Epoch=4/5, Step=200/221 loss=0.1025 acc=0.9656 lr=0.000050 step/sec=5.31 | ETA 00:03:25
[2021-08-12 20:56:45,121] [ TRAIN] - Epoch=4/5, Step=210/221 loss=0.1372 acc=0.9500 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:56:46,987] [ TRAIN] - Epoch=4/5, Step=220/221 loss=0.0801 acc=0.9781 lr=0.000050 step/sec=5.36 | ETA 00:03:25
[2021-08-12 20:56:48,979] [ TRAIN] - Epoch=5/5, Step=10/221 loss=0.0604 acc=0.9875 lr=0.000050 step/sec=5.52 | ETA 00:03:25
[2021-08-12 20:56:50,857] [ TRAIN] - Epoch=5/5, Step=20/221 loss=0.0337 acc=0.9906 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:52,726] [ TRAIN] - Epoch=5/5, Step=30/221 loss=0.0336 acc=0.9938 lr=0.000050 step/sec=5.35 | ETA 00:03:25
[2021-08-12 20:56:54,603] [ TRAIN] - Epoch=5/5, Step=40/221 loss=0.0772 acc=0.9781 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:56:56,500] [ TRAIN] - Epoch=5/5, Step=50/221 loss=0.0518 acc=0.9844 lr=0.000050 step/sec=5.27 | ETA 00:03:25
[2021-08-12 20:56:58,379] [ TRAIN] - Epoch=5/5, Step=60/221 loss=0.0603 acc=0.9781 lr=0.000050 step/sec=5.32 | ETA 00:03:25
[2021-08-12 20:57:00,255] [ TRAIN] - Epoch=5/5, Step=70/221 loss=0.0578 acc=0.9875 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:57:02,120] [ TRAIN] - Epoch=5/5, Step=80/221 loss=0.0897 acc=0.9719 lr=0.000050 step/sec=5.36 | ETA 00:03:25
[2021-08-12 20:57:03,997] [ TRAIN] - Epoch=5/5, Step=90/221 loss=0.0341 acc=0.9906 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:57:05,879] [ TRAIN] - Epoch=5/5, Step=100/221 loss=0.0322 acc=0.9938 lr=0.000050 step/sec=5.31 | ETA 00:03:25
[2021-08-12 20:57:07,753] [ TRAIN] - Epoch=5/5, Step=110/221 loss=0.0264 acc=0.9969 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:57:09,628] [ TRAIN] - Epoch=5/5, Step=120/221 loss=0.0279 acc=0.9938 lr=0.000050 step/sec=5.33 | ETA 00:03:25
[2021-08-12 20:57:11,494] [ TRAIN] - Epoch=5/5, Step=130/221 loss=0.0326 acc=0.9906 lr=0.000050 step/sec=5.36 | ETA 00:03:25
[2021-08-12 20:57:13,355] [ TRAIN] - Epoch=5/5, Step=140/221 loss=0.0694 acc=0.9844 lr=0.000050 step/sec=5.37 | ETA 00:03:25
[2021-08-12 20:57:15,217] [ TRAIN] - Epoch=5/5, Step=150/221 loss=0.0618 acc=0.9812 lr=0.000050 step/sec=5.37 | ETA 00:03:25
[2021-08-12 20:57:17,089] [ TRAIN] - Epoch=5/5, Step=160/221 loss=0.0614 acc=0.9781 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:57:18,958] [ TRAIN] - Epoch=5/5, Step=170/221 loss=0.0854 acc=0.9656 lr=0.000050 step/sec=5.35 | ETA 00:03:25
[2021-08-12 20:57:20,829] [ TRAIN] - Epoch=5/5, Step=180/221 loss=0.0989 acc=0.9688 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:57:22,703] [ TRAIN] - Epoch=5/5, Step=190/221 loss=0.0864 acc=0.9656 lr=0.000050 step/sec=5.34 | ETA 00:03:25
[2021-08-12 20:57:24,571] [ TRAIN] - Epoch=5/5, Step=200/221 loss=0.0838 acc=0.9719 lr=0.000050 step/sec=5.35 | ETA 00:03:25
[2021-08-12 20:57:26,436] [ TRAIN] - Epoch=5/5, Step=210/221 loss=0.0357 acc=0.9906 lr=0.000050 step/sec=5.36 | ETA 00:03:25
[2021-08-12 20:57:28,293] [ TRAIN] - Epoch=5/5, Step=220/221 loss=0.0575 acc=0.9812 lr=0.000050 step/sec=5.39 | ETA 00:03:25
[2021-08-12 20:57:30,037] [ EVAL] - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - [Evaluation result] avg_acc=0.8229
[2021-08-12 20:57:32,928] [ EVAL] - Saving best model to ./checkpoint/best_model [best acc=0.8229]
[2021-08-12 20:57:32,932] [ INFO] - Saving model checkpoint to ./checkpoint/epoch_5
预测
model = hub.Module(
name='ernie\_tiny',
version='2.0.1',
task='seq-cls',
load_checkpoint='./checkpoint/best\_model/model.pdparams',
label_map={i:str(i) for i in range(10)}
)
[2021-08-12 21:11:45,345] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/ernie_tiny.pdparams
[2021-08-12 21:11:55,772] [ INFO] - Loaded parameters from /home/aistudio/checkpoint/best_model/model.pdparams
data = [[i] for i in test_df['text'].values]
results = model.predict(data, max_seq_len=256, batch_size=1, use_gpu=False)
[2021-08-12 21:12:16,153] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt
[2021-08-12 21:12:16,226] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model
[2021-08-12 21:12:16,230] [ INFO] - Found /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle
sub = pd.DataFrame({'id':[i+1 for i in range(len(test_df))], 'label': results})
sub.head()
| id | label | |
|---|---|---|
| 0 | 1 | 3 |
| 1 | 2 | 8 |
| 2 | 3 | 2 |
| 3 | 4 | 0 |
| 4 | 5 | 3 |
sub.to_csv('ernie\_tiny\_baseline.csv', index=False)
提交结果
线上得分: 0.85174
扫码加我好友进微信群|QQ群
