Prophetで時系列データ予測

2021.06.27
2024.03.24
機械学習
ProphetPythonseaborn時系列データ

はじめに

時系列データを簡単に予測できるProphetについて、基本的な使い方と実際に使ってみた具体例を紹介したいと思います。

Prophetとは

Prophetとは、Facebookが開発した時系列データの予測ができるライブラリです。

外れ値や欠損値があっても簡単に精度のいいモデルを作成できます。また、パラメータを調整することでよりいいモデルの作成も可能になります。

どのように予測しているか

どのように予測しているのか簡単に説明します。

トレンドや年単位や週単位などの周期性、休日の影響に基づいて時系列データを予測するモデルになります。

予測結果は以下の式から求められます。

y(t)=g(t)+s(t)+h(t)+ϵty(t)=g(t)+s(t)+h(t)+\epsilon_{t} g(t):トレンド s(t):周期性 h(t):休日 ϵt:誤差g(t): トレンド \\\ s(t): 周期性 \\\ h(t): 休日 \\\ \epsilon_{t}: 誤差

基本的な使い方

Prophetの基本的な使い方を紹介していきます。

ここではAirline Passengersのデータを利用します。今回はkaggleでアップされていたデータを利用しています。

Air Passengers

Air Passengers

Number of air passengers per month

インストール

まずはProphetをインストールします。

1pip install Prophet

準備

まずは、必要なライブラリをインポートします。

1import pandas as pd
2from prophet import Prophet
3
4import seaborn as sns
5sns.set()

利用データ

利用するデータを読み込みます。

1data = pd.read_csv('../input/air-passengers/AirPassengers.csv')

以下のような1949年1月から1960年12月の飛行機の乗客データになっています。

Month#Passengers
01949-01112
11949-02118
21949-03132
31949-04129
41949-05121
.........
1391960-08606
1401960-09508
1411960-10461
1421960-11390
1431960-12432

どのようなデータか可視化してみます。

1sns.lineplot(x="Month", y="#Passengers", data=data)

学習データの準備

学習データの準備をしていきます。

Prophetで学習させるデータは時系列を表すdsカラムと予測する対象となるyカラムが必要になります。

データの形としては今のままでいいですが、カラム名が適切ではないので、カラム名の変更をします。

1data.columns = ['ds', 'y']

また、学習データとテストデータを分割します。ここでは、最後の12ヶ月をテストデータとしています。

1test_length = 12
2train = data.iloc[:-test_length]
3test = data.iloc[-test_length:]

学習

学習データの準備ができたので、そのデータを用いてモデルの学習をします。

1model = Prophet(seasonality_mode='multiplicative')
2model.fit(train)

パラメータ

モデルには以下のパラメータが設定可能です。問題により適切なパラメータを設定することで精度が向上します。

  • growth: トレンドを表す関数(線形かロジスティック曲線)
  • changepoints: トレンドの変化点のリスト
  • n_changepoints: トレンドの変化点の数
  • changepoint_range: トレンドの変化点を推測する幅
  • yearly_seasonality: 年単位の周期性を考慮するか
  • weekly_seasonalityJ: 週単位の周期性を考慮するか
  • daily_seasonality: 日単位の周期性を考慮するか
  • holidays: 休日
  • seasonality_mode: 周期性の傾向
  • seasonality_prior_scale: 周期性の強さを表すパラメータ
  • holidays_prior_scale: 休日の強さを表すパラメータ
  • changepoint_prior_scale: トレンドの変化点の強さを表すパラメータ
  • mcmc_samples: MCMC法(マルコフ連鎖モンテカルロ法)のサンプル数
  • interval_width: 誤差の範囲の広さ
  • uncertainty_samples: 誤差を推測するためのサンプル数
  • stan_backend: バックエンドの指定

以下のようにパラメータの指定ができます。下記は全てデフォルト値になっています。

1params = {'growth': 'linear',
2          'changepoints': None,
3          'n_changepoints': 25,
4          'changepoint_range': 0.8,
5          'yearly_seasonality': 'auto',
6          'weekly_seasonality': 'auto',
7          'daily_seasonality': 'auto',
8          'holidays': None,
9          'seasonality_mode': 'additive',
10          'seasonality_prior_scale': 10.0,
11          'holidays_prior_scale': 10.0,
12          'changepoint_prior_scale': 0.05,
13          'mcmc_samples': 0,
14          'interval_width': 0.80,
15          'uncertainty_samples': 1000,
16          'stan_backend': None}
17
18model = Prophet(**params)

パラメータチューニングする場合は以下のパラメータが多いです。

  • changepoint_prior_scale
  • seasonality_prior_scale
  • holidays_prior_scale
  • seasonality_mode
  • changepoint_range

予測

学習が完了したらテストデータで予測します。

make_future_dataframeで学習データの期間にテストデータ(未来のデータ)を追加したDataFrameを作成します。このデータを使って予測します。

1future = model.make_future_dataframe(periods=test_length, freq='M')
2pred = model.predict(future)

予測した結果は以下のようになります。評価指標を計算する場合はyhatカラムを利用します。

dstrendyhat_loweryhat_uppertrend_lowertrend_uppermultiplicative_termsmultiplicative_terms_lowermultiplicative_terms_upperyearlyyearly_loweryearly_upperadditive_termsadditive_terms_loweradditive_terms_upperyhat
01949-01-01115.60396691.865217116.958190115.603966115.603966-0.101135-0.101135-0.101135-0.101135-0.101135-0.1011350.00.00.0103.912403
11949-02-01117.27501986.208456111.255977117.275019117.275019-0.154216-0.154216-0.154216-0.154216-0.154216-0.1542160.00.00.099.189377
21949-03-01118.784358107.379400131.664532118.784358118.7843580.0027210.0027210.0027210.0027210.0027210.0027210.00.00.0119.107520
31949-04-01120.455412103.425669129.164002120.455412120.455412-0.033256-0.033256-0.033256-0.033256-0.033256-0.0332560.00.00.0116.449565
41949-05-01122.072561106.650815131.868761122.072561122.072561-0.027357-0.027357-0.027357-0.027357-0.027357-0.0273570.00.00.0118.733027
...................................................
1391960-07-31464.860836571.806814598.207529464.129166465.4903360.2573930.2573930.2573930.2573930.2573930.2573930.00.00.0584.512761
1401960-08-31467.660516474.601698499.457627466.735673468.4833370.0410490.0410490.0410490.0410490.0410490.0410490.00.00.0486.857711
1411960-09-30470.369885414.832248440.271254469.248001471.331630-0.091353-0.091353-0.091353-0.091353-0.091353-0.0913530.00.00.0427.400268
1421960-10-31473.169565360.064275386.413159471.812173474.293732-0.210581-0.210581-0.210581-0.210581-0.210581-0.2105810.00.00.0373.529040
1431960-11-30475.878934411.538263436.644966474.320539477.207354-0.108360-0.108360-0.108360-0.108360-0.108360-0.1083600.00.00.0424.312896

可視化

予測した結果を可視化します。 Prophetではmatplotlibなどのライブラリを利用しなくても可視化できます。

黒い点が学習データの実際の値になります。

1pred_plot = model.plot(pred)

plot_componentsを使うとトレンドや周期性の可視化もできます。

1component_plot = model.plot_components(pred)

まとめ

  • Prophetを使うと時系列データを簡単に予測できる

参考

Support

\ この記事が役に立ったと思ったら、サポートお願いします! /

buy me a coffee
Share

Profile

author

Masa

都内のIT企業で働くエンジニア
自分が学んだことをブログでわかりやすく発信していきながらスキルアップを目指していきます!

buy me a coffee