darts 时序预测入门

火山方舟向量数据库大模型

darts是一个强大而易用的Python时间序列建模工具包。在github上目前拥有超过7k颗stars。

它主要支持以下任务:

  • 时间序列预测 (包含 ARIMA, LightGBM模型, TCN, N-BEATS, TFT, DLinear, TiDE等等)

  • 时序异常检测 (包括 分位数检测 等等)

  • 时间序列滤波 (包括 卡尔曼滤波,高斯过程滤波)

本文演示使用darts构建N-BEATS模型对 牛奶月销量数据进行预测~

公众号算法美食屋后台回复关键词: 源码 ,获取本文notebook源码和数据集~


      
  `!pip install darts`
 


    

一, 准备数据

首先,你需要准备时间序列数据。如果数据有缺失,需要进行数据填充。

这里示范的是一个每月牛奶销量数据集。


      
 

 
  `import numpy as np` 
  `import pandas as pd` 
  `from matplotlib import pyplot as plt`
 
  `import darts`
  `from darts import TimeSeries`
  `from darts.dataprocessing.transformers import MissingValuesFiller`
  `from darts.dataprocessing.transformers import Scaler` 
 
 
  `# 1,读取数据集`
  `df = pd.read_csv('month_milk.csv')`
  `df.columns = ['ds','y']`
  `df = df.sort_values(by='ds')`
  `#df['y'] = df['y'].interpolate(method='linear')`
 
 
  `# 2,填充数据集`
  `ts_raw = TimeSeries.from_dataframe(df,time_col='ds',value_cols=['y'])`
  `#df = ts_raw.pd_dataframe() #timeseries转成dataframe`
  `fig, ax1 = plt.subplots(1, 1, figsize=(10,5))`
  `ts_raw.plot(color='brown',ax = ax1)`
  `ax1.set_title('before fill')`
 
  `#from darts.utils.missing_values import fill_missing_values`
  `#ts = fill_missing_values(ts_raw)`
 
  `filler = MissingValuesFiller()`
  `ts = filler.transform(ts_raw)`
 
  `fig, ax2 = plt.subplots(1, 1, figsize=(10,5))`
  `ts.plot(color='cyan', ax = ax2)`
  `ax2.set_title('after fill')`
 
  `# 3,分割数据集`
  `ts_train, ts_test= ts.split_after(pd.Timestamp("1973-01-01"))` 
  `fig, ax3 = plt.subplots(1,1,figsize=(10,5))`
  `ts_train.plot(color='blue',label='train',ax=ax3)`
  `ts_test.plot(color='red',label='test',ax=ax3)`
  `ax3.set_title('after split');`
 
 
  `# 4,缩放数据集`
  `scaler = Scaler()`
  `ts_train_scaled = scaler.fit_transform(ts_train)`
  `ts_test_scaled = scaler.transform(ts_test)`
  `ts_scaled = scaler.transform(ts)`
 
 
 
    

picture.image

picture.image

picture.image

二, 定义模型

接下来,定义一个时间序列模型。Darts支持多种模型。

包括统计模型(ARIMA, FFT, ExponentialSmoothing等)

机器学习模型(Prophet,LightGBM等)

神经网络模型(RNN类型,Transformer类型,CNN类型,MLP类型)。

此处我们使用 NBeats模型。


      
 

 
  `import warnings`
  `warnings.filterwarnings("ignore")`
  `import logging`
  `logging.disable(logging.CRITICAL)`
  `import wandb`
  `wandb.login()`
 
 
    

      
 

 
  `from darts.models import NBEATSModel` 
  `import pytorch_lightning as pl` 
  `from darts.utils.callbacks import TFMProgressBar`
  `from pytorch_lightning.loggers import WandbLogger`
 
  `wandb_logger = WandbLogger(project='MILK')`
  `progress_bar = TFMProgressBar(enable_train_bar_only=True)`
  `early_stopper = pl.callbacks.EarlyStopping(monitor="val_loss",patience=10,`
  `min_delta=1e-5,mode="min")`
  `pl_trainer_kwargs = dict(max_epochs=100,` 
  `accelerator='cpu',`
  `callbacks = [progress_bar,early_stopper],`
  `logger = wandb_logger`
  `)` 
 
  `settings = dict(`
  `input_chunk_length=10,`
  `output_chunk_length=3,`
  `generic_architecture=True,`
  `num_stacks=10,`
  `num_blocks=1,`
  `num_layers=4,`
  `layer_widths=512,`
  `save_checkpoints=True,`
 
  `batch_size=800,`
  `random_state=42,`
  `model_name='nbeats',`
  `force_reset=True`
  `)`
 
  `settings['pl_trainer_kwargs'] = pl_trainer_kwargs`
 
  `model = NBEATSModel(**settings)`
 
 
    

三,训练模型


          
model.fit(series = ts_train_scaled, val_series= ts_train_scaled)
          
model = model.load_from_checkpoint(model_name=model.model_name,best=True)
      

四,评估模型


      
 

 
  `# 历史数据逐段回测,使用真实历史数据作为特征,不做滚动预测`
  `ts_test_forecast = model.historical_forecasts(`
  `series=ts_scaled,`
  `start=ts_test_scaled.start_time(),`
  `forecast_horizon=3,`
  `stride=3,`
  `last_points_only=False,`
  `retrain=False,`
  `verbose=True,`
  `)`
 
 
    

      
 

 
  `from darts import concatenate`
  `ts_test_forecast = scaler.inverse_transform(concatenate(ts_test_forecast))`
  `ts_test_true = ts[ts_test_forecast.time_index]`
  `test_score = r2_score(ts_test_true, ts_test_preds)`
 
 
    

      
 

 
  `import matplotlib.pyplot as plt` 
  `plt.figure(figsize=(8, 5))`
  `ts_test_true.plot(label="y",color='blue')`
  `ts_test_forecast.plot(label='yhat-forecast',color='cyan')`
  `plt.title(`
  `"yhat-forecast R2-score: {}".format(test_score)`
  `)`
  `plt.legend()`
 
 
    

picture.image

五,使用模型


          
#滚动预测,使用预测的数据作为后面预测步骤的特征 
          
# (注意:当预测步数 n 小于等于模型的output_chunk_length,无需滚动)
          
ts_preds = model.predict(n = len(ts_test_scaled),series = ts_train_scaled, num_samples=1)
          
ts_test_preds  = scaler.inverse_transform(ts_preds)
      

      
 

 
  `from darts.metrics import r2_score` 
  `test_score = r2_score(ts_test_true, ts_test_preds)`
 
 
    

      
 

 
  `import matplotlib.pyplot as plt` 
  `plt.figure(figsize=(8, 5))`
  `ts_test_true.plot(label="y",color='blue')`
  `ts_test_forecast.plot(label="yhat-forecast",color='cyan')`
  `ts_test_preds.plot(label='yhat',color='red')`
  `plt.title(`
  `"yhat R2-score: {}".format(test_score)`
  `)`
  `plt.legend()`
 
 
    

picture.image

六,保存模型


      
 

 
  `model.save('nbeats')`
  `model_loaded = NBEATSModel.load('nbeats') #重新加载`
 
 
    

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
CloudWeGo白皮书:字节跳动云原生微服务架构原理与开源实践
本书总结了字节跳动自2018年以来的微服务架构演进之路
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论