Skip to content

Commit cd2595b

Browse files
authored
🔖 0.8.6 (#134)
1 parent 4f9c2a3 commit cd2595b

9 files changed

+97
-44
lines changed

datar/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
__all__ = ("f", "get_versions")
16-
__version__ = "0.8.5"
16+
__version__ = "0.8.6"
1717

1818
apply_init_callbacks()
1919

datar/base/arithmetic.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Arithmetic or math functions"""
22

3+
from functools import singledispatch
34
import inspect
45
from typing import TYPE_CHECKING, Union
56

@@ -883,18 +884,42 @@ def std(
883884
sd = std
884885

885886

886-
@func_factory("transform", {"x", "w"})
887-
def weighted_mean(
888-
x: Series, w: Series = 1, na_rm=True, __args_raw=None
889-
) -> Series:
890-
"""Calculate weighted mean"""
891-
if __args_raw["w"] is not None and np.nansum(w) == 0:
887+
@singledispatch
888+
def _weighted_mean(
889+
df: DataFrame,
890+
has_w: bool = True,
891+
na_rm: bool = True,
892+
) -> np.ndarray:
893+
if not has_w:
894+
return np.nanmean(df["x"]) if na_rm else np.mean(df["x"])
895+
896+
if np.nansum(df["w"]) == 0:
892897
return np.nan
893898

894899
if na_rm:
895-
na_mask = pd.isnull(x)
896-
x = x[~na_mask.values]
897-
w = w[~na_mask.values]
900+
na_mask = pd.isnull(df["x"])
901+
x = df["x"][~na_mask.values]
902+
w = df["w"][~na_mask.values]
898903
return np.average(x, weights=w)
899904

900-
return np.average(x, weights=w)
905+
return np.average(df["x"], weights=df["w"])
906+
907+
908+
@_weighted_mean.register(TibbleGrouped)
909+
def _(
910+
df: TibbleGrouped,
911+
has_w: bool = True,
912+
na_rm: bool = True,
913+
) -> Series:
914+
return df._datar["grouped"].apply(
915+
lambda subdf: _weighted_mean(subdf, has_w, na_rm)
916+
)
917+
918+
919+
@func_factory(None, {"x", "w"})
920+
def weighted_mean(
921+
x: Series, w: Series = 1, na_rm=True, __args_raw=None, __args_frame=None,
922+
) -> Series:
923+
"""Calculate weighted mean"""
924+
has_w = __args_raw["w"] is not None
925+
return _weighted_mean(__args_frame, has_w, na_rm)

datar/base/verbs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def union(x, y):
234234

235235
@register_verb(context=Context.EVAL)
236236
def unique(x):
237-
"""Union of two iterables"""
237+
"""Get unique elements from an iterable and keep their order"""
238238
# order not kept
239239
# return np.unique(x)
240240
if is_scalar(x):

datar/dplyr/distinct.py

+42-23
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
See source https://github.com/tidyverse/dplyr/blob/master/R/distinct.R
44
"""
55
from pipda import register_verb
6+
from pipda.symbolic import Reference
67

78
from ..core.backends.pandas import DataFrame
89
from ..core.backends.pandas.core.groupby import GroupBy
@@ -11,7 +12,7 @@
1112
from ..core.factory import func_factory
1213
from ..core.utils import regcall
1314
from ..core.tibble import Tibble, TibbleGrouped, reconstruct_tibble
14-
from ..base import union, setdiff, intersect
15+
from ..base import union, setdiff, intersect, unique
1516
from .mutate import mutate
1617

1718

@@ -33,31 +34,49 @@ def distinct(_data, *args, _keep_all=False, **kwargs):
3334
A dataframe without duplicated rows in _data
3435
"""
3536
if not args and not kwargs:
36-
uniq = _data.drop_duplicates()
37+
out = _data.drop_duplicates()
3738
else:
38-
# keep_none_prefers_new_order
39-
uniq = (
40-
regcall(
41-
mutate,
42-
_data,
43-
*args,
44-
**kwargs,
45-
_keep="none",
39+
if (
40+
not kwargs
41+
# optimize:
42+
# iris >> distinct(f.Species, f.Sepal_Length)
43+
# We don't need to do mutation
44+
and all(
45+
isinstance(expr, Reference)
46+
and expr._pipda_level == 1
47+
and expr._pipda_ref in _data.columns
48+
for expr in args
4649
)
47-
).drop_duplicates()
50+
):
51+
subset = [expr._pipda_ref for expr in args]
52+
ucols = getattr(_data, "group_vars", [])
53+
ucols.extend(subset)
54+
ucols = regcall(unique, ucols)
55+
uniq = _data.drop_duplicates(subset=subset)[ucols]
56+
else:
57+
# keep_none_prefers_new_order
58+
uniq = (
59+
regcall(
60+
mutate,
61+
_data,
62+
*args,
63+
**kwargs,
64+
_keep="none",
65+
)
66+
).drop_duplicates()
4867

49-
if not _keep_all:
50-
# keep original order
51-
out = uniq[
52-
regcall(
53-
union,
54-
regcall(intersect, _data.columns, uniq.columns),
55-
regcall(setdiff, uniq.columns, _data.columns),
56-
)
57-
]
58-
else:
59-
out = _data.loc[uniq.index, :].copy()
60-
out[uniq.columns.tolist()] = uniq
68+
if not _keep_all:
69+
# keep original order
70+
out = uniq[
71+
regcall(
72+
union,
73+
regcall(intersect, _data.columns, uniq.columns),
74+
regcall(setdiff, uniq.columns, _data.columns),
75+
)
76+
]
77+
else:
78+
out = _data.loc[uniq.index, :].copy()
79+
out[uniq.columns.tolist()] = uniq
6180

6281
return reconstruct_tibble(_data, Tibble(out, copy=False))
6382

docs/CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.8.6
2+
3+
- 🐛 Fix weighted_mean not working for grouped data (#133)
4+
- ✅ Add tests for weighted_mean on grouped data
5+
- ⚡️ Optimize distinct on existing columns (#128)
6+
17
## 0.8.5
28

39
- 🐛 Fix columns missing after Join by same columns using mapping (#122)

docs/requirements.txt

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# use_directory_urls doesn't work for newer versions
2-
mkdocs==1.1.2
3-
# AttributeError: module 'jinja2' has no attribute 'contextfilter'
4-
# jinja2==3.1.0
5-
jinja2==3.0.3
6-
mkdocs-material==7.2.3
7-
pymdown-extensions==8.2
2+
mkdocs
3+
mkdocs-material
4+
pymdown-extensions
85
mkapi-fix
9-
mkdocs-jupyter==0.17.3
6+
mkdocs-jupyter
107
ipykernel
118
ipython_genutils
129
# to compile readme.ipynb

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "datar"
3-
version = "0.8.5"
3+
version = "0.8.6"
44
description = "Port of dplyr and other related R packages in python, using pipda."
55
authors = ["pwwang <[email protected]>"]
66
readme = "README.md"

tests/base/test_stats.py

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def test_weighted_mean():
1616
with pytest.raises(ValueError):
1717
weighted_mean([1,2], [1,2,3])
1818

19+
df = tibble(g=[1, 1, 2, 2], x=[1, 2, 3, 4], w=[1, 3, 3, 3]).group_by('g')
20+
assert weighted_mean(df.g.obj, w=None) == 1.5
21+
assert_iterable_equal(weighted_mean(df.g), [1, 2])
22+
assert_iterable_equal(weighted_mean(df.x, w=df.w), [1.75, 3.5])
23+
1924

2025
def test_quantile():
2126
df = tibble(x=[1, 2, 3], g=[1, 2, 2])

tests/dplyr/test_distinct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from datar.tibble import tibble
2525
from datar.datasets import iris
26+
from datar.testing import assert_frame_equal
2627

2728

2829
def test_single_column():
@@ -51,7 +52,7 @@ def test_keeps_only_specified_cols():
5152
df = tibble(x=c(1, 1, 1), y=c(1, 1, 1))
5253
expect = tibble(x=1)
5354
out = df >> distinct(f.x)
54-
assert out.equals(expect)
55+
assert_frame_equal(out, expect)
5556

5657

5758
def test_unless_keep_all_true():

0 commit comments

Comments
 (0)