CatBoost模型的Java推理相比LightGBM会简单许多,无需转换成pmml格式,直接用官方的Java-package即可。
最主要的是,它直接支持字符串类型的类别特征,无需做各种编码转换,简直不要太6。
参考文档:https://catboost.ai/en/docs/concepts/java-package
一,Java项目添加Maven依赖
注意version与python中的一致
<!-- https://mvnrepository.com/artifact/ai.catboost/catboost-prediction -->
<dependency>
<groupId>ai.catboost</groupId>
<artifactId>catboost-prediction</artifactId>
<version>1.0.6</version>
</dependency>
二,Python端训练CatBoost模型
此处以adult数据集的二分类问题为例。
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import datetime
import numpy as np
import pandas as pd
import plotly.express as px
import catboost as cb
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("\n"+"=========="*8 + "%s"%nowtime)
print(info+'...\n\n')
#================================================================================
# 一,准备数据
#================================================================================
printlog("step1: preparing data...")
adult = fetch_openml(name = "adult",version=1,as_frame=True)
label_col = "target"
dfdata = adult["data"]
dfdata["target"] = (adult["target"]==">50K").astype("float")
cat_features = dfdata.drop("target",axis=1
).select_dtypes("category").columns.values.tolist()
cat_features.sort()
num_features = [col for col in dfdata.columns if col not in cat_features+[label_col]]
num_features.sort()
dfdata[cat_features] = dfdata[cat_features].astype(str)
dfdata[cat_features].fillna("missed")
dfdata = dfdata[num_features+cat_features+[label_col]]
dftrain_val,dftest = train_test_split(dfdata,test_size=0.3)
dftrain,dfval = train_test_split(dftrain_val,test_size=0.3)
# 整理成Pool
pool_train = cb.Pool(data = dftrain.drop(label_col,axis=1),
label = dftrain[label_col], cat_features=cat_features)
pool_val = cb.Pool(data = dfval.drop(label_col,axis=1),
label = dfval[label_col], cat_features=cat_features)
pool_test = cb.Pool(data = dftest.drop(label_col,axis=1),
label = dftest[label_col], cat_features=cat_features)
#================================================================================
# 二,设置参数
#================================================================================
printlog("step2: setting parameters...")
iterations = 1000
early_stopping_rounds = 200
params = dict(
loss_function = "Logloss",
eval_metric = "AUC",
random_seed = 42,
logging_level = 'Silent',
use_best_model = True,
nan_mode = 'Min',
###
learning_rate = 0.05,
depth = 6,
min_data_in_leaf = 10,
one_hot_max_size = 5, #类别数量多于此数将使用ordered target statistics
boosting_type = "Ordered", #Ordered 或者Plain
max_ctr_complexity = 2, #特征组合的最大特征数量,设置为1取消特征组合,
)
model = cb.CatBoostClassifier(
iterations = iterations,
early_stopping_rounds = early_stopping_rounds,
**params
)
#================================================================================
# 三,训练模型
#================================================================================
printlog("step3: training model...")
model.fit(
pool_train,
eval_set=pool_val,
plot=True
)
#================================================================================
# 四,评估模型
#================================================================================
printlog("step4: evaluating model ...")
#feature importance
dfimportance = model.get_feature_importance(prettified=True)
dfimportance = dfimportance.sort_values(by = "Importances").iloc[-20:]
fig_importance = px.bar(dfimportance,x="Importances",y="Feature Id",title="Feature Importance")
display(dfimportance)
display(fig_importance)
#score distribution
y_test_prob = model.predict_proba(dftest.drop(label_col,axis = 1))[:,-1]
fig_hist = px.histogram(
x=y_test_prob,color =dftest[label_col], nbins=50,
title = "Score Distribution",
labels=dict(color='True Labels', x='Score')
)
fig_hist.show()
#================================================================================
# 五,使用模型
#================================================================================
printlog("step5: using model ...")
y_pred_test = model.predict(dftest)
y_pred_test_prob = model.predict_proba(dftest)
print("y\_pred\_test:\n",y_pred_test[:10])
print("y\_pred\_test\_prob:\n",y_pred_test_prob[:10])
#================================================================================
# 六,保存模型
#================================================================================
printlog("step6: saving model ...")
model_dir = 'adult\_model.cbm'
model.save_model(model_dir)
model_loaded = cb.CatBoostClassifier()
model.load_model(model_dir)
得到的adult_model.cbm放入到java项目的resource目录下.
三,Java端推理预测封装
推理代码封装如下
package com.example.model;
import ai.catboost.CatBoostModel;
import ai.catboost.CatBoostPredictions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Map;
public class CatBoostClassifier {
public CatBoostModel model;
private static Logger LOG= LoggerFactory.getLogger(CatBoostClassifier.class);
public CatBoostClassifier(String model\_path) {
try{
InputStream model_file = Thread.currentThread()
.getContextClassLoader()
.getResourceAsStream(model_path);
this.model = CatBoostModel.loadModel(model_file);
}catch (Exception err){
LOG.error("catboost-init-error", err);
System.out.println(err.toString());
};
}
public Double predict(Map<String,Double> num\_features,Map<String,String> cat\_features){
try {
String[] feature_names = this.model.getFeatureNames();
assert (num_features.size() + cat_features.size()) == feature_names.length;
float[] num_arr = new float[num_features.size()];
String[] cat_arr = new String[cat_features.size()];
int i = 0;
int j = 0;
for (String name : feature_names) {
if (num_features.keySet().contains(name)) {
num_arr[i] = num_features.get(name).floatValue();
i += 1;
} else {
assert cat_features.keySet().contains(name);
cat_arr[j] = cat_features.get(name);
j += 1;
}
}
CatBoostPredictions prediction = this.model.predict(
num_arr,
cat_arr);
Double prob = 1.0 - 1.0 / (1.0 + Math.exp(prediction.get(0, 0)));
return prob;
}catch (Exception err){
LOG.error("catboost-predict-error", err);
return null;
}
}
}
四,预测调用测试代码
import java.util.*;
import com.example.model.CatBoostClassifier;
public class CatBoostTest {
public void testCatBoostClassifier(){
CatBoostClassifier clf = new CatBoostClassifier("adult\_model.cbm");
Map<String,Double> num_features = new HashMap<String,Double>();
Map<String,String> cat_features = new HashMap<String,String>();
num_features.put("fnlwgt",236379.0);
num_features.put("education-num",11.0);
cat_features.put("workclass","Private");
cat_features.put("sex","Male");
cat_features.put("relationship","Husband");
cat_features.put("race","White");
cat_features.put("occupation","Craft-repair");
cat_features.put("native-country","United-States");
cat_features.put("marital-status","Married-civ-spouse");
cat_features.put("hoursperweek","2");
cat_features.put("education","Assoc-voc");
cat_features.put("capitalloss","0");
cat_features.put("capitalgain","0");
cat_features.put("age","1");
Double model_prob = clf.predict(num_features,cat_features);
System.out.println(model_prob);
System.out.println("sucessed!");
}
}