CatBoost的Java端推理

技术

CatBoost模型的Java推理相比LightGBM会简单许多,无需转换成pmml格式,直接用官方的Java-package即可。

最主要的是,它直接支持字符串类型的类别特征,无需做各种编码转换,简直不要太6。

参考文档:https://catboost.ai/en/docs/concepts/java-package

picture.image

一,Java项目添加Maven依赖

注意version与python中的一致


        
          
<!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->  
<dependency>  
    <groupId>ai.catboost</groupId>  
    <artifactId>catboost-prediction</artifactId>  
    <version>1.0.6</version>  
</dependency>
      

二,Python端训练CatBoost模型

此处以adult数据集的二分类问题为例。


        
          
  
from sklearn.datasets import fetch_openml  
from sklearn.model_selection import train_test_split   
import datetime   
  
import numpy as np   
import pandas as pd   
import plotly.express as px   
  
  
import catboost as cb   
def printlog(info):  
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')  
    print("\n"+"=========="*8 + "%s"%nowtime)  
    print(info+'...\n\n')  
      
      
#================================================================================  
# 一,准备数据  
#================================================================================  
printlog("step1: preparing data...")  
  
  
adult = fetch_openml(name = "adult",version=1,as_frame=True)  
  
label_col = "target"  
dfdata = adult["data"]  
dfdata["target"] = (adult["target"]==">50K").astype("float")  
  
cat_features = dfdata.drop("target",axis=1  
                ).select_dtypes("category").columns.values.tolist()  
cat_features.sort()  
  
num_features = [col for col in dfdata.columns if col not in cat_features+[label_col]]  
num_features.sort()   
  
dfdata[cat_features] = dfdata[cat_features].astype(str)  
dfdata[cat_features].fillna("missed")   
  
dfdata = dfdata[num_features+cat_features+[label_col]]  
  
dftrain_val,dftest = train_test_split(dfdata,test_size=0.3)  
dftrain,dfval = train_test_split(dftrain_val,test_size=0.3)  
  
# 整理成Pool  
pool_train = cb.Pool(data = dftrain.drop(label_col,axis=1),   
                     label = dftrain[label_col], cat_features=cat_features)  
pool_val = cb.Pool(data = dfval.drop(label_col,axis=1),   
                     label = dfval[label_col], cat_features=cat_features)  
pool_test = cb.Pool(data = dftest.drop(label_col,axis=1),   
            label = dftest[label_col], cat_features=cat_features)  
  
  
#================================================================================  
# 二,设置参数  
#================================================================================  
printlog("step2: setting parameters...")  
                                 
iterations = 1000  
early_stopping_rounds = 200  
  
params = dict(  
    loss_function = "Logloss",  
    eval_metric = "AUC",  
    random_seed = 42,  
    logging_level = 'Silent',  
    use_best_model = True,  
    nan_mode = 'Min',  
    ###  
    learning_rate = 0.05,  
    depth = 6,  
    min_data_in_leaf = 10,  
    one_hot_max_size = 5,      #类别数量多于此数将使用ordered target statistics  
    boosting_type = "Ordered", #Ordered 或者Plain  
    max_ctr_complexity = 2,    #特征组合的最大特征数量,设置为1取消特征组合,  
)  
  
  
  
model = cb.CatBoostClassifier(  
        iterations = iterations,  
        early_stopping_rounds = early_stopping_rounds,  
        **params  
    )  
  
  
#================================================================================  
# 三,训练模型  
#================================================================================  
printlog("step3: training model...")  
  
  
model.fit(  
    pool_train,  
    eval_set=pool_val,  
    plot=True  
)  
  
  
  
#================================================================================  
# 四,评估模型  
#================================================================================  
printlog("step4: evaluating model ...")  
  
  
#feature importance   
dfimportance = model.get_feature_importance(prettified=True)   
dfimportance = dfimportance.sort_values(by = "Importances").iloc[-20:]  
fig_importance = px.bar(dfimportance,x="Importances",y="Feature Id",title="Feature Importance")  
  
display(dfimportance)  
display(fig_importance)  
  
  
#score distribution  
y_test_prob = model.predict_proba(dftest.drop(label_col,axis = 1))[:,-1]  
fig_hist = px.histogram(  
    x=y_test_prob,color =dftest[label_col],  nbins=50,  
    title = "Score Distribution",  
    labels=dict(color='True Labels', x='Score')  
)  
fig_hist.show()   
  
  
  
