引言
你知道模型的弱点在哪儿吗?你知道该如何改进模型的弱点吗?
本文介绍ACL 2023论文:Targeted Data Generation:Finding and Fixing Model Weaknesses[1]。这篇论文旨在 自动挖掘模型weakness 、 并生成可解决模型weakness的数据 。
笔者认为,这篇论文有两点意义:
- 提供了一种半自动化的数据迭代框架 ,由人类参与、并且该过程人类可以理解,这为NLP模型迭代提供了一种新思路;
- 展现了机器生成样本 vs 自然真实样本 的优势,这优势在于机器生成样本能更有的放矢,因此更高效 。
TDG框架
论文提出的TDG框架,主要有以下3个阶段:
- 阶段1 :Automatic Subgroup Discovery,自动从验证集Val中挖掘subgroups;
- 阶段2 :Select Challenging Subgroup,找到值得进行数据增强的challenging subgroup(指模型无法很好解决的subgroup,也就是weakness);
- 阶段3 :Subgroup Augmentation with LLM + Human-in-the-Loop,利用LLM生成in-subgroup数据,并由人工参与标注,在基于local model和global model的disagreement的迭代式框架下(后文将做详解),最终确定可以解决模型weakness的数据。
换句话说,你提供1个原始模型 + 1个验证集 + 可进行标注的人工,执行TDG方法之后你将获得:
- 知道模型的weakness在哪,即知道模型在哪类数据上效果明显更差 ;
- 得到能够解决模型weakness的数据 。
这就是论文题目 Finding and Fixing Model Weakness
的意思。
阶段1:Automatic Subgroup Discovery
这一阶段主要的思路是: 聚类(clustering) ,聚类的对象为验证集Val。
作者提出3类clustering策略:
- Agnostic clustering 。使用与任务完全解耦的embedding model(例如sentence-BERT)来聚类;
- Task-based clustering 。使用模型自身的embedding来进行聚类;
- Task-based + label information 。在上一种方式的基础上,再加一个限制:一个cluster内的数据的label必须相同。
直观上理解,第二、第三种方法会更好,因为 模型的弱点模型自己知道的最清楚 。
作者给出了3种策略的示意图。从图上看,第一种Agnostic clustering方法混淆了positive point和negative point,而后两种方法可以区分, 这也暗示了后两种方法可能更好,下一章的实验也的确证明了这一点 。
阶段2:Select Challenging Subgroups
该如何选择clustering策略?该选择哪些subgroup来进行数据增强,从而既能提升subgroup效果而又不损害整体效果?
为了筛选,作者引入了两个概念: GC(Generalization in Context)
和 IC(Interference in Context)
。
其中Cval指验证集Val中的某个cluster的数据;Dval指验证集Val的所有数据;M指原模型,M’指加入额外cluster数据训练的模型。
GC
衡量
一个cluster是否可以从更多数据中受益,
IC
衡量加入更多cluster data是否对整体效果有损。
因此, GC - IC
的值最大的clustering策略,是最好的策略。
其中 Ck
表示 Top-k subgroups
(按error rate降序)。
作者在SST和MNLI数据集上,对3类策略进行了对比实验,结果如下图。
从结果来看, Task-based clustering的效果更好 ;而Task-based + label information方法,虽然GC很高,但是IC也很高,作者指出这种方式带来的增强数据是 label-imbalance 的,因此会影响整体效果。
补充:
从公式来看,本节的方法不仅可以挑选clustering策略,还可以确定k的数值,但原论文未对这块内容进行描述;在后文实验中,SST选择了top2的subgroup进行实验,MNLI数据集选择了top10的subgroup,但作者并未说明为何如此挑选
在QQP数据中,作者发现无论什么方法,
IC
都很高,因此不适合使用本论文的框架,作者进一步分析发现,QQP具有high label noise
,因此影响了效果;这也说明,本论文的方法只适用于标注准确率高的数据集
阶段3:Subgroup Augmentation with LLM + Human-in-the-Loop
经过了阶段1和阶段2,找到了值得进行数据增强的subgroup, 阶段3的目的是针对这些subgroup进行数据增强 。
乍一看,直接把subgroup的数据给LLM,通过写prompt,让LLM生成类似的数据,不就可以了么?
但问题没这么简单:
- 该让LLM生成多少数据? 少了,不足以解决问题;多了,会影响其他类型数据的效果;
- 直接把LLM生成的数据拿来用就够了么? 一方面有标注准确性问题,一方面LLM生成的easy sample可能原模型已经能解决,没有必要再加入,加入了反而可能引起
shortcuts
问题。
作者使用的方法源自CoDev这篇论文[2]。整体上这个框架是迭代式的, 利用local model和global model的disagreement来驱动整个流程 ,并且由人工来进行最后的标注。
其步骤如下:
- 步骤a :提供
concept
(CoDev论文中的概念,在本文中concept = subgroup
) - 步骤b :训练local model(指仅基于这一个subgroup训练的model);
- 步骤c :基于subgroup内的数据,GPT-3生成一批
in-cluster data
,并且将global model与local model预测结果不一致的,送往人工标注(global model指在所有subgroup上训练的model); - 步骤d :人工标注,并且放弃标注不属于这一cluster的生成数据;
- 步骤e :基于标注后的数据,更新local model和global model;
- 循环步骤c-步骤e,直到收敛,收敛条件是 local model和global model不再有disagreement 。
其伪代码如下:
直观上理解, local model类似于一个持续进化的专家,负责指导LLM生成能够帮助global model真正学习到subgroup信息的数据 。
与从unlabel data中采样(即active learning)相比,这种指导机器生成数据的方法, 更有针对性,因此理论上更加高效 。
更多细节请读者参阅CoDev论文[2]。
实验效果
在SST、MNLI上的效果如下:
其中,两个Baseline为Reweighting(一种训练策略,用以优化subgroup效果)和Paraphrasing(用T5-based paraphrase model来生成相似样本);TDG(single)指仅对一个subgroup进行修复,TDG(all)指对挑选的所有subgroup同时进行修复。
根据实验结果,可以发现:
- 在修复单个challenging subgroup时,TDG(single)的效果通常比TDG(all) 更好 ;
- TDG(all)不仅可以提升多个challenging subgroups的效果(即完成 修复 工作),同时还可以 小幅提升 devtest的效果,而反观Baseline,两头都做不好。
本节省略了一些细节(对baseline的详细介绍、消融实验结果等),有兴趣的读者请查阅原论文。
结论
- 该论文提出了TDG框架,能够自动挖掘model weakness、并生成修复weakness的数据,你所需要提供的是 1个原始模型 + 1个验证集 + 可以标注的人工
- 经过实验,最好的clustering策略是task-based clustering,即使用任务模型自身的表征来进行聚类
- TDG不仅可以针对性地修复challenging subgroup问题,对整体数据的效果无害,甚至能有小幅提升
局限性
论文作者在最后提到了以下局限:
- TDG方法并未完全考虑subgroup之间的关系(体现在阶段2中,直接选择Top-K subgroups进行修复)
- 受限于计算资源,本文用于实验的subgroup比较少,SST选取了Top2,MNLI选择了Top10
同时,作者强调TDG应该作为提升低表现group的 最后一步 ,如果你发现有 大量 的低表现group,说明模型是 under-trained
,此时不应该选择TDG,而是需要进行 better data/modeling
。
一些思考
论文主体部分讲解结束,以下是笔者的一些思考:
- TDG这种半自动化的数据迭代方法, 如何与基于产品反馈的数据迭代方法进行有效结合 ?
- TDG的各个子模块可以独立应用,比如阶段1的方法可以用于从unlabel data中 挖掘 易错样本送往人工标注、也可以在推理时 识别 易错样本并进行单独处理,而阶段2的方法则可以用于 评估 数据增强方法的有效性
- 本论文的实验均为分类任务,但TDG的这一框架理论上也可以支持 信息抽取和文本生成 任务
关注笔者
参考资料
[1] Targeted Data Generation: Finding and Fixing Model Weaknesses: https://aclanthology.org/2023.acl-long.474/
[2] Collaborative Development of NLP models: https://arxiv.org/abs/2305.12219