Skip to content

Commit cf09804

Browse files
committed
seperate imputation between categorical and numerical
1 parent 9092d1c commit cf09804

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

language/mlsql/mlsql/functions/keywords/preprocessing/impute_functions.py

+44-32
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
impute_strategy: What to replace missing values with
1313
Options:
1414
Imputer Class
15-
'most frequent'
16-
'median'
17-
'mean'
15+
'mode'
16+
'median' - numerical
17+
'mean' - numerical
1818
Custom Functions
1919
'remove'
2020
'dummy'
@@ -30,59 +30,63 @@ def impute_missing(data, columns=None, impute_strategy='mode', missing_values='N
3030
cols_to_impute = _find_cols_with_missing_vals(data, missing_values)
3131
else:
3232
cols_to_impute = columns
33-
if cols_to_impute == None:
33+
if not cols_to_impute:
3434
return datacopy
3535
if impute_strategy == 'mode':
36+
print(cols_to_impute)
3637
for col in cols_to_impute:
3738
modeVal = data[col].mode()
38-
if missing_values =='NaN':
39-
datacopy[col] = data[col].fillna(modeVal[0])
40-
else:
41-
datacopy[col] = data[col].replace(missing_values, modeVal[0], regex = True)
39+
print(modeVal[0])
40+
datacopy[col] = _fill_col(data[col], missing_values, modeVal[0])
4241
return datacopy
4342
elif impute_strategy == 'mean':
4443
for col in cols_to_impute:
45-
meanVal = data[col].mean()
46-
if missing_values == 'NaN':
47-
datacopy[col] = data[col].fillna(meanVal)
44+
if data[col].dtype != 'object':
45+
meanVal = data[col].mean()
46+
datacopy[col] = _fill_col(data[col], missing_values, meanVal)
4847
else:
49-
datacopy[col] = data[col].replace(missing_values, meanVal, regex = True)
50-
return datacopy
48+
datacopy[col] = _fill_col(data[col], missing_values, dummy_val)
49+
return datacopy
5150
elif impute_strategy == 'median':
5251
for col in cols_to_impute:
53-
medianVal = data[col].median()
54-
if missing_values == 'NaN':
55-
datacopy[col] = data[col].fillna(medianVal)
52+
if data[col].dtype != 'object':
53+
medianVal = data[col].median()
54+
datacopy[col] = _fill_col(data[col], missing_values, medianVal)
5655
else:
57-
datacopy[col] = data[col].replace(missing_values, medianVal, regex = True)
58-
return datacopy
56+
datacopy[col] = _fill_col(data[col], missing_values, dummy_val)
57+
return datacopy
5958
elif impute_strategy == 'drop column':
6059
return _remove_columns(data, cols_to_impute)
6160
elif impute_strategy == 'maximum':
6261
for col in cols_to_impute:
63-
maxVal = max(data[col])
64-
if missing_values == 'NaN':
65-
datacopy[col] = data[col].fillna(maxVal)
62+
if data[col].dtype != 'object':
63+
maxVal = max(data[col])
64+
datacopy[col] = _fill_col(data[col], missing_values, maxVal)
6665
else:
67-
datacopy[col] = data[col].replace(missing_values, maxVal, regex = True)
68-
return data
66+
datacopy[col] = _fill_col(data[col], missing_values, dummy_val)
67+
return datacopy
6968
elif impute_strategy == 'minimum':
7069
for col in cols_to_impute:
71-
minVal = min(data[col])
72-
if missing_values == 'NaN':
73-
datacopy[col] = data[col].fillna(minVal)
70+
if data[col].dtype != 'object':
71+
minVal = min(data[col])
72+
datacopy[col] = _fill_col(data[col], missing_values, minVal)
7473
else:
75-
datacopy[col] = data[col].replace(missing_values, minVal, regex = True)
76-
return data
74+
datacopy[col] = _fill_col(data[col], missing_values, dummy_val)
75+
return datacopy
7776
elif impute_strategy == 'dummy':
78-
return data.replace(missing_values, dummy_val, regex = True)
77+
for col in cols_to_impute:
78+
if data[col].dtype != 'object':
79+
datacopy[col] = _fill_col(data[col], missing_values, 0)
80+
else:
81+
datacopy[col] = _fill_col(data[col], missing_values, dummy_val)
82+
return datacopy
7983
# Do some more research on this before implementing
8084
elif impute_strategy == 'rand_forest_reg':
8185
print("RANDOM FOREST REGRESSOR NOT IMPLEMENTED NO IMPUTATION HAPPENED")
82-
return None
86+
return datacopy
8387
else:
84-
print ("REPLACE COMMAND NOT RECOGNIZED")
85-
return None
88+
print ("REPLACE COMMAND NOT RECOGNIZED NO IMPUTATION HAPPENED")
89+
return datacopy
8690

8791
"""
8892
remove_columns()
@@ -109,3 +113,11 @@ def _find_cols_with_missing_vals(data=None, missing_values= 'NaN'):
109113
if data[col].str.contains(missing_values).any():
110114
cols_to_impute.append(col)
111115
return cols_to_impute
116+
117+
def _fill_col(column, missing_values, replace_val):
118+
ret = column
119+
if missing_values == 'NaN':
120+
ret = column.fillna(replace_val)
121+
else:
122+
ret = column.replace(missing_values, replace_val, regex = True)
123+
return ret

0 commit comments

Comments
 (0)