一份半监督学习的指南-伪标签学习

大模型数据中台

picture.image

1 引言

picture.image

在ML中,有3种机器学习方法-监督学习、无监督学习和强化学习技术。 我们所知道的监督学习是指数据带有标签的情况, 无监督学习是仅存在数据而没有标签的情况,强化学习算法的思路非常简单,以游戏为例,如果在游戏中采取某种策略可以取得较高的得分,那么就进一步“强化”这种策略,以期继续取得较好的结果。

想象一下这样一种情况,在训练中,标记数据的数量更少,而未标记数据的数量更多。 一种称为半监督学习( [Semi-Supervised Learning],SSL)的新技术,它是监督学习和非监督学习的混合体。 顾名思义,半监督学习中同时存在一组标记的训练数据和另一组未标记的训练数据。 我们可以将这种情况想像成Google图片或Facebook通过其面孔(数据)识别出图片中的人物并根据该人物先前存储的图像生成建议名称(标签)的情况。

picture.image

在本文中,我们将讨论如何使用半监督学习技术生成伪标签。

picture.image

2 Pseudo-Labelling 伪标签

picture.image

伪标签是使用标记的数据模型预测未标记数据并进行标记的过程。 首先,模型已经训练了包含标签的数据集,该模型用于为未标记的数据集生成伪标签。 最后,将数据集和标签(原始标签和伪标签)组合在一起以进行最终模型训练。 之所以称为伪(意味着虚幻),是因为它们可能是真实标签,也可能不是真实标签,并且是通过我们基于类似的数据模型生成的标签。

picture.image

该方法的主旨思想其实很简单。首先,在标签数据上训练模型,然后使用经过训练的模型来预测无标签数据的标签,从而创建伪标签。此外,将标签数据和新生成的伪标签数据结合起来作为新的训练数据。

picture.image

3 Python 实现

picture.image

在这个例子中,我们使用了sklearn中的breast cancer数据集。我们知道整个已经包含了标签,但我们要修改它,将数据分成两部分,一部分有标签,另一部分没有标签。我们将从经过训练的带标签数据模型中为未带标签的数据生成我们自己的标签,然后最后使用两者合并的数据集来训练最终的模型。

3.1 数据集

Breast cancer dataset是预测肿瘤是良性(B)还是恶性(M)的分类问题。前两列为1)id和2)diagnosis(标签):

picture.image


                
a)radius_mean(从中心到外围点的距离的平均值)  
b)texture_mean(灰度值的标准偏差)  
c)perimeter\_mean(周长)  
d)area\_mean(面积)  
e)smoothness_mean(半径长度的局部变化)  
f)compactness_mean(周长^ 2 /面积– 1.0)  
g)concavity_mean(轮廓凹部的严重程度)  
h) concave points_mean(轮廓的凹面部分的数量)  

            

3.2 导入包


                
import pandas as pd  
import numpy as np  
from sklearn.model_selection import train_test_split  
from sklearn.datasets import load_breast_cancer  
from sklearn.ensemble import RandomForestClassifier  

            

3.3 加载数据集


                
X,y = load\_breast\_cancer(True)  
X.shape  

            

                
(569, 30)  

            

3.4 分割数据集


                
x_train,x_test,y_train,_ = train_test_split(X,y,test_size=.6)  
x_train.shape,y_train.shape,x_test.shape  

            

                
((227, 30), (227,), (342, 30)  

            

3.5 训练模型


                
model1 = RandomForestClassifier()  
history = model1.fit(x_train,y_train)  
history  

            

                
RandomForestRegressor(bootstrap=True, ccp_alpha=0.0, criterion=’mse’,  
max_depth=None, max_features=’auto’, max_leaf_nodes=None,  
max_samples=None, min_impurity_decrease=0.0,  
min_impurity_split=None, min_samples_leaf=1,  
min_samples_split=2, min_weight_fraction_leaf=0.0,  
n_estimators=100, n_jobs=None, oob_score=False,  
random_state=None, verbose=0, warm_start=False)  

            

3.6 评分


                
model1.score(x_train,y_train)  

            

                
1.0  

            

3.7 预测


                
y_new = model1.predict(x_test)  
y_new.shape  

            

                
(342,)  

            

合并数据集


                
final_X = np.concatenate((x_train,x_test))  
final_X.shape  

            

                
(569, 30)  

            

合并原始标签与伪标签


                
final_Y = np.concatenate((y_train,y_test))  
final_Y.shape  

            

                
(569,)  

            

基于合并的数据集训练最终模型


                
model2 = RandomForestRegressor()  
model2.fit(final_X,final_Y)  
model2.score(final_X,final_Y)  

            

                
1.0  

            

picture.image

4 结论

picture.image

伪标签的实现到此为止,大家可以根据自己的想法去比赛中尝试吧。

picture.image

picture.image

picture.image

picture.image

扫码关注 ChallengeHub

picture.image

仙女都在看 点点点,赞和在看都在这儿!

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
在火山引擎云搜索服务上构建混合搜索的设计与实现
本次演讲将重点介绍字节跳动在混合搜索领域的探索,并探讨如何在多模态数据场景下进行海量数据搜索。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论