Skip to content

Commit 97242a6

Browse files
committed
Merge branch 'develop'
2 parents 1545afc + 8bcd987 commit 97242a6

File tree

4 files changed

+106
-15
lines changed

4 files changed

+106
-15
lines changed

pyramid/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# The pyramid module
66

7-
__version__ = '0.2-alpha'
7+
__version__ = '0.3'
88

99
try:
1010
# this var is injected in the setup build to enable

pyramid/arima/arima.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,30 @@
88
from __future__ import print_function, absolute_import, division
99

1010
from sklearn.base import BaseEstimator
11-
from sklearn.utils.validation import check_array, check_is_fitted
11+
from sklearn.utils.validation import check_array, check_is_fitted, column_or_1d
12+
from sklearn.metrics import mean_absolute_error, mean_squared_error
1213
from sklearn.utils.metaestimators import if_delegate_has_method
1314
from statsmodels.tsa.arima_model import ARIMA as _ARIMA
1415
from statsmodels.tsa.base.tsa_model import TimeSeriesModelResults
1516
from statsmodels import api as sm
17+
import numpy as np
1618
import datetime
1719
import warnings
1820
import os
1921

2022
# DTYPE for arrays
2123
from ..compat.numpy import DTYPE
24+
from ..utils import get_callable
2225

2326
__all__ = [
2427
'ARIMA'
2528
]
2629

30+
VALID_SCORING = {
31+
'mse': mean_squared_error,
32+
'mae': mean_absolute_error
33+
}
34+
2735

2836
class ARIMA(BaseEstimator):
2937
"""An ARIMA, or autoregressive integrated moving average, is a generalization of an autoregressive
@@ -123,6 +131,17 @@ class ARIMA(BaseEstimator):
123131
Many warnings might be thrown inside of statsmodels. If ``suppress_warnings``
124132
is True, all of these warnings will be squelched.
125133
134+
out_of_sample_size : int, optional (default=0)
135+
The number of examples from the tail of the time series to use as validation
136+
examples.
137+
138+
scoring : str, optional (default='mse')
139+
If performing validation (i.e., if ``out_of_sample_size`` > 0), the metric
140+
to use for scoring the out-of-sample data. One of {'mse', 'mae'}
141+
142+
scoring_args : dict, optional (default=None)
143+
A dictionary of key-word arguments to be passed to the ``scoring`` metric.
144+
126145
127146
Notes
128147
-----
@@ -141,7 +160,8 @@ class ARIMA(BaseEstimator):
141160
"""
142161
def __init__(self, order, seasonal_order=None, start_params=None, trend='c',
143162
method=None, transparams=True, solver='lbfgs', maxiter=50,
144-
disp=0, callback=None, suppress_warnings=False):
163+
disp=0, callback=None, suppress_warnings=False, out_of_sample_size=0,
164+
scoring='mse', scoring_args=None):
145165
super(ARIMA, self).__init__()
146166

147167
self.order = order
@@ -155,6 +175,9 @@ def __init__(self, order, seasonal_order=None, start_params=None, trend='c',
155175
self.disp = disp
156176
self.callback = callback
157177
self.suppress_warnings = suppress_warnings
178+
self.out_of_sample_size = out_of_sample_size
179+
self.scoring = scoring
180+
self.scoring_args = dict() if not scoring_args else scoring_args
158181

159182
def fit(self, y, exogenous=None, **fit_args):
160183
"""Fit an ARIMA to a vector, ``y``, of observations with an
@@ -171,13 +194,19 @@ def fit(self, y, exogenous=None, **fit_args):
171194
include a constant or trend. If provided, these variables are
172195
used as additional features in the regression operation.
173196
"""
174-
y = check_array(y, ensure_2d=False, force_all_finite=False, copy=True, dtype=DTYPE)
197+
y = column_or_1d(check_array(y, ensure_2d=False, force_all_finite=False, copy=True, dtype=DTYPE))
198+
n_samples = y.shape[0]
175199

176200
# if exog was included, check the array...
177201
if exogenous is not None:
178202
exogenous = check_array(exogenous, ensure_2d=True, force_all_finite=False,
179203
copy=False, dtype=DTYPE)
180204

205+
# determine the CV args, if any
206+
cv = self.out_of_sample_size
207+
scoring = get_callable(self.scoring, VALID_SCORING)
208+
cv = max(min(cv, n_samples), 0) # don't allow negative, don't allow > n_samples
209+
181210
def _fit_wrapper():
182211
# these might change depending on which one
183212
method = self.method
@@ -227,6 +256,14 @@ def _fit_wrapper():
227256
# if the model is fit with an exogenous array, it must be predicted with one as well.
228257
self.fit_with_exog_ = exogenous is not None
229258

259+
# now make a prediction if we're validating to save the out-of-sample value
260+
if cv > 0:
261+
# get the predictions
262+
pred = self.arima_res_.predict(exog=exogenous, typ='linear')[-cv:]
263+
self.oob_ = scoring(y[-cv:], pred, **self.scoring_args)
264+
else:
265+
self.oob_ = np.nan
266+
230267
return self
231268

232269
def predict(self, n_periods=10, exogenous=None):
@@ -500,6 +537,17 @@ def maroots(self):
500537
"""
501538
return self.arima_res_.maroots
502539

540+
def oob(self):
541+
"""If the model was built with ``out_of_sample_size`` > 0, a validation
542+
score will have been computed. Otherwise it will be np.nan.
543+
544+
Returns
545+
-------
546+
oob_ : float
547+
The "out-of-bag" score.
548+
"""
549+
return self.oob_
550+
503551
@if_delegate_has_method('arima_res_')
504552
def params(self):
505553
"""Get the parameters of the model. The order of variables is the trend

pyramid/arima/auto.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
]
2525

