Skip to content

Commit 99ab764

Browse files
committed
Add formula to stat_smooth
closes #311
1 parent ebb2c3f commit 99ab764

10 files changed

+236
-7
lines changed

doc/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ New Features
1212
makes it easy to change the ordering of a discrete variable according
1313
to some other variable/column.
1414

15+
- :class:`~plotnine.stats.stat_smooth` can now use formulae for linear
16+
models.
17+
1518

1619
Bug Fixes
1720
*********

doc/conf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@
410410
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None),
411411
'sklearn': ('http://scikit-learn.org/stable/', None),
412412
'skmisc': ('https://has2k1.github.io/scikit-misc/', None),
413-
'adjustText': ('https://adjusttext.readthedocs.io/en/latest/', None)
413+
'adjustText': ('https://adjusttext.readthedocs.io/en/latest/', None),
414+
'patsy': ('https://patsy.readthedocs.io/en/stable', None)
414415
}
415416

416417

plotnine/stats/smoothers.py

+148-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pandas as pd
66
import scipy.stats as stats
77
import statsmodels.api as sm
8+
import statsmodels.formula.api as smf
9+
from patsy import dmatrices
810

911
from ..exceptions import PlotnineError, PlotnineWarning
1012
from ..utils import get_valid_kwargs
@@ -47,17 +49,25 @@ def lm(data, xseq, **params):
4749
"""
4850
Fit OLS / WLS if data has weight
4951
"""
52+
if params['formula']:
53+
return lm_formula(data, xseq, **params)
54+
5055
X = sm.add_constant(data['x'])
5156
Xseq = sm.add_constant(xseq)
57+
weights = data.get('weights', None)
5258

53-
if 'weight' in data:
54-
init_kwargs, fit_kwargs = separate_method_kwargs(
55-
params['method_args'], sm.WLS, sm.WLS.fit)
56-
model = sm.WLS(data['y'], X, weights=data['weight'], **init_kwargs)
57-
else:
59+
if weights is None:
5860
init_kwargs, fit_kwargs = separate_method_kwargs(
5961
params['method_args'], sm.OLS, sm.OLS.fit)
6062
model = sm.OLS(data['y'], X, **init_kwargs)
63+
else:
64+
if np.any(weights < 0):
65+
raise ValueError(
66+
"All weights must be greater than zero."
67+
)
68+
init_kwargs, fit_kwargs = separate_method_kwargs(
69+
params['method_args'], sm.WLS, sm.WLS.fit)
70+
model = sm.WLS(data['y'], X, weights=data['weight'], **init_kwargs)
6171

6272
results = model.fit(**fit_kwargs)
6373
data = pd.DataFrame({'x': xseq})
@@ -74,10 +84,60 @@ def lm(data, xseq, **params):
7484
return data
7585

7686

87+
def lm_formula(data, xseq, **params):
88+
"""
89+
Fit OLS / WLS using a formula
90+
"""
91+
formula = params['formula']
92+
eval_env = params['enviroment']
93+
weights = data.get('weight', None)
94+
95+
if weights is None:
96+
init_kwargs, fit_kwargs = separate_method_kwargs(
97+
params['method_args'], sm.OLS, sm.OLS.fit)
98+
model = smf.ols(
99+
formula,
100+
data,
101+
eval_env=eval_env,
102+
**init_kwargs
103+
)
104+
else:
105+
if np.any(weights < 0):
106+
raise ValueError(
107+
"All weights must be greater than zero."
108+
)
109+
init_kwargs, fit_kwargs = separate_method_kwargs(
110+
params['method_args'], sm.OLS, sm.OLS.fit)
111+
model = smf.wls(
112+
formula,
113+
data,
114+
weights=weights,
115+
eval_env=eval_env,
116+
**init_kwargs
117+
)
118+
119+
results = model.fit(**fit_kwargs)
120+
data = pd.DataFrame({'x': xseq})
121+
data['y'] = results.predict(data)
122+
123+
if params['se']:
124+
_, predictors = dmatrices(formula, data, eval_env=eval_env)
125+
alpha = 1 - params['level']
126+
prstd, iv_l, iv_u = wls_prediction_std(
127+
results, predictors, alpha=alpha)
128+
data['se'] = prstd
129+
data['ymin'] = iv_l
130+
data['ymax'] = iv_u
131+
return data
132+
133+
77134
def rlm(data, xseq, **params):
78135
"""
79136
Fit RLM
80137
"""
138+
if params['formula']:
139+
return rlm_formula(data, xseq, **params)
140+
81141
X = sm.add_constant(data['x'])
82142
Xseq = sm.add_constant(xseq)
83143

