Skip to content

Commit

Permalink
remove giskard dep
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed May 19, 2022
1 parent 50aa85f commit 3dad199
Show file tree
Hide file tree
Showing 6 changed files with 604 additions and 602 deletions.
2 changes: 1 addition & 1 deletion pkg/pkg/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .io import savefig, OUT_PATH, FIG_PATH, get_out_dir, glue
from .io import FIG_PATH, OUT_PATH, get_out_dir, glue, savefig
3 changes: 2 additions & 1 deletion pkg/pkg/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .palette import method_palette, subgraph_palette
from .scatter import matched_stripplot
from .theme import set_theme
from .utils import (
bound_texts,
Expand All @@ -12,4 +14,3 @@
remove_shared_ax,
shrink_axis,
)
from .palette import method_palette, subgraph_palette
55 changes: 55 additions & 0 deletions pkg/pkg/plot/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def matched_stripplot(
data,
x=None,
y=None,
jitter=0.2,
hue=None,
match=None,
ax=None,
matchline_kws=None,
order=None,
**kwargs,
):
data = data.copy()
if ax is None:
ax = plt.gca()

if order is None:
unique_x_var = data[x].unique()
else:
unique_x_var = order
ind_map = dict(zip(unique_x_var, range(len(unique_x_var))))
data["x"] = data[x].map(ind_map)
if match is not None:
groups = data.groupby(match)
for _, group in groups:
perturb = np.random.uniform(-jitter, jitter)
data.loc[group.index, "x"] += perturb
else:
data["x"] += np.random.uniform(-jitter, jitter, len(data))

sns.scatterplot(data=data, x="x", y=y, hue=hue, ax=ax, zorder=1, **kwargs)

if match is not None:
unique_match_var = data[match].unique()
fake_palette = dict(zip(unique_match_var, len(unique_match_var) * ["black"]))
if matchline_kws is None:
matchline_kws = dict(alpha=0.2, linewidth=1)
sns.lineplot(
data=data,
x="x",
y=y,
hue=match,
ax=ax,
legend=False,
palette=fake_palette,
zorder=-1,
**matchline_kws,
)
ax.set(xlabel=x, xticks=np.arange(len(unique_x_var)), xticklabels=unique_x_var)
return ax
Loading

0 comments on commit 3dad199

Please sign in to comment.