2626
# The valid information criteria
27-
VALID_CRITERIA = {'aic', 'bic', 'hqic'}
27+
VALID_CRITERIA = {'aic', 'bic', 'hqic', 'oob'}
2828

2929

3030
def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2, max_q=5,
@@ -33,7 +33,8 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
3333
seasonal_test='ch', n_jobs=1, start_params=None, trend='c', method=None, transparams=True,
3434
solver='lbfgs', maxiter=50, disp=0, callback=None, offset_test_args=None, seasonal_test_args=None,
3535
suppress_warnings=False, error_action='warn', trace=False, random=False, random_state=None,
36-
n_fits=10, return_valid_fits=False, **fit_args):
36+
n_fits=10, return_valid_fits=False, out_of_sample_size=0, scoring='mse', scoring_args=None,
37+
**fit_args):
3738
"""The ``auto_arima`` function seeks to identify the most optimal parameters for an ``ARIMA`` model,
3839
and returns a fitted ARIMA model. This function is based on the commonly-used R function,
3940
`forecase::auto.arima``[3].
@@ -45,8 +46,9 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
4546
conducting the Canova-Hansen to determine the optimal order of seasonal differencing, ``D``.
4647
4748
In order to find the best model, ``auto_arima`` optimizes for a given ``information_criterion``, one of
48-
{'aic', 'bic', 'hqic'} (Akaine Information Criterion, Bayesian Information Criterion or Hannan-Quinn
49-
Information Criterion, respectively) and returns the ARIMA which minimizes the value.
49+
{'aic', 'bic', 'hqic', 'oob'} (Akaine Information Criterion, Bayesian Information Criterion, Hannan-Quinn
50+
Information Criterion, or "out of bag"--for validation scoring--respectively) and returns the ARIMA which
51+
minimizes the value.
5052
5153
Note that due to stationarity issues, ``auto_arima`` might not find a suitable model that will converge. If this
5254
is the case, a ``ValueError`` will be thrown suggesting stationarity-inducing measures be taken prior
@@ -127,8 +129,7 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
127129
128130
information_criterion : str, optional (default='aic')
129131
The information criterion used to select the best ARIMA model. One of
130-
``pyramid.arima.auto_arima.VALID_CRITERIA``, ('aic', 'bic'). Note that if
131-
n_samples <= 3, AIC will be used.
132+
``pyramid.arima.auto_arima.VALID_CRITERIA``, ('aic', 'bic', 'hqic', 'oob').
132133
133134
alpha : float, optional (default=0.05)
134135
Level of the test for testing significance.
@@ -224,6 +225,17 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
224225
If True, will return all valid ARIMA fits. If False (by default), will only
225226
return the best fit.
226227
228+
out_of_sample_size : int, optional (default=0)
229+
The number of examples from the tail of the time series to use as validation
230+
examples.
231+
232+
scoring : str, optional (default='mse')
233+
If performing validation (i.e., if ``out_of_sample_size`` > 0), the metric
234+
to use for scoring the out-of-sample data. One of {'mse', 'mae'}
235+
236+
scoring_args : dict, optional (default=None)
237+
A dictionary of key-word arguments to be passed to the ``scoring`` metric.
238+
227239
**fit_args : dict, optional (default=None)
228240
A dictionary of keyword arguments to pass to the :func:`ARIMA.fit` method.
229241
@@ -282,7 +294,9 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
282294
transparams=transparams, solver=solver, maxiter=maxiter,
283295
disp=disp, callback=callback, fit_params=fit_args,
284296
suppress_warnings=suppress_warnings, trace=trace,
285-
error_action=error_action)),
297+
error_action=error_action, scoring=scoring,
298+
out_of_sample_size=out_of_sample_size,
299+
scoring_args=scoring_args)),
286300
return_valid_fits)
287301

288302
# test ic, and use AIC if n <= 3
@@ -396,7 +410,9 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
396410
transparams=transparams, solver=solver, maxiter=maxiter,
397411
disp=disp, callback=callback, fit_params=fit_args,
398412
suppress_warnings=suppress_warnings, trace=trace,
399-
error_action=error_action)),
413+
error_action=error_action, scoring=scoring,
414+
out_of_sample_size=out_of_sample_size,
415+
scoring_args=scoring_args)),
400416
return_valid_fits)
401417

402418
# seasonality issues
@@ -442,7 +458,8 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
442458
start_params=start_params, trend=trend, method=method, transparams=transparams,
443459
solver=solver, maxiter=maxiter, disp=disp, callback=callback,
444460
fit_params=fit_args, suppress_warnings=suppress_warnings,
445-
trace=trace, error_action=error_action)
461+
trace=trace, error_action=error_action, out_of_sample_size=out_of_sample_size,
462+
scoring=scoring, scoring_args=scoring_args)
446463
for order, seasonal_order in gen)
447464

448465
# filter the non-successful ones
@@ -461,12 +478,14 @@ def auto_arima(y, exogenous=None, start_p=2, d=None, start_q=2, max_p=5, max_d=2
461478

462479
def _fit_arima(x, xreg, order, seasonal_order, start_params, trend, method, transparams,
463480
solver, maxiter, disp, callback, fit_params, suppress_warnings, trace,
464-
error_action):
481+
error_action, out_of_sample_size, scoring, scoring_args):
465482
try:
466483
fit = ARIMA(order=order, seasonal_order=seasonal_order, start_params=start_params,
467484
trend=trend, method=method, transparams=transparams,
468485
solver=solver, maxiter=maxiter, disp=disp,
469-
callback=callback, suppress_warnings=suppress_warnings)\
486+
callback=callback, suppress_warnings=suppress_warnings,
487+
out_of_sample_size=out_of_sample_size, scoring=scoring,
488+
scoring_args=scoring_args)\
470489
.fit(x, exogenous=xreg, **fit_params)
471490

472491
# for non-stationarity errors, return None

pyramid/arima/tests/test_arima.py

+24
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ def test_basic_arima():
101101
assert_array_almost_equal(preds, expected_preds)
102102

103103

104+
def test_with_oob():
105+
# show we can fit with CV (kinda)
106+
arima = ARIMA(order=(2, 1, 2), suppress_warnings=True, out_of_sample_size=10).fit(y=hr)
107+
assert not np.isnan(arima.oob()) # show this works
108+
109+
# show we can fit if ooss < 0 and oob will be nan
110+
arima = ARIMA(order=(2, 1, 2), suppress_warnings=True, out_of_sample_size=-1).fit(y=hr)
111+
assert np.isnan(arima.oob())
112+
113+
# can we do one with an exogenous array, too?
114+
arima = ARIMA(order=(2, 1, 2), suppress_warnings=True, out_of_sample_size=10).fit(
115+
y=hr, exogenous=rs.rand(hr.shape[0], 4))
116+
assert not np.isnan(arima.oob())
117+
118+
104119
def _try_get_attrs(arima):
105120
# show we can get all these attrs without getting an error
106121
attrs = {
@@ -294,6 +309,15 @@ def test_with_seasonality6():
294309
# FIXME: we get an IndexError from statsmodels summary if (0, 0, 0)
295310

296311

312+
def test_with_seasonality7():
313+
# show we can fit one with OOB as the criterion
314+
_ = auto_arima(wineind, start_p=1, start_q=1, max_p=2, max_q=2, m=12,
315+
start_P=0, seasonal=True, n_jobs=1, d=1, D=1,
316+
out_of_sample_size=10, information_criterion='oob',
317+
suppress_warnings=True, error_action='raise', # do raise so it fails fast
318+
random=True, random_state=42, n_fits=3)
319+
320+
297321
def test_corner_cases():
298322
assert_raises(ValueError, auto_arima, wineind, error_action='some-bad-string')
299323

0 commit comments

Comments
 (0)