@@ -96,10 +156,38 @@ def rlm(data, xseq, **params):
96156
return data
97157

98158

159+
def rlm_formula(data, xseq, **params):
160+
"""
161+
Fit RLM using a formula
162+
"""
163+
eval_env = params['enviroment']
164+
formula = params['formula']
165+
init_kwargs, fit_kwargs = separate_method_kwargs(
166+
params['method_args'], sm.RLM, sm.RLM.fit)
167+
model = smf.rlm(
168+
formula,
169+
data,
170+
eval_env=eval_env,
171+
**init_kwargs
172+
)
173+
results = model.fit(**fit_kwargs)
174+
data = pd.DataFrame({'x': xseq})
175+
data['y'] = results.predict(data)
176+
177+
if params['se']:
178+
warnings.warn("Confidence intervals are not yet implemented"
179+
"for RLM smoothing.", PlotnineWarning)
180+
181+
return data
182+
183+
99184
def gls(data, xseq, **params):
100185
"""
101186
Fit GLS
102187
"""
188+
if params['formula']:
189+
return gls_formula(data, xseq, **params)
190+
103191
X = sm.add_constant(data['x'])
104192
Xseq = sm.add_constant(xseq)
105193

@@ -122,10 +210,42 @@ def gls(data, xseq, **params):
122210
return data
123211

124212

213+
def gls_formula(data, xseq, **params):
214+
"""
215+
Fit GLL using a formula
216+
"""
217+
eval_env = params['enviroment']
218+
formula = params['formula']
219+
init_kwargs, fit_kwargs = separate_method_kwargs(
220+
params['method_args'], sm.GLS, sm.GLS.fit)
221+
model = smf.gls(
222+
formula,
223+
data,
224+
eval_env=eval_env,
225+
**init_kwargs
226+
)
227+
results = model.fit(**fit_kwargs)
228+
data = pd.DataFrame({'x': xseq})
229+
data['y'] = results.predict(data)
230+
231+
if params['se']:
232+
_, predictors = dmatrices(formula, data, eval_env=eval_env)
233+
alpha = 1 - params['level']
234+
prstd, iv_l, iv_u = wls_prediction_std(
235+
results, predictors, alpha=alpha)
236+
data['se'] = prstd
237+
data['ymin'] = iv_l
238+
data['ymax'] = iv_u
239+
return data
240+
241+
125242
def glm(data, xseq, **params):
126243
"""
127244
Fit GLM
128245
"""
246+
if params['formula']:
247+
return glm_formula(data, xseq, **params)
248+
129249
X = sm.add_constant(data['x'])
130250
Xseq = sm.add_constant(xseq)
131251

@@ -146,6 +266,29 @@ def glm(data, xseq, **params):
146266
return data
147267

148268

269+
def glm_formula(data, xseq, **params):
270+
eval_env = params['enviroment']
271+
init_kwargs, fit_kwargs = separate_method_kwargs(
272+
params['method_args'], sm.GLM, sm.GLM.fit)
273+
model = smf.glm(
274+
params['formula'],
275+
data,
276+
eval_env=eval_env,
277+
**init_kwargs
278+
)
279+
results = model.fit(**fit_kwargs)
280+
data = pd.DataFrame({'x': xseq})
281+
data['y'] = results.predict(data)
282+
283+
if params['se']:
284+
df = pd.DataFrame({'x': xseq})
285+
prediction = results.get_prediction(df)
286+
ci = prediction.conf_int(1 - params['level'])
287+
data['ymin'] = ci[:, 0]
288+
data['ymax'] = ci[:, 1]
289+
return data
290+
291+
149292
def lowess(data, xseq, **params):
150293
for k in ('is_sorted', 'return_sorted'):
151294
with suppress(KeyError):