#================================================================================  
# 五,使用模型  
#================================================================================  
printlog("step5: using model ...")  
  
y_pred_test = model.predict(dftest)  
y_pred_test_prob = model.predict_proba(dftest)  
  
print("y\_pred\_test:\n",y_pred_test[:10])  
print("y\_pred\_test\_prob:\n",y_pred_test_prob[:10])  
  
  
#================================================================================  
# 六,保存模型  
#================================================================================  
printlog("step6: saving model ...")  
  
model_dir = 'adult\_model.cbm'  
model.save_model(model_dir)  
model_loaded = cb.CatBoostClassifier()  
model.load_model(model_dir)  
  

      

得到的adult_model.cbm放入到java项目的resource目录下.

三,Java端推理预测封装

推理代码封装如下


        
          
package com.example.model;  
  
import ai.catboost.CatBoostModel;  
import ai.catboost.CatBoostPredictions;  
import org.slf4j.Logger;  
import org.slf4j.LoggerFactory;  
  
import java.io.InputStream;  
import java.io.PrintWriter;  
import java.io.StringWriter;  
import java.util.Map;  
  
public class CatBoostClassifier {  
  
    public CatBoostModel model;  
  
    private static Logger LOG= LoggerFactory.getLogger(CatBoostClassifier.class);  
  
    public CatBoostClassifier(String model\_path) {  
        try{  
            InputStream model_file = Thread.currentThread()  
                    .getContextClassLoader()  
                    .getResourceAsStream(model_path);  
            this.model = CatBoostModel.loadModel(model_file);  
  
        }catch (Exception err){  
            LOG.error("catboost-init-error", err);  
            System.out.println(err.toString());  
        };  
  
    }  
  
    public Double predict(Map<String,Double> num\_features,Map<String,String> cat\_features){  
        try {  
            String[] feature_names = this.model.getFeatureNames();  
            assert (num_features.size() + cat_features.size()) == feature_names.length;  
  
            float[] num_arr = new float[num_features.size()];  
            String[] cat_arr = new String[cat_features.size()];  
  
            int i = 0;  
            int j = 0;  
            for (String name : feature_names) {  
                if (num_features.keySet().contains(name)) {  
                    num_arr[i] = num_features.get(name).floatValue();  
                    i += 1;  
                } else {  
                    assert cat_features.keySet().contains(name);  
                    cat_arr[j] = cat_features.get(name);  
                    j += 1;  
                }  
            }  
            CatBoostPredictions prediction = this.model.predict(  
                    num_arr,  
                    cat_arr);  
            Double prob = 1.0 - 1.0 / (1.0 + Math.exp(prediction.get(0, 0)));  
            return prob;  
        }catch (Exception err){  
            LOG.error("catboost-predict-error", err);  
            return null;  
        }  
    }  
}  
  

      

四,预测调用测试代码


        
          
import java.util.*;  
import com.example.model.CatBoostClassifier;  
  
public class CatBoostTest {  
    public void testCatBoostClassifier(){  
        CatBoostClassifier clf = new CatBoostClassifier("adult\_model.cbm");  
          
        Map<String,Double> num_features = new HashMap<String,Double>();  
        Map<String,String> cat_features = new HashMap<String,String>();  
          
        num_features.put("fnlwgt",236379.0);  
        num_features.put("education-num",11.0);  
          
          
        cat_features.put("workclass","Private");  
        cat_features.put("sex","Male");  
        cat_features.put("relationship","Husband");  
        cat_features.put("race","White");  
        cat_features.put("occupation","Craft-repair");  
        cat_features.put("native-country","United-States");  
        cat_features.put("marital-status","Married-civ-spouse");  
        cat_features.put("hoursperweek","2");  
        cat_features.put("education","Assoc-voc");  
        cat_features.put("capitalloss","0");  
        cat_features.put("capitalgain","0");  
        cat_features.put("age","1");  
  
        Double model_prob = clf.predict(num_features,cat_features);  
        System.out.println(model_prob);  
        System.out.println("sucessed!");  
    }  
  
}  
  

      

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
VikingDB:大规模云原生向量数据库的前沿实践与应用
本次演讲将重点介绍 VikingDB 解决各类应用中极限性能、规模、精度问题上的探索实践,并通过落地的案例向听众介绍如何在多模态信息检索、RAG 与知识库等领域进行合理的技术选型和规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论