Skip to content

Commit 836261f

Browse files
authored
Merge pull request #58 from Renumics/feature/huggingface-integration
Feature/huggingface integration
2 parents 957f70a + ca0731d commit 836261f

7 files changed

Lines changed: 284 additions & 20 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ dependencies = [
2121
"scikit-learn >= 1.2.2",
2222
"umap-learn >= 0.5.3",
2323
"tqdm >= 4.65.0",
24-
"renumics-spotlight == 1.4.0rc2",
24+
"renumics-spotlight >= 1.5.3",
2525
"datasets >= 2.13.1",
26+
"puremagic >= 1.15"
2627
]
2728

2829
[project.optional-dependencies]

sliceguard/data.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import List, Optional
1+
from os import rename
2+
from typing import List
23
from pathlib import Path
34
import pandas as pd
45
import datasets
5-
from datasets import Image, ClassLabel, Value, Sequence
6-
6+
from datasets import Image, Audio, ClassLabel, Value, Sequence
7+
import uuid
8+
import puremagic
79

810
def _get_tutorial_imports():
911
try:
@@ -15,43 +17,94 @@ def _get_tutorial_imports():
1517
return downloader
1618

1719

18-
def from_huggingface(dataset_identifier: str):
20+
def write_file(data: dict, suffix: str, data_dir: str):
21+
with open(f"{data_dir}/{uuid.uuid4().hex}{suffix}", "wb") as tmp:
22+
tmp.write(data["bytes"])
23+
return tmp.name
24+
25+
26+
def convert_data(data: dict, data_dir: str):
27+
"""
28+
Prefer raw data over path
29+
"""
30+
if "bytes" in data and data['bytes'] is not None:
31+
if len(data['bytes']) > 0:
32+
suffix = puremagic.from_string(data['bytes'])
33+
return write_file(data, suffix, data_dir)
34+
35+
if "path" in data and data['path'] is not None:
36+
if data['path'] != "":
37+
suffix = puremagic.from_file(data['path'])
38+
new_path = f"{data['path']}{suffix}"
39+
40+
# In case of missing file extension
41+
rename(data['path'], new_path)
42+
43+
return new_path
44+
45+
46+
# Tested with the following datasets:
47+
# Image:
48+
# "mnist"
49+
# "ceyda/smithsonian_butterflies"
50+
# "GabrielVidal/dead-by-daylight-perks"
51+
52+
# Audio:
53+
# "437aewuh/dog-dataset"
54+
# "Gae8J/modeling"
55+
# "ccmusic-database/piano_sound_quality"
56+
57+
# Text:
58+
# "xtreme", "XNLI"
59+
# "indonlp/indonlu", "smsa"
60+
# "tweet_eval", "emoji"
61+
62+
63+
def from_huggingface(dataset_identifier: str, name=None, split=None, extract_dir="./sliceguard_tmp"):
1964
# Simple utility method to support loading of huggingface datasets
20-
# Currently only supports image data. Use custom load function if you need something else.
21-
dataset = datasets.load_dataset(dataset_identifier)
65+
dataset = datasets.load_dataset(dataset_identifier, name, split)
2266
overall_df = None
67+
68+
# Create missing directories if non-existent
69+
Path(extract_dir).mkdir(parents=True, exist_ok=True)
70+
71+
# Iterate splits in dataset.
2372
for split in dataset.keys():
2473
cur_split = dataset[split]
2574

2675
split_df = dataset[split].to_pandas()
2776
split_df["split"] = split
2877

78+
# Create a dataframe from each split.
2979
for fname, ftype in cur_split.features.items():
3080
if (
3181
not isinstance(ftype, Image)
82+
and not isinstance(ftype, Audio)
3283
and not isinstance(ftype, ClassLabel)
3384
and not isinstance(ftype, Value)
85+
and not isinstance(ftype, list)
3486
and not isinstance(ftype, Sequence)
3587
):
3688
raise RuntimeError(
3789
f"Found unsupported datatype {ftype}. Use custom load function."
3890
)
91+
92+
if isinstance(ftype, list):
93+
split_df = split_df.drop(columns=fname)
94+
print(
95+
f"Column {fname} with type {ftype} dropped. Lists are currently not supported."
96+
)
97+
3998
# Run transformations for specific data types if needed.
4099
if isinstance(ftype, ClassLabel):
41100
class_label_lookup = {i: l for i, l in enumerate(ftype.names)}
42101
split_df[fname] = split_df[fname].map(lambda x: class_label_lookup[x])
43102

