После написания статьи о Prophet и SARIMA я подумал, что было бы интересно сравнить прогнозы, построив обе модели на одном и том же наборе данных. В этом посте я попытаюсь спрогнозировать индекс промышленного производства США - экономический индикатор, который измеряет реальный объем производства для всех производственных, горнодобывающих, электроэнергетических и газовых предприятий США. Выходные данные двух моделей будут сравниваться с использованием значения R-квадрата (силы корреляции) и средней абсолютной ошибки (средней величины ошибок в наборе прогнозов) и среднеквадратичной ошибки.

Записную книжку data и jupyter можно скачать с моей страницы github.

Импорт данных и библиотек

Первый шаг - импортировать библиотеки, такие как fbprophet, statsmodels, sklearn, pandas, numpy, seaborn, matplotlib и т. Д. Убедитесь, что вы установили эти библиотеки перед запуском программы. Для создания графиков мы будем использовать стиль «пять тридцать восемь».

from fbprophet import Prophet
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import warnings
import itertools
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.stattools import adfuller
import statsmodels.api as sm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
warnings.filterwarnings("ignore")
plt.style.use('fivethirtyeight')

Затем мы импортируем индекс промышленного производства США с веб-сайта Федерального резервного банка Сент-Луиса и загрузим эти данные в «данные» фрейма данных pandas.

data = pd.read_csv("INDPRO.csv")
# Check the last 5 elements of the dataframe
data.tail()

# Check if the the data is set up in proper format and then start modeling/forecasting. 
data.dtypes
DATE       object
INDPRO    float64
dtype: object

Тип данных, определяющий поле «ДАТА», является объектом. Давайте преобразуем это в формат datetime, что является предпосылкой модели Prophet.

#Convert 'DATE' object to date datatype
data['DATE'] = pd.to_datetime(data['DATE'])
#Visualize the dataframe
plt.figure(figsize=(10,5))
sns.lineplot(data=data, x="DATE", y="INDPRO")
plt.title("U.S. Industrial Production Index (INDPRO)")
plt.grid(True)
plt.show()

Prophet ожидает, что формат фрейма данных будет конкретным. Модель ожидает столбец «ds», содержащий поле datetime, и столбец «y», содержащий значение, которое мы хотим моделировать / прогнозировать. Следовательно, нам необходимо соответствующим образом переименовать столбцы. Затем мы определяем параметры Prophet для оптимизации вывода модели.

data.columns = ["ds","y"]
model = Prophet(growth="linear", seasonality_mode="multiplicative", changepoint_prior_scale=30, seasonality_prior_scale=35,
               daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False
               ).add_seasonality(
                name='monthly',
                period=30.5,
                fourier_order=30)

model.fit(data)
<fbprophet.forecaster.Prophet at 0x2022a54a520>

Пришло время приступить к прогнозированию. В Prophet мы начинаем с создания некоторых данных о будущем с помощью следующей команды:

future = model.make_future_dataframe(periods= 120, freq='m')

В этой строке кода мы создали фрейм данных pandas со 120 (периодами = 120) будущими точками данных с ежемесячной периодичностью (freq = «m»). В следующей строке кода мы проверяем последние пять дат прогнозируемых данных.

future.tail()

Теперь мы попытаемся предсказать фактические значения, используя библиотеку Prophet, и проверим последние пять элементов прогноза.

forecast = model.predict(future)
forecast.tail()

Если мы посмотрим на данные с помощью .tail (), мы заметим, что во фрейме данных прогноза есть несколько столбцов. Важными (на данный момент) являются «ds» (дата и время), «yhat» (прогноз), «yhat_lower» и «yhat_upper» (уровни неопределенности).

forecast[["ds","yhat","yhat_lower","yhat_upper"]].head()

# Plot the graph of this data to get an understanding of how well forecast looks
model.plot(forecast);
plt.title("U.S. Industrial Production Index (INDPRO)")
plt.show()

