Skip to content

Pr 451 - modified and added tests to statespace #466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: pymc-extras-test
name: pymc-extras
channels:
- conda-forge
- nodefaults
Expand Down
29 changes: 16 additions & 13 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph(
provided when the model was built.
data_dims: str or tuple of str, optional
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
scenario: dict[str, pd.DataFrame], optional
Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables
in the model. pm.set_data is used to replace training data with new values.

Returns
-------
Expand Down Expand Up @@ -1567,8 +1570,10 @@ def _validate_forecast_args(
raise ValueError(
"Integer start must be within the range of the data index used to fit the model."
)
if periods is None and end is None:
raise ValueError("Must specify one of either periods or end")
if periods is None and end is None and not use_scenario_index:
raise ValueError(
"Must specify one of either periods or end unless use_scenario_index=True"
)
if periods is not None and end is not None:
raise ValueError("Must specify exactly one of either periods or end")
if scenario is None and use_scenario_index:
Expand Down Expand Up @@ -2060,9 +2065,18 @@ def forecast(

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,
data_dims=["data_time", OBS_STATE_DIM],
)

for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break
Comment on lines +2072 to +2078
Copy link
Author

@Dekermanjian Dekermanjian May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this logic to update the static shape of the target variable when forecasting. I realize that this logic is naive in a few aspects:
1). This is making an assumption on how the scenario data is constructed (I think I resolved that)
2). The timing of when this is being called may be inappropriate
3). Probably other things that I am not thinking of right now

With your suggestions I can make this more robust, I just wanted to confirm that my suspicion that the issue is that the static shape of the target needs to be updated to reflect the shape of the forecast index?

EDIT:
Sorry about the multiple pings, I didn't highlight all of the code.

I should also mention that these do pass the unit tests in test_statespace.py


group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
mu, cov = grouped_outputs[group_idx]

Expand All @@ -2073,17 +2087,6 @@ def forecast(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

if scenario is not None:
sub_dict = {
forecast_model[data_name]: pt.as_tensor_variable(
scenario.get(data_name), name=data_name
)
for data_name in self.data_names
}

matrices = graph_replace(matrices, replace=sub_dict, strict=True)
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]

_ = LinearGaussianStateSpace(
"forecast",
x0,
Expand Down
6 changes: 6 additions & 0 deletions tests/statespace/test_SARIMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S):
"ignore:Non-stationary starting autoregressive parameters found",
"ignore:Non-invertible starting seasonal moving average",
"ignore:Non-stationary starting seasonal autoregressive",
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
Expand Down Expand Up @@ -361,6 +364,9 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter
"ignore:Non-invertible starting MA parameters found.",
"ignore:Non-stationary starting autoregressive parameters found",
"ignore:Maximum Likelihood optimization failed to converge.",
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
if (d + D) > 0:
Expand Down
91 changes: 69 additions & 22 deletions tests/statespace/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,32 +128,51 @@ def ss_mod_no_exog_dt(rng):


@pytest.fixture(scope="session")
def exog_ss_mod(rng):
ll = st.LevelTrendComponent()
reg = st.RegressionComponent(name="exog", state_names=["a", "b", "c"])
mod = (ll + reg).build(verbose=False)
def exog_data(rng):
# simulate data
df = pd.DataFrame(
{
"date": pd.date_range(start="2023-05-01", end="2023-05-10", freq="D"),
"x1": rng.choice(2, size=10, replace=True).astype(float),
"y": rng.normal(size=(10,)),
}
)

return mod
df.loc[[1, 3, 9], ["y"]] = np.nan
return df.set_index("date")


@pytest.fixture(scope="session")
def exog_pymc_mod(exog_ss_mod, rng):
y = rng.normal(size=(100, 1)).astype(floatX)
X = rng.normal(size=(100, 3)).astype(floatX)
def exog_ss_mod(exog_data):
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
exog = st.RegressionComponent(
name="exog", # Name of this exogenous variable component
k_exog=1, # Only one exogenous variable now
innovations=False, # Typically fixed effect (no stochastic evolution)
state_names=exog_data[["x1"]].columns.tolist(),
)

with pm.Model(coords=exog_ss_mod.coords) as m:
exog_data = pm.Data("data_exog", X)
initial_trend = pm.Normal("initial_trend", dims=["trend_state"])
P0_sigma = pm.Exponential("P0_sigma", 1)
P0 = pm.Deterministic(
"P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"]
combined_model = level_trend + exog
return combined_model.build()


@pytest.fixture(scope="session")
def exog_pymc_mod(exog_ss_mod, exog_data):
# define pymc model
with pm.Model(coords=exog_ss_mod.coords) as struct_model:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=4, dims=["state"])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])

initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])

data_exog = pm.Data(
"data_exog", exog_data["x1"].values[:, None], dims=["time", "exog_state"]
)
beta_exog = pm.Normal("beta_exog", dims=["exog_state"])
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])

sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"])
exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True)
exog_ss_mod.build_statespace_graph(exog_data["y"])

return m
return struct_model


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -844,10 +863,14 @@ def test_forecast(filter_output, mod_name, idata_name, start, end, periods, rng,
assert forecast_idx[0] == (t0 + delta)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@pytest.mark.parametrize("start", [None, -1, 10])
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
@pytest.mark.parametrize("start", [None, -1, 5])
def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
scenario = pd.DataFrame(np.zeros((10, 3)), columns=["a", "b", "c"])
scenario = pd.DataFrame(np.zeros((10, 1)), columns=["x1"])
scenario.iloc[5, 0] = 1e9

forecast_idata = exog_ss_mod.forecast(
Expand All @@ -856,17 +879,41 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):

components = exog_ss_mod.extract_components_from_idata(forecast_idata)
level = components.forecast_latent.sel(state="LevelTrend[level]")
betas = components.forecast_latent.sel(state=["exog[a]", "exog[b]", "exog[c]"])
betas = components.forecast_latent.sel(state=["exog[x1]"])

scenario.index.name = "time"
scenario_xr = (
scenario.unstack()
.to_xarray()
.rename({"level_0": "state"})
.assign_coords(state=["exog[a]", "exog[b]", "exog[c]"])
.assign_coords(state=["exog[x1]"])
)

regression_effect = forecast_idata.forecast_observed.isel(observed_state=0) - level
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])

assert_allclose(regression_effect, regression_effect_expected)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data):
# Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
with exog_pymc_mod:
idata = pm.sample_prior_predictive()

# Define start date and forecast period
start_date, n_periods = pd.to_datetime("2023-05-05"), 5

# Extract exogenous data for the forecast period
scenario = {
"data_exog": pd.DataFrame(
exog_data[["x1"]].loc[start_date:].iloc[:n_periods], columns=exog_data[["x1"]].columns
)
}

# Generate the forecast
forecasts = exog_ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)
5 changes: 5 additions & 0 deletions tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@ def test_autoregressive_model(order, rng):
@pytest.mark.parametrize("s", [10, 25, 50])
@pytest.mark.parametrize("innovations", [True, False])
@pytest.mark.parametrize("remove_first_state", [True, False])
@pytest.mark.filterwarnings(
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_time_seasonality(s, innovations, remove_first_state, rng):
def random_word(rng):
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
Expand Down
Loading