5
5
import pandas as pd
6
6
import scipy .stats as stats
7
7
import statsmodels .api as sm
8
+ import statsmodels .formula .api as smf
9
+ from patsy import dmatrices
8
10
9
11
from ..exceptions import PlotnineError , PlotnineWarning
10
12
from ..utils import get_valid_kwargs
@@ -47,17 +49,25 @@ def lm(data, xseq, **params):
47
49
"""
48
50
Fit OLS / WLS if data has weight
49
51
"""
52
+ if params ['formula' ]:
53
+ return lm_formula (data , xseq , ** params )
54
+
50
55
X = sm .add_constant (data ['x' ])
51
56
Xseq = sm .add_constant (xseq )
57
+ weights = data .get ('weights' , None )
52
58
53
- if 'weight' in data :
54
- init_kwargs , fit_kwargs = separate_method_kwargs (
55
- params ['method_args' ], sm .WLS , sm .WLS .fit )
56
- model = sm .WLS (data ['y' ], X , weights = data ['weight' ], ** init_kwargs )
57
- else :
59
+ if weights is None :
58
60
init_kwargs , fit_kwargs = separate_method_kwargs (
59
61
params ['method_args' ], sm .OLS , sm .OLS .fit )
60
62
model = sm .OLS (data ['y' ], X , ** init_kwargs )
63
+ else :
64
+ if np .any (weights < 0 ):
65
+ raise ValueError (
66
+ "All weights must be greater than zero."
67
+ )
68
+ init_kwargs , fit_kwargs = separate_method_kwargs (
69
+ params ['method_args' ], sm .WLS , sm .WLS .fit )
70
+ model = sm .WLS (data ['y' ], X , weights = data ['weight' ], ** init_kwargs )
61
71
62
72
results = model .fit (** fit_kwargs )
63
73
data = pd .DataFrame ({'x' : xseq })
@@ -74,10 +84,60 @@ def lm(data, xseq, **params):
74
84
return data
75
85
76
86
87
+ def lm_formula (data , xseq , ** params ):
88
+ """
89
+ Fit OLS / WLS using a formula
90
+ """
91
+ formula = params ['formula' ]
92
+ eval_env = params ['enviroment' ]
93
+ weights = data .get ('weight' , None )
94
+
95
+ if weights is None :
96
+ init_kwargs , fit_kwargs = separate_method_kwargs (
97
+ params ['method_args' ], sm .OLS , sm .OLS .fit )
98
+ model = smf .ols (
99
+ formula ,
100
+ data ,
101
+ eval_env = eval_env ,
102
+ ** init_kwargs
103
+ )
104
+ else :
105
+ if np .any (weights < 0 ):
106
+ raise ValueError (
107
+ "All weights must be greater than zero."
108
+ )
109
+ init_kwargs , fit_kwargs = separate_method_kwargs (
110
+ params ['method_args' ], sm .OLS , sm .OLS .fit )
111
+ model = smf .wls (
112
+ formula ,
113
+ data ,
114
+ weights = weights ,
115
+ eval_env = eval_env ,
116
+ ** init_kwargs
117
+ )
118
+
119
+ results = model .fit (** fit_kwargs )
120
+ data = pd .DataFrame ({'x' : xseq })
121
+ data ['y' ] = results .predict (data )
122
+
123
+ if params ['se' ]:
124
+ _ , predictors = dmatrices (formula , data , eval_env = eval_env )
125
+ alpha = 1 - params ['level' ]
126
+ prstd , iv_l , iv_u = wls_prediction_std (
127
+ results , predictors , alpha = alpha )
128
+ data ['se' ] = prstd
129
+ data ['ymin' ] = iv_l
130
+ data ['ymax' ] = iv_u
131
+ return data
132
+
133
+
77
134
def rlm (data , xseq , ** params ):
78
135
"""
79
136
Fit RLM
80
137
"""
138
+ if params ['formula' ]:
139
+ return rlm_formula (data , xseq , ** params )
140
+
81
141
X = sm .add_constant (data ['x' ])
82
142
Xseq = sm .add_constant (xseq )
83
143
@@ -96,10 +156,38 @@ def rlm(data, xseq, **params):
96
156
return data
97
157
98
158
159
+ def rlm_formula (data , xseq , ** params ):
160
+ """
161
+ Fit RLM using a formula
162
+ """
163
+ eval_env = params ['enviroment' ]
164
+ formula = params ['formula' ]
165
+ init_kwargs , fit_kwargs = separate_method_kwargs (
166
+ params ['method_args' ], sm .RLM , sm .RLM .fit )
167
+ model = smf .rlm (
168
+ formula ,
169
+ data ,
170
+ eval_env = eval_env ,
171
+ ** init_kwargs
172
+ )
173
+ results = model .fit (** fit_kwargs )
174
+ data = pd .DataFrame ({'x' : xseq })
175
+ data ['y' ] = results .predict (data )
176
+
177
+ if params ['se' ]:
178
+ warnings .warn ("Confidence intervals are not yet implemented"
179
+ "for RLM smoothing." , PlotnineWarning )
180
+
181
+ return data
182
+
183
+
99
184
def gls (data , xseq , ** params ):
100
185
"""
101
186
Fit GLS
102
187
"""
188
+ if params ['formula' ]:
189
+ return gls_formula (data , xseq , ** params )
190
+
103
191
X = sm .add_constant (data ['x' ])
104
192
Xseq = sm .add_constant (xseq )
105
193
@@ -122,10 +210,42 @@ def gls(data, xseq, **params):
122
210
return data
123
211
124
212
213
+ def gls_formula (data , xseq , ** params ):
214
+ """
215
+ Fit GLL using a formula
216
+ """
217
+ eval_env = params ['enviroment' ]
218
+ formula = params ['formula' ]
219
+ init_kwargs , fit_kwargs = separate_method_kwargs (
220
+ params ['method_args' ], sm .GLS , sm .GLS .fit )
221
+ model = smf .gls (
222
+ formula ,
223
+ data ,
224
+ eval_env = eval_env ,
225
+ ** init_kwargs
226
+ )
227
+ results = model .fit (** fit_kwargs )
228
+ data = pd .DataFrame ({'x' : xseq })
229
+ data ['y' ] = results .predict (data )
230
+
231
+ if params ['se' ]:
232
+ _ , predictors = dmatrices (formula , data , eval_env = eval_env )
233
+ alpha = 1 - params ['level' ]
234
+ prstd , iv_l , iv_u = wls_prediction_std (
235
+ results , predictors , alpha = alpha )
236
+ data ['se' ] = prstd
237
+ data ['ymin' ] = iv_l
238
+ data ['ymax' ] = iv_u
239
+ return data
240
+
241
+
125
242
def glm (data , xseq , ** params ):
126
243
"""
127
244
Fit GLM
128
245
"""
246
+ if params ['formula' ]:
247
+ return glm_formula (data , xseq , ** params )
248
+
129
249
X = sm .add_constant (data ['x' ])
130
250
Xseq = sm .add_constant (xseq )
131
251
@@ -146,6 +266,29 @@ def glm(data, xseq, **params):
146
266
return data
147
267
148
268
269
+ def glm_formula (data , xseq , ** params ):
270
+ eval_env = params ['enviroment' ]
271
+ init_kwargs , fit_kwargs = separate_method_kwargs (
272
+ params ['method_args' ], sm .GLM , sm .GLM .fit )
273
+ model = smf .glm (
274
+ params ['formula' ],
275
+ data ,
276
+ eval_env = eval_env ,
277
+ ** init_kwargs
278
+ )
279
+ results = model .fit (** fit_kwargs )
280
+ data = pd .DataFrame ({'x' : xseq })
281
+ data ['y' ] = results .predict (data )
282
+
283
+ if params ['se' ]:
284
+ df = pd .DataFrame ({'x' : xseq })
285
+ prediction = results .get_prediction (df )
286
+ ci = prediction .conf_int (1 - params ['level' ])
287
+ data ['ymin' ] = ci [:, 0 ]
288
+ data ['ymax' ] = ci [:, 1 ]
289
+ return data
290
+
291
+
149
292
def lowess (data , xseq , ** params ):
150
293
for k in ('is_sorted' , 'return_sorted' ):
151
294
with suppress (KeyError ):
0 commit comments