Далее мы проверим надежность модели, используя лучшие метрики для измерения точности этой модели. Использование комбинации R-квадрата, среднеквадратичной ошибки и средней абсолютной ошибки поможет нам оценить качество нашей модели. Мы создадим библиотеку Python Scikit-Learn, чтобы быстро вычислить эти показатели.

Определения

R-квадрат (R2) - это статистическая мера, которая представляет собой долю дисперсии для зависимой переменной, которая объясняется независимой переменной или переменными в регрессионной модели.

Средняя абсолютная ошибка (MAE) измеряет среднюю величину ошибок в наборе прогнозов без учета их направления. Это среднее значение по тестовой выборке абсолютных различий между прогнозом и фактическим наблюдением, при котором все индивидуальные различия имеют равный вес.

RMSE (среднеквадратичная ошибка) - это среднее расстояние между точкой данных от подобранной линии, измеренное вдоль вертикальной линии. RMSE можно напрямую интерпретировать в единицах измерения, и поэтому он является лучшим показателем соответствия, чем коэффициент корреляции.

# calculate MAE between expected and predicted values
y_true = data['y'].values
y_pred = forecast['yhat'][:1225].values
mae = mean_absolute_error(y_true, y_pred)
print('MAE: %.3f' % mae)
r = r2_score(y_true, y_pred)
print('R-squared Score: %.3f' % r)
rms = mean_squared_error(y_true, y_pred, squared=False)
print('RMSE: %.3f' % rms)
MAE: 1.219
R-squared Score: 0.997
RMSE: 1.798

Для данных временных рядов INDPRO США модель Пророка дает значение R-квадрат 0,997, то есть 99,7% дисперсии в нашем наборе данных объясняется этой моделью. MAE рассчитывается как 1,219, то есть для каждой точки данных средняя ошибка величины составляет примерно 1,22%, а RMSE составляет 1,798, что указывает на надежность нашей модели.

Наконец, мы создаем график для сравнения фактических и прогнозируемых значений, чтобы дать четкое представление о том, как наша модель визуально выглядит по сравнению с существующим набором данных INDPRO США.

plt.figure(figsize=(10,5))
# plot expected vs actual
plt.plot(y_true, label='Actual')
plt.plot(y_pred, label='Predicted')
plt.title("United States Industrial Production Index (INDPRO)")
plt.grid(True)
plt.legend()
plt.show()

Модель SARIMA

Теперь давайте посмотрим, как модель SARIMA будет работать для того же набора данных. Обозначение модели - SARIMA (p, d, q). (P, D, Q) lag. Эти три параметра учитывают сезонность, тенденцию и шум в данных. Я попытался вычислить AIC (информационный критерий Акаике), который является оценкой относительного качества статистических моделей для заданного набора данных. Учитывая набор моделей для данных, AIC оценивает качество каждой модели относительно каждой из других моделей. Нам нужно выбрать лучшую комбинацию, которая обеспечивает наименьшее значение AIC. Следующая программа определит оптимальный триплет, необходимый для получения наилучшей (P, D, Q) комбинации от 0 до 2. Для этого короткого фрагмента кода мы будем использовать библиотеку itertools.

data = pd.read_csv('INDPRO.csv')

# Define the p, d and q parameters to take any value between 0 and 3
p = d = q = range(0, 2)

# Generate all different combinations of p, q and q triplets
simple_pdq = list(itertools.product(p, d, q))

# Generate all different combinations of seasonal p, q and q triplets
seasonal_pdq = [(i[0], i[1], i[2], 12) for i in list(itertools.product(p, d, q))]

print('Parameter combinations for Seasonal ARIMA...')

warnings.filterwarnings("ignore") # specify to ignore warning messages

for param in simple_pdq:
    for param_seasonal in seasonal_pdq:
        try:
            mod = sm.tsa.statespace.SARIMAX(data['INDPRO'],
                                            order=param,
                                            )

            results = mod.fit()

            print('ARIMA{}x{}12 - AIC:{}'.format(param, param_seasonal, results.aic))
        except:
            continue