plotnine/stats/stat_smooth.py

+16
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def my_smoother(data, xseq, **params):
7676
data['ymax'] = high
7777
7878
return data
79+
formula : formula_like
80+
An object that can be used to construct a patsy design matrix.
81+
This is usually a string. You can only use a formula if ``method``
82+
is one of *lm*, *ols*, *wls*, *glm*, *rlm* or *gls*, and in the
83+
:ref:`formula <patsy:formulas>` you may refer to the ``x`` and
84+
``y`` aesthetic variables.
7985
se : bool (default: True)
8086
If :py:`True` draw confidence interval around the smooth line.
8187
n : int (default: 80)
@@ -131,6 +137,7 @@ def my_smoother(data, xseq, **params):
131137
DEFAULT_PARAMS = {'geom': 'smooth', 'position': 'identity',
132138
'na_rm': False,
133139
'method': 'auto', 'se': True, 'n': 80,
140+
'formula': None,
134141
'fullrange': False, 'level': 0.95,
135142
'span': 0.75, 'method_args': {}}
136143
CREATES = {'se', 'ymin', 'ymax'}
@@ -168,6 +175,15 @@ def setup_params(self, data):
168175
"facets".format(window), PlotnineWarning)
169176
params['method_args']['window'] = window
170177

178+
if params['formula']:
179+
allowed = {'lm', 'ols', 'wls', 'glm', 'rlm', 'gls'}
180+
if params['method'] not in allowed:
181+
raise ValueError(
182+
"You can only use a formula with `method` is "
183+
"one of {}".format(allowed)
184+
)
185+
params['enviroment'] = self.environment
186+
171187
return params
172188

173189
@classmethod
Loading
Loading
Loading
Loading
Loading

plotnine/tests/test_geom_smooth.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import statsmodels.api as sm
55

66

7-
from plotnine import ggplot, aes, geom_point, geom_smooth
7+
from plotnine import ggplot, aes, geom_point, geom_smooth, stat_smooth
88
from plotnine.exceptions import PlotnineWarning
99

1010

@@ -185,3 +185,69 @@ def test_init_and_fit_kwargs():
185185
)
186186

187187
assert p == 'init_and_fit_kwargs'
188+
189+
190+
n = 100
191+
random_state = np.random.RandomState(123)
192+
mu = 0
193+
sigma = 0.065
194+
noise = random_state.randn(n)*sigma + mu
195+
df = pd.DataFrame({
196+
'x': x,
197+
'y': np.sin(x) + noise,
198+
})
199+
200+
201+
class TestFormula:
202+
203+
p = ggplot(df, aes('x', 'y')) + geom_point()
204+
205+
def test_lm(self):
206+
p = (self.p
207+
+ stat_smooth(
208+
method='lm',
209+
formula='y ~ np.sin(x)',
210+
fill='red',
211+
se=True
212+
))
213+
assert p == 'lm_formula'
214+
215+
def test_lm_weights(self):
216+
p = (self.p
217+
+ aes(weight='x.abs()')
218+
+ stat_smooth(
219+
method='lm',
220+
formula='y ~ np.sin(x)',
221+
fill='red',
222+
se=True
223+
))
224+
assert p == 'lm_formula_weights'
225+
226+
def test_glm(self):
227+
p = (self.p
228+
+ stat_smooth(
229+
method='glm',
230+
formula='y ~ np.sin(x)',
231+
fill='red',
232+
se=True
233+
))
234+
assert p == 'glm_formula'
235+
236+
def test_rlm(self):
237+
p = (self.p
238+
+ stat_smooth(
239+
method='rlm',
240+
formula='y ~ np.sin(x)',
241+
fill='red',
242+
))
243+
assert p == 'rlm_formula'
244+
245+
def test_gls(self):
246+
p = (self.p
247+
+ stat_smooth(
248+
method='gls',
249+
formula='y ~ np.sin(x)',
250+
fill='red',
251+
se=True
252+
))
253+
assert p == 'gls_formula'

0 commit comments

Comments
 (0)