Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 32 additions & 189 deletions Experiment_Script_Adult.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,24 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%|█▋ | 4/300 [3:00:00<219:14:43, 2666.50s/it]"
]
}
],
"source": [
"synthesizer = CTABGAN(raw_csv_path = real_path,\n",
" test_ratio = 0.20,\n",
Expand All @@ -38,17 +53,22 @@
" mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]},\n",
" integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'],\n",
" problem_type= {\"Classification\": 'income'},\n",
" epochs = 300) \n",
" epochs = 300,\n",
" batch_size = 500,\n",
" class_dim = (256, 256, 256, 256),\n",
" random_dim = 100,\n",
" num_channels = 64,\n",
" l2scale = 1e-5) \n",
"\n",
"for i in range(num_exp):\n",
" synthesizer.fit()\n",
" syn = synthesizer.generate_samples()\n",
" syn = synthesizer.generate_samples(100)\n",
" syn.to_csv(fake_file_root+\"/\"+dataset+\"/\"+ dataset+\"_fake_{exp}.csv\".format(exp=i), index= False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,77 +77,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Acc</th>\n",
" <th>AUC</th>\n",
" <th>F1_Score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>lr</th>\n",
" <td>8.383663</td>\n",
" <td>0.076099</td>\n",
" <td>-0.016126</td>\n",
" </tr>\n",
" <tr>\n",
" <th>dt</th>\n",
" <td>11.372710</td>\n",
" <td>0.116821</td>\n",
" <td>0.126888</td>\n",
" </tr>\n",
" <tr>\n",
" <th>rf</th>\n",
" <td>11.587675</td>\n",
" <td>0.119969</td>\n",
" <td>0.120840</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>11.700276</td>\n",
" <td>0.095017</td>\n",
" <td>0.088192</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Acc AUC F1_Score\n",
"lr 8.383663 0.076099 -0.016126\n",
"dt 11.372710 0.116821 0.126888\n",
"rf 11.587675 0.119969 0.120840\n",
"mlp 11.700276 0.095017 0.088192"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"classifiers_list = [\"lr\",\"dt\",\"rf\",\"mlp\"]\n",
"result_mat = get_utility_metrics(real_path,fake_paths,\"MinMax\",classifiers_list, test_ratio = 0.20)\n",
Expand All @@ -139,59 +91,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Average WD (Continuous Columns</th>\n",
" <th>Average JSD (Categorical Columns)</th>\n",
" <th>Correlation Distance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.018747</td>\n",
" <td>0.085125</td>\n",
" <td>1.847952</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Average WD (Continuous Columns Average JSD (Categorical Columns) \\\n",
"0 0.018747 0.085125 \n",
"\n",
" Correlation Distance \n",
"0 1.847952 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"adult_categorical = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']\n",
"stat_res_avg = []\n",
Expand All @@ -206,68 +108,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>DCR between Real and Fake (5th perc)</th>\n",
" <th>DCR within Real(5th perc)</th>\n",
" <th>DCR within Fake (5th perc)</th>\n",
" <th>NNDR between Real and Fake (5th perc)</th>\n",
" <th>NNDR within Real (5th perc)</th>\n",
" <th>NNDR within Fake (5th perc)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.569546</td>\n",
" <td>0.216545</td>\n",
" <td>0.451031</td>\n",
" <td>0.634042</td>\n",
" <td>0.442052</td>\n",
" <td>0.567227</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" DCR between Real and Fake (5th perc) DCR within Real(5th perc) \\\n",
"0 0.569546 0.216545 \n",
"\n",
" DCR within Fake (5th perc) NNDR between Real and Fake (5th perc) \\\n",
"0 0.451031 0.634042 \n",
"\n",
" NNDR within Real (5th perc) NNDR within Fake (5th perc) \n",
"0 0.442052 0.567227 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"priv_res_avg = []\n",
"for fake_path in fake_paths:\n",
Expand All @@ -281,7 +124,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down
22 changes: 14 additions & 8 deletions model/ctabgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,38 @@ def __init__(self,
mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]},
integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'],
problem_type= {"Classification": 'income'},
epochs = 1):
epochs = 1,
batch_size=500,
class_dim=(256, 256, 256, 256),
random_dim=100,
num_channels=64,
l2scale=1e-5):

self.__name__ = 'CTABGAN'

self.synthesizer = CTABGANSynthesizer(epochs = epochs)
self.synthesizer = CTABGANSynthesizer(epochs = epochs, batch_size = batch_size, class_dim = class_dim, random_dim = random_dim,
num_channels = num_channels, l2scale = l2scale)
self.raw_df = pd.read_csv(raw_csv_path)
self.test_ratio = test_ratio
self.categorical_columns = categorical_columns
self.log_columns = log_columns
self.mixed_columns = mixed_columns
self.integer_columns = integer_columns
self.problem_type = problem_type

def fit(self):

start_time = time.time()
self.data_prep = DataPrep(self.raw_df,self.categorical_columns,self.log_columns,self.mixed_columns,self.integer_columns,self.problem_type,self.test_ratio)
self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"],
self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"],
mixed = self.data_prep.column_types["mixed"],type=self.problem_type)
end_time = time.time()
print('Finished training in',end_time-start_time," seconds.")


def generate_samples(self):
sample = self.synthesizer.sample(len(self.raw_df))
def generate_samples(self, num_samples):

sample = self.synthesizer.sample(num_samples)
sample_df = self.data_prep.inverse_prep(sample)

return sample_df