Skip to content

Commit f45437d

Browse files
authored
[ML-52574] Add covariate support in predict_timeseries for prediction table usage (#169)
* init * change name * fix * fix * fix * fix * fix * fix
1 parent 6fb07f8 commit f45437d

File tree

3 files changed

+165
-36
lines changed

3 files changed

+165
-36
lines changed

runtime/databricks/automl_runtime/forecast/prophet/model.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP, DATE_OFFSET_KEYWORD_MAP
2727
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
2828
from databricks.automl_runtime import version
29-
from databricks.automl_runtime.forecast.utils import is_quaterly_alias, make_future_dataframe
29+
from databricks.automl_runtime.forecast.utils import is_quaterly_alias, make_future_dataframe, apply_preprocess_func
3030

3131

3232
PROPHET_ADDITIONAL_PIP_DEPS = [
@@ -110,26 +110,36 @@ def make_future_dataframe(self, horizon: int = None, include_history: bool = Tru
110110
freq=pd.DateOffset(**offset_kwarg),
111111
include_history=include_history)
112112

113-
def _predict_impl(self, horizon: int = None, include_history: bool = True) -> pd.DataFrame:
113+
def _predict_impl(self, future_df: pd.DataFrame) -> pd.DataFrame:
114114
"""
115115
Predict using the API from prophet model.
116-
:param horizon: Int number of periods to forecast forward.
117-
:param include_history: Boolean to include the historical dates in the data
118-
frame for predictions.
119-
:return: A pd.DataFrame with the forecast components.
116+
:param future_df: future input dataframe. This dataframe should contain
117+
the time series column and covariate columns if available. It is used as the
118+
input for generating predictions.
119+
:return: A pd.DataFrame that represents the model's output. The predicted target
120+
column is named 'yhat'.
120121
"""
121-
future_pd = self.make_future_dataframe(horizon=horizon or self._horizon, include_history=include_history)
122-
return self.model().predict(future_pd)
122+
return self.model().predict(future_df)
123123

124-
def predict_timeseries(self, horizon: int = None, include_history: bool = True) -> pd.DataFrame:
124+
def predict_timeseries(self, horizon: int = None, include_history: bool = True, future_df: pd.DataFrame = None) -> pd.DataFrame:
125125
"""
126-
Predict using the prophet model.
126+
Predict using the prophet model. The input dataframe will be preprocessed if with covariates.
127127
:param horizon: Int number of periods to forecast forward.
128128
:param include_history: Boolean to include the historical dates in the data
129129
frame for predictions.
130-
:return: A pd.DataFrame with the forecast components.
130+
:param future_df: Optional future input dataframe. This dataframe should contain
131+
the time series column and covariate columns if available. It is used as the
132+
input for generating predictions.
133+
:return: A pd.DataFrame that represents the model's output. The predicted target
134+
column is named 'yhat'.
131135
"""
132-
return self._predict_impl(horizon, include_history)
136+
if future_df is None:
137+
future_df = self.make_future_dataframe(horizon=horizon or self._horizon, include_history=include_history)
138+
139+
if self._preprocess_func and self._split_col:
140+
future_df = apply_preprocess_func(future_df, self._preprocess_func, self._split_col)
141+
future_df.rename(columns={self._time_col: "ds"}, inplace=True)
142+
return self._predict_impl(future_df)
133143

134144
def predict(self, context: mlflow.pyfunc.model.PythonModelContext, model_input: pd.DataFrame) -> pd.Series:
135145
"""
@@ -143,15 +153,7 @@ def predict(self, context: mlflow.pyfunc.model.PythonModelContext, model_input:
143153
test_df = model_input.copy()
144154

145155
if self._preprocess_func and self._split_col:
146-
# Apply the same preprocessing pipeline to test_df. The preprocessing function requires the "y" column
147-
# and the split column to be present, as they are used in the trial notebook. These columns are added
148-
# temporarily and removed after preprocessing.
149-
# see https://src.dev.databricks.com/databricks-eng/universe/-/blob/automl/python/databricks/automl/core/sections/templates/preprocess/finish_with_transform.jinja?L3
150-
# and https://src.dev.databricks.com/databricks-eng/universe/-/blob/automl/python/databricks/automl/core/sections/templates/preprocess/select_columns.jinja?L8-10
151-
test_df["y"] = None
152-
test_df[self._split_col] = "prediction"
153-
test_df = self._preprocess_func(test_df)
154-
test_df.drop(columns=["y", self._split_col], inplace=True, errors="ignore")
156+
test_df = apply_preprocess_func(test_df, self._preprocess_func, self._split_col)
155157

156158
test_df.rename(columns={self._time_col: "ds"}, inplace=True)
157159
predict_df = self.model().predict(test_df)
@@ -260,28 +262,36 @@ def _predict_impl(self, df: pd.DataFrame, horizon: int = None, include_history:
260262
future_pd[self._id_cols] = df[self._id_cols].iloc[0]
261263
return future_pd
262264

263-
def predict_timeseries(self, horizon: int = None, include_history: bool = True) -> pd.DataFrame:
265+
def predict_timeseries(self, horizon: int = None, include_history: bool = True, future_df: pd.DataFrame = None) -> pd.DataFrame:
264266
"""
265267
Predict using the prophet model.
266268
:param horizon: Int number of periods to forecast forward.
267269
:param include_history: Boolean to include the historical dates in the data
268270
frame for predictions.
269-
:return: A pd.DataFrame with the forecast components.
271+
:param future_df: Optional future input dataframe. This dataframe should contain
272+
the time series column and covariate columns if available. It is used as the
273+
input for generating predictions.
274+
:return: A pd.DataFrame that represents the model's output. The predicted target
275+
column is named 'yhat'.
270276
"""
271277
horizon=horizon or self._horizon
272-
end_time = pd.Timestamp(self._timeseries_end)
273-
future_df = make_future_dataframe(
274-
start_time=self._timeseries_starts,
275-
end_time=end_time,
276-
horizon=horizon,
277-
frequency_unit=self._frequency_unit,
278-
frequency_quantity=self._frequency_quantity,
279-
include_history=include_history,
280-
groups=self._model_json.keys(),
281-
identity_column_names=self._id_cols
282-
)
278+
if future_df is None:
279+
end_time = pd.Timestamp(self._timeseries_end)
280+
future_df = make_future_dataframe(
281+
start_time=self._timeseries_starts,
282+
end_time=end_time,
283+
horizon=horizon,
284+
frequency_unit=self._frequency_unit,
285+
frequency_quantity=self._frequency_quantity,
286+
include_history=include_history,
287+
groups=self._model_json.keys(),
288+
identity_column_names=self._id_cols
289+
)
283290
future_df["ts_id"] = future_df[self._id_cols].apply(tuple, axis=1)
284-
return future_df.groupby(self._id_cols).apply(lambda df: self._predict_impl(df, horizon, include_history)).reset_index()
291+
if self._preprocess_func and self._split_col:
292+
future_df = apply_preprocess_func(future_df, self._preprocess_func, self._split_col)
293+
future_df.rename(columns={self._time_col: "ds"}, inplace=True)
294+
return future_df.groupby(self._id_cols).apply(lambda df: self._predict_impl(df, horizon, include_history)).reset_index(drop=True)
285295

286296
@staticmethod
287297
def get_reserved_cols() -> List[str]:
@@ -354,7 +364,6 @@ def model_prediction(df):
354364
return_df = test_df.merge(predict_df, how="left", on=["ds"] + self._id_cols)
355365
return return_df["yhat"]
356366

357-
358367
def mlflow_prophet_log_model(prophet_model: Union[ProphetModel, MultiSeriesProphetModel],
359368
sample_input: pd.DataFrame = None) -> None:
360369
"""

runtime/databricks/automl_runtime/forecast/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,21 @@ def calculate_period_differences(
279279
freq_alias = PERIOD_ALIAS_MAP[OFFSET_ALIAS_MAP[frequency_unit]]
280280
# It is intended to get the floor value. And in the later check we will use this floor value to find out if it is not consistent.
281281
return (end_time.to_period(freq_alias) - start_time.to_period(freq_alias)).n // frequency_quantity
282+
283+
def apply_preprocess_func(df: pd.DataFrame, preprocess_func: callable, split_col: str) -> pd.DataFrame:
284+
"""
285+
Apply the preprocessing function to the dataframe. The preprocessing function requires the "y" column
286+
and the split column to be present, as they are used in the trial notebook. These columns are added
287+
temporarily and removed after preprocessing.
288+
see https://src.dev.databricks.com/databricks-eng/universe/-/blob/automl/python/databricks/automl/core/sections/templates/preprocess/finish_with_transform.jinja?L3
289+
and https://src.dev.databricks.com/databricks-eng/universe/-/blob/automl/python/databricks/automl/core/sections/templates/preprocess/select_columns.jinja?L8-10
290+
:param df: pd.DataFrame to be preprocessed.
291+
:param preprocess_func: preprocessing function to be applied to the dataframe.
292+
:param split_col: name of the split column to be added to the dataframe.
293+
:return: preprocessed pd.DataFrame.
294+
"""
295+
df["y"] = None
296+
df[split_col] = "prediction"
297+
df = preprocess_func(df)
298+
df.drop(columns=["y", split_col], inplace=True, errors="ignore")
299+
return df

runtime/tests/automl_runtime/forecast/prophet/model_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#
1616

1717
import unittest
18+
from unittest.mock import patch
1819
import datetime
1920

2021
import pandas as pd
@@ -103,6 +104,42 @@ def test_model_save_and_load(self):
103104
include_history=False
104105
)
105106
self.assertEqual(len(forecast_future_pd), 1)
107+
108+
@patch("databricks.automl_runtime.forecast.prophet.model.ProphetModel._predict_impl")
109+
def test_predict_timeseries_with_preprocess_func(self, mock_predict_impl):
110+
# Mock the output of _predict_impl
111+
mock_predict_impl.side_effect = lambda df: df
112+
113+
# Define a preprocess function
114+
def preprocess_func(df):
115+
df["feature"] = df["feature"] * 2
116+
return df
117+
118+
# Create a ProphetModel instance with preprocess_func
119+
prophet_model = ProphetModel(
120+
model_json=PROPHET_MODEL_JSON,
121+
horizon=3,
122+
frequency_unit="d",
123+
frequency_quantity=1,
124+
time_col="time",
125+
preprocess_func=preprocess_func,
126+
split_col="split"
127+
)
128+
129+
# Input DataFrame
130+
input_df = pd.DataFrame({"time": ["2020-10-01", "2020-10-02", "2020-10-03"], "feature": [1, 2, 3]})
131+
132+
# Call predict_timeseries
133+
result = prophet_model.predict_timeseries(future_df=input_df)
134+
135+
# Assertions
136+
mock_predict_impl.assert_called_once()
137+
self.assertEqual(len(result), 3)
138+
139+
# Check if the preprocess_func was applied
140+
processed_df = mock_predict_impl.call_args[0][0] # Get the DataFrame passed to _predict_impl
141+
self.assertTrue((processed_df["feature"] == [2, 4, 6]).all()) # Check if "y" was doubled
142+
self.assertIn("ds", processed_df.columns) # Ensure "ds" column exists
106143

107144
def test_make_future_dataframe(self):
108145
for feq_unit in OFFSET_ALIAS_MAP:
@@ -451,3 +488,68 @@ def preprocess_func(df):
451488
)
452489
yhat = prophet_model.predict(None, test_df)
453490
self.assertEqual(2, len(yhat))
491+
492+
@patch("databricks.automl_runtime.forecast.prophet.model.MultiSeriesProphetModel._predict_impl")
493+
def test_predict_timeseries(self, mock_predict_impl):
494+
# Mock the output of _predict_impl
495+
mock_predict_impl.side_effect = lambda df, horizon, include_history: pd.DataFrame({
496+
"ds": df["ds"],
497+
"feature": df["feature"],
498+
"id": df["id"]
499+
})
500+
501+
# Define a preprocess function
502+
def preprocess_func(df):
503+
df["feature"] = df["feature"] * 2
504+
return df
505+
506+
# Create a MultiSeriesProphetModel instance
507+
model_json = {
508+
("id1",): '{"model": "mock_model_1"}',
509+
("id2",): '{"model": "mock_model_2"}'
510+
}
511+
timeseries_starts = {("id1",): pd.Timestamp("2020-01-01"), ("id2",): pd.Timestamp("2020-01-01")}
512+
timeseries_end = "2020-12-31"
513+
prophet_model = MultiSeriesProphetModel(
514+
model_json=model_json,
515+
timeseries_starts=timeseries_starts,
516+
timeseries_end=timeseries_end,
517+
horizon=3,
518+
frequency_unit="d",
519+
frequency_quantity=1,
520+
time_col="time",
521+
id_cols=["id"],
522+
preprocess_func=preprocess_func,
523+
split_col="split"
524+
)
525+
526+
# Input DataFrame
527+
input_df = pd.DataFrame({
528+
"time": ["2020-10-01", "2020-10-02", "2020-10-03", "2020-10-01", "2020-10-02", "2020-10-03"],
529+
"feature": [1, 2, 3, 4, 5, 6],
530+
"id": ["id1", "id1", "id1", "id2", "id2", "id2"]
531+
})
532+
533+
# Call predict_timeseries
534+
result = prophet_model.predict_timeseries(future_df=input_df)
535+
536+
# Assertions
537+
mock_predict_impl.assert_called()
538+
self.assertEqual(len(result), 6)
539+
self.assertIn("feature", result.columns)
540+
self.assertIn("ds", result.columns)
541+
self.assertIn("id", result.columns)
542+
543+
# Check the calls to _predict_impl
544+
calls = mock_predict_impl.call_args_list
545+
self.assertEqual(len(calls), 2) # Ensure _predict_impl is called twice (once per group)
546+
547+
# Check the first call
548+
first_call_df = calls[0][0][0] # Get the DataFrame passed in the first call
549+
self.assertTrue((first_call_df["feature"] == [2, 4, 6]).all())
550+
self.assertTrue((first_call_df["id"] == ["id1", "id1", "id1"]).all())
551+
552+
# Check the second call
553+
second_call_df = calls[1][0][0] # Get the DataFrame passed in the second call
554+
self.assertTrue((second_call_df["feature"] == [8, 10, 12]).all())
555+
self.assertTrue((second_call_df["id"] == ["id2", "id2", "id2"]).all())

0 commit comments

Comments
 (0)