44-
if isinstance(ftype, Image):
45-
all_has_paths = all(
46-
x is not None and "path" in x for x in split_df[fname].values
47-
)
48-
if not all_has_paths:
49-
print(
50-
f"Column {fname} dropped. Images are not extracted onto harddrive. Currently this is not supported."
51-
)
52-
split_df = split_df.drop(columns=fname)
103+
if isinstance(ftype, Image) or isinstance(ftype, Audio):
104+
if any(x is None for x in split_df[fname].values):
105+
print("Column {fname} dropped due to None-type entries.")
53106
else:
54-
split_df[fname] = split_df[fname].map(lambda x: x["path"])
107+
split_df[fname] = split_df[fname].map(lambda x: convert_data(x, extract_dir))
55108

56109
if overall_df is None:
57110
overall_df = split_df

sliceguard/detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def detect_issues(
322322
marked_issue_idx = 0
323323
for idx, row in all_groups_df.iterrows():
324324
if row["issue"] == True:
325-
group_dfs[int(row["level"])].loc[idx] = True
325+
group_dfs[int(row["level"])].loc[idx, "issue"] = True
326326

327327
marked_issue_idx += 1
328328
if n_slices is not None and marked_issue_idx >= n_slices:

sliceguard/explanation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ def explain_clusters(features, feature_types, issues, df, prereduced_embeddings)
103103
predicate["maximum"] = val.max()
104104
elif feature_type in ["nominal", "ordinal"]:
105105
val = df[f].iloc[issue_rows]
106-
predicate["mode"] = val.mode()[0]
106+
feature_mode = val.mode()
107+
if len(feature_mode) == 0:
108+
predicate["mode"] = "no mode"
109+
elif len(feature_mode == 1):
110+
predicate["mode"] = feature_mode[0]
111+
else:
112+
raise RuntimeError("Invalid value encountered when calculating feature mode.")
107113
predicates_list.append(predicate)
108114

109115
issue["explanation"] = predicates_list

sliceguard/sliceguard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import warnings
33
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
44

5+
# Ignore warnings caused by dependency umap-learn
56
warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
67
warnings.simplefilter("ignore", category=NumbaPendingDeprecationWarning)
8+
# For now ignore warnings caused by dependency fairlearn. Remove once they address Pandas 2.0
9+
warnings.simplefilter(action='ignore', category=FutureWarning)
710

811
# Real imports
912
from uuid import uuid4

sliceguard/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def encode_normalize_features(
282282
num_dimensions,
283283
), # TODO: Do not hardcode this, probably determine based on embedding size and variance. Also, check implications on normalization.
284284
# min_dist=0.0,
285-
random_state=42,
286285
set_op_mix_ratio=op_mix_ratio_prereduction,
287286
).fit_transform(embeddings)
288287