Parameter combinations for Seasonal ARIMA...
ARIMA(0, 0, 0)x(0, 0, 0, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(0, 0, 1, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(0, 1, 0, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(0, 1, 1, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(1, 0, 0, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(1, 0, 1, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(1, 1, 0, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 0)x(1, 1, 1, 12)12 - AIC:13402.14774760819
ARIMA(0, 0, 1)x(0, 0, 0, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(0, 0, 1, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(0, 1, 0, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(0, 1, 1, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(1, 0, 0, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(1, 0, 1, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(1, 1, 0, 12)12 - AIC:11723.74201989737
ARIMA(0, 0, 1)x(1, 1, 1, 12)12 - AIC:11723.74201989737
ARIMA(0, 1, 0)x(0, 0, 0, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(0, 0, 1, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(0, 1, 0, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(0, 1, 1, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(1, 0, 0, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(1, 0, 1, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(1, 1, 0, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 0)x(1, 1, 1, 12)12 - AIC:2305.0299916689955
ARIMA(0, 1, 1)x(0, 0, 0, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(0, 0, 1, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(0, 1, 0, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(0, 1, 1, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(1, 0, 0, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(1, 0, 1, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(1, 1, 0, 12)12 - AIC:2133.7033551662826
ARIMA(0, 1, 1)x(1, 1, 1, 12)12 - AIC:2133.7033551662826
ARIMA(1, 0, 0)x(0, 0, 0, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(0, 0, 1, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(0, 1, 0, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(0, 1, 1, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(1, 0, 0, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(1, 0, 1, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(1, 1, 0, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 0)x(1, 1, 1, 12)12 - AIC:2318.496694496929
ARIMA(1, 0, 1)x(0, 0, 0, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(0, 0, 1, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(0, 1, 0, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(0, 1, 1, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(1, 0, 0, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(1, 0, 1, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(1, 1, 0, 12)12 - AIC:2147.1544799750695
ARIMA(1, 0, 1)x(1, 1, 1, 12)12 - AIC:2147.1544799750695
ARIMA(1, 1, 0)x(0, 0, 0, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(0, 0, 1, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(0, 1, 0, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(0, 1, 1, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(1, 0, 0, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(1, 0, 1, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(1, 1, 0, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 0)x(1, 1, 1, 12)12 - AIC:2159.076233067918
ARIMA(1, 1, 1)x(0, 0, 0, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(0, 0, 1, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(0, 1, 0, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(0, 1, 1, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(1, 0, 0, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(1, 0, 1, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(1, 1, 0, 12)12 - AIC:2135.4274909372502
ARIMA(1, 1, 1)x(1, 1, 1, 12)12 - AIC:2135.4274909372502
best_model = SARIMAX(data['INDPRO'], order=(0, 1, 1), seasonal_order=(0, 0, 0, 12)).fit()
print(best_model.summary())
SARIMAX Results                                
==============================================================================
Dep. Variable:                 INDPRO   No. Observations:                 1225
Model:               SARIMAX(0, 1, 1)   Log Likelihood               -1064.852
Date:                Tue, 16 Mar 2021   AIC                           2133.703
Time:                        22:04:40   BIC                           2143.923
Sample:                             0   HQIC                          2137.549
                               - 1225                                         
Covariance Type:                  opg                                         
==============================================================================
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
ma.L1          0.3835      0.005     79.215      0.000       0.374       0.393
sigma2         0.3335      0.002    170.314      0.000       0.330       0.337
===================================================================================
Ljung-Box (L1) (Q):                   0.02   Jarque-Bera (JB):           1024545.17
Prob(Q):                              0.89   Prob(JB):                         0.00
Heteroskedasticity (H):              16.73   Skew:                            -6.34
Prob(H) (two-sided):                  0.00   Kurtosis:                       144.17
===================================================================================

Warnings:
[1] Covariance matrix calculated using the outer product of gradients (complex-step).

После исследования различных комбинаций p, d и q мы получаем самый низкий номер AIC для модели SARIMA с порядком (0,1,1) и сезонным порядком (0,0,0,12).

Модельный прогноз

На этапе прогноза мы попытаемся спрогнозировать INDPRO США на следующие 120 этапов, то есть на 10 лет.

#Forecasting 10 years ahead
forecast_values = best_model.get_forecast(steps = 120)

#Confidence intervals of the forecasted values
forecast_ci = forecast_values.conf_int()

#Plot the data
ax = data.plot(x='DATE' ,y='INDPRO', figsize = (12, 5), legend = True)

#Plot the forecasted values 
forecast_values.predicted_mean.plot(ax=ax, label='Forecasts', figsize = (12, 5), grid=True)

#Plot the confidence intervals
ax.fill_between(forecast_ci.index,
                forecast_ci.iloc[: , 0],
                forecast_ci.iloc[: , 1], color='#D3D3D3', alpha = .5)
plt.title("United States Industrial Production Index (INDPRO)", size=16)
plt.ylabel('INDPRO', size=12)
plt.xlabel('Date', size=12)
plt.legend(loc='upper center', prop={'size': 12})
#annotation
ax.text(1235, 82, 'FORECAST', fontsize=11,  color='RED')
ax.text(1275, 72, 'TO', fontsize=11,  color='RED')
ax.text(1260, 62, '2030', fontsize=11,  color='RED')
plt.show()

Проверка прогноза

Чтобы оценить производительность модели, мы вычисляем оценку R-квадрат и среднеквадратичную ошибку моего набора данных, чтобы проверить подлинность модели. Модель имеет точность 94,5%, что неплохо. Средняя абсолютная ошибка модели составляет 0,64, а RMSE - 1,25, что достаточно мало для того, чтобы мы были уверены в способности модели точно прогнозировать в будущем.

#divide into train and validation set to calculate R-squared score and mean absolute percentage error 
train = data[:int(0.85*(len(data)))]
test = data[int(0.85*(len(data))):]
start=len(train)
end=len(train)+len(test)-1
predictions = best_model.predict(start=start, end=end, dynamic=False, typ='levels').rename('SARIMA Predictions')
evaluation_results = pd.DataFrame({'r2_score': r2_score(test['INDPRO'], predictions)}, index=[0])
evaluation_results['mean_absolute_error'] = mean_absolute_error(test['INDPRO'], predictions)
evaluation_results['root_mean_squared_error'] = np.sqrt(mean_squared_error(test['INDPRO'], predictions))
evaluation_results

Наконец, мы сравниваем фактические и прогнозируемые значения, чтобы получить четкое представление об эффективности нашей модели.

data['sarima_model'] = best_model.fittedvalues
forecast = best_model.predict(start=data.shape[0], end=data.shape[0] + 120)
forecast = data['sarima_model'].append(forecast)
plt.figure(figsize=(12, 5))
plt.plot(forecast, color='r', label='Forecast')
plt.axvspan(data.index[-1], forecast.index[-1], alpha=0.6, color='lightgrey')
plt.plot(data['INDPRO'], label='INDPRO')
plt.legend()
plt.show()

Заключение

После подгонки модели в меру наших возможностей краткое изложение результатов SARIMA и Prophet выглядит следующим образом:

Из сводной таблицы мы видим, что модель SARIMA имеет более низкую среднюю абсолютную ошибку (MAE) и среднеквадратичную ошибку (RMSE) по сравнению с моделью, построенной в Prophet. Однако значение R-Squared на 5% ниже, что указывает на более низкую положительную линейную связь по сравнению с моделью Prophet.

На этом публикация завершена. Для получения дополнительных статей, пожалуйста, проверьте следующие ссылки: -