Skip to content

Commit ca1a86e

Browse files
De-duplicate exogenous dim between DFM and SARIMAX
1 parent 1c35fc7 commit ca1a86e

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

pymc_extras/statespace/models/SARIMAX.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ALL_STATE_AUX_DIM,
1818
ALL_STATE_DIM,
1919
AR_PARAM_DIM,
20-
EXOGENOUS_DIM,
20+
EXOG_STATE_DIM,
2121
MA_PARAM_DIM,
2222
OBS_STATE_DIM,
2323
SARIMAX_STATE_STRUCTURES,
@@ -315,7 +315,7 @@ def param_names(self):
315315
def data_info(self) -> dict[str, dict[str, Any]]:
316316
info = {
317317
"exogenous_data": {
318-
"dims": (TIME_DIM, EXOGENOUS_DIM),
318+
"dims": (TIME_DIM, EXOG_STATE_DIM),
319319
"shape": (None, self.k_exog),
320320
}
321321
}
@@ -403,7 +403,7 @@ def param_dims(self):
403403
"ma_params": (MA_PARAM_DIM,),
404404
"seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
405405
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
406-
"beta_exog": (EXOGENOUS_DIM,),
406+
"beta_exog": (EXOG_STATE_DIM,),
407407
}
408408
if self.k_endog == 1:
409409
coord_map["sigma_state"] = None
@@ -438,7 +438,7 @@ def coords(self) -> dict[str, Sequence]:
438438
if self.Q > 0:
439439
coords.update({SEASONAL_MA_PARAM_DIM: list(range(1, self.Q + 1))})
440440
if self.k_exog > 0:
441-
coords.update({EXOGENOUS_DIM: self.exog_state_names})
441+
coords.update({EXOG_STATE_DIM: self.exog_state_names})
442442
return coords
443443

444444
def _stationary_initialization(self):

pymc_extras/statespace/utils/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
ETS_SEASONAL_DIM = "seasonal_lag"
1515
FACTOR_DIM = "factor"
1616
ERROR_AR_PARAM_DIM = "error_lag_ar"
17-
EXOG_STATE_DIM = "exogenous_state"
18-
EXOGENOUS_DIM = "exogenous"
17+
EXOG_STATE_DIM = "exogenous"
1918

2019
NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
2120
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]

0 commit comments

Comments
 (0)