tests/test_huggingface.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import os
2+
import uuid
3+
import shutil
4+
from pathlib import Path
5+
from urllib.parse import urlparse
6+
7+
from sklearn.metrics import accuracy_score
8+
import requests
9+
import pandas as pd
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
from jiwer import wer
13+
import datasets
14+
from renumics.spotlight import Image, Audio
15+
from sliceguard import data
16+
17+
from sliceguard import SliceGuard
18+
19+
20+
def wer_metric(y_true, y_pred):
21+
return np.mean([wer(s_y, s_pred) for s_y, s_pred in zip(y_true, y_pred)])
22+
23+
24+
def test_huggingface_mnist():
25+
df = data.from_huggingface("mnist")
26+
27+
sg = SliceGuard()
28+
issue_df = sg.find_issues(
29+
df.sample(100),
30+
["image"],
31+
y="label",
32+
metric=accuracy_score,
33+
metric_mode="max",
34+
min_support=10,
35+
min_drop=0.08,
36+
)
37+
38+
sg.report(spotlight_dtype={"image_path": Image})
39+
40+
41+
def test_huggingface_butterflies():
42+
df = data.from_huggingface("ceyda/smithsonian_butterflies")
43+
44+
sg = SliceGuard()
45+
issue_df = sg.find_issues(
46+
df,
47+
["image"],
48+
y="scientific_name",
49+
metric=accuracy_score,
50+
metric_mode="max",
51+
min_support=10,
52+
min_drop=0.08,
53+
automl_train_split="train",
54+
automl_task="classification",
55+
automl_time_budget=40.0,
56+
)
57+
58+
sg.report(spotlight_dtype={"image_path": Image})
59+
60+
61+
def test_huggingface_dead_by_daylight_perks():
62+
df = data.from_huggingface("GabrielVidal/dead-by-daylight-perks")
63+
64+
sg = SliceGuard()
65+
issue_df = sg.find_issues(
66+
df,
67+
["image"],
68+
y="type",
69+
metric=accuracy_score,
70+
metric_mode="max",
71+
min_support=10,
72+
min_drop=0.08,
73+
automl_train_split="train",
74+
automl_task="classification",
75+
# automl_use_full_embeddings=True,
76+
automl_time_budget=40.0,
77+
)
78+
79+
sg.report(spotlight_dtype={"image_path": Image})
80+
81+
82+
def test_huggingface_dog_dataset():
83+
df = data.from_huggingface("437aewuh/dog-dataset")
84+
85+
sg = SliceGuard()
86+
issue_df = sg.find_issues(
87+
df.sample(200),
88+
["audio"],
89+
"label",
90+
metric=accuracy_score,
91+
metric_mode="max",
92+
embedding_models={"path": "superb/wav2vec2-base-superb-sid"},
93+
min_support=5,
94+
min_drop=0.1,
95+
)
96+
sg.report(spotlight_dtype={"path": Audio})
97+
98+
99+
def test_huggingface_modeling():
100+
df = data.from_huggingface("Gae8J/modeling")
101+
102+
sg = SliceGuard()
103+
issue_df = sg.find_issues(
104+
df.sample(200),
105+
["audio"],
106+
"label",
107+
metric=accuracy_score,
108+
metric_mode="max",
109+
automl_train_split="train",
110+
automl_task="classification",
111+
automl_time_budget=40.0,
112+
)
113+
sg.report(spotlight_dtype={"path": Audio})
114+
115+
116+
def test_huggingface_piano():
117+
df = data.from_huggingface("ccmusic-database/piano_sound_quality")
118+
119+
sg = SliceGuard()
120+
issue_df = sg.find_issues(
121+
df.sample(200),
122+
["audio"],
123+
"label",
124+
metric=accuracy_score,
125+
metric_mode="max",
126+
automl_train_split="train",
127+
automl_task="classification",
128+
# automl_use_full_embeddings=True,
129+
automl_time_budget=40.0,
130+
)
131+
sg.report(spotlight_dtype={"path": Audio})
132+
133+
134+
def test_huggingface_xtreme():
135+
df = data.from_huggingface("xtreme", "XNLI")
136+
sg = SliceGuard()
137+
issue_df = sg.find_issues(
138+
df.sample(1000),
139+
['language'],
140+
"gold_label",
141+
metric=accuracy_score,
142+
min_drop=0.05,
143+
min_support=10,
144+
automl_task="classification",
145+
automl_time_budget=40.0,
146+
)
147+
sg.report()
148+
149+
150+
def test_huggingface_indonlu():
151+
df = data.from_huggingface("indonlp/indonlu", "smsa")
152+
sg = SliceGuard()
153+
issue_df = sg.find_issues(
154+
df.sample(1000),
155+
['text'],
156+
"label",
157+
metric=accuracy_score,
158+
min_drop=0.05,
159+
min_support=10,
160+
automl_train_split="train",
161+
automl_task="classification",
162+
automl_time_budget=40.0,
163+
)
164+
sg.report()
165+
166+
167+
def test_huggingface_tweet_eval():
168+
df = data.from_huggingface("tweet_eval", "emoji")
169+
sg = SliceGuard()
170+
issue_df = sg.find_issues(
171+
df.sample(1000),
172+
['text'],
173+
"label",
174+
metric=accuracy_score,
175+
# metric_mode="max",
176+
# wer_metric,
177+
# metric_mode="min",
178+
min_drop=0.05,
179+
min_support=10,
180+
# automl_split_key="",
181+
automl_train_split="train",
182+
automl_task="classification",
183+
# automl_use_full_embeddings=True,
184+
automl_time_budget=40.0,
185+
)
186+
sg.report()
187+
188+
189+
# Image:
190+
test_huggingface_mnist()
191+
# test_huggingface_butterflies()
192+
# test_huggingface_dead_by_daylight_perks()
193+
194+
# Audio:
195+
# test_huggingface_dog_dataset()
196+
# test_huggingface_modeling()
197+
# test_huggingface_piano()
198+
199+
# Text:
200+
# test_huggingface_xtreme()
201+
# test_huggingface_indonlu()
202+
# test_huggingface_tweet_eval()

0 commit comments

Comments
 (0)