diff --git a/Experiment_Script_Adult.ipynb b/Experiment_Script_Adult.ipynb
index 31336b8..a98cf49 100644
--- a/Experiment_Script_Adult.ipynb
+++ b/Experiment_Script_Adult.ipynb
@@ -27,9 +27,24 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 1%|█▋ | 4/300 [3:00:00<219:14:43, 2666.50s/it]"
+ ]
+ }
+ ],
"source": [
"synthesizer = CTABGAN(raw_csv_path = real_path,\n",
" test_ratio = 0.20,\n",
@@ -38,17 +53,22 @@
" mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]},\n",
" integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'],\n",
" problem_type= {\"Classification\": 'income'},\n",
- " epochs = 300) \n",
+ " epochs = 300,\n",
+ " batch_size = 500,\n",
+ " class_dim = (256, 256, 256, 256),\n",
+ " random_dim = 100,\n",
+ " num_channels = 64,\n",
+ " l2scale = 1e-5) \n",
"\n",
"for i in range(num_exp):\n",
" synthesizer.fit()\n",
- " syn = synthesizer.generate_samples()\n",
+ " syn = synthesizer.generate_samples(100)\n",
" syn.to_csv(fake_file_root+\"/\"+dataset+\"/\"+ dataset+\"_fake_{exp}.csv\".format(exp=i), index= False)"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -57,77 +77,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Acc | \n",
- " AUC | \n",
- " F1_Score | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | lr | \n",
- " 8.383663 | \n",
- " 0.076099 | \n",
- " -0.016126 | \n",
- "
\n",
- " \n",
- " | dt | \n",
- " 11.372710 | \n",
- " 0.116821 | \n",
- " 0.126888 | \n",
- "
\n",
- " \n",
- " | rf | \n",
- " 11.587675 | \n",
- " 0.119969 | \n",
- " 0.120840 | \n",
- "
\n",
- " \n",
- " | mlp | \n",
- " 11.700276 | \n",
- " 0.095017 | \n",
- " 0.088192 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Acc AUC F1_Score\n",
- "lr 8.383663 0.076099 -0.016126\n",
- "dt 11.372710 0.116821 0.126888\n",
- "rf 11.587675 0.119969 0.120840\n",
- "mlp 11.700276 0.095017 0.088192"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"classifiers_list = [\"lr\",\"dt\",\"rf\",\"mlp\"]\n",
"result_mat = get_utility_metrics(real_path,fake_paths,\"MinMax\",classifiers_list, test_ratio = 0.20)\n",
@@ -139,59 +91,9 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " Average WD (Continuous Columns | \n",
- " Average JSD (Categorical Columns) | \n",
- " Correlation Distance | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " 0.018747 | \n",
- " 0.085125 | \n",
- " 1.847952 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " Average WD (Continuous Columns Average JSD (Categorical Columns) \\\n",
- "0 0.018747 0.085125 \n",
- "\n",
- " Correlation Distance \n",
- "0 1.847952 "
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"adult_categorical = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']\n",
"stat_res_avg = []\n",
@@ -206,68 +108,9 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " DCR between Real and Fake (5th perc) | \n",
- " DCR within Real(5th perc) | \n",
- " DCR within Fake (5th perc) | \n",
- " NNDR between Real and Fake (5th perc) | \n",
- " NNDR within Real (5th perc) | \n",
- " NNDR within Fake (5th perc) | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " | 0 | \n",
- " 0.569546 | \n",
- " 0.216545 | \n",
- " 0.451031 | \n",
- " 0.634042 | \n",
- " 0.442052 | \n",
- " 0.567227 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " DCR between Real and Fake (5th perc) DCR within Real(5th perc) \\\n",
- "0 0.569546 0.216545 \n",
- "\n",
- " DCR within Fake (5th perc) NNDR between Real and Fake (5th perc) \\\n",
- "0 0.451031 0.634042 \n",
- "\n",
- " NNDR within Real (5th perc) NNDR within Fake (5th perc) \n",
- "0 0.442052 0.567227 "
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"priv_res_avg = []\n",
"for fake_path in fake_paths:\n",
@@ -281,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
diff --git a/model/ctabgan.py b/model/ctabgan.py
index fde96e7..df7fc93 100644
--- a/model/ctabgan.py
+++ b/model/ctabgan.py
@@ -21,11 +21,17 @@ def __init__(self,
mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]},
integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'],
problem_type= {"Classification": 'income'},
- epochs = 1):
+ epochs = 1,
+ batch_size=500,
+ class_dim=(256, 256, 256, 256),
+ random_dim=100,
+ num_channels=64,
+ l2scale=1e-5):
self.__name__ = 'CTABGAN'
- self.synthesizer = CTABGANSynthesizer(epochs = epochs)
+ self.synthesizer = CTABGANSynthesizer(epochs = epochs, batch_size = batch_size, class_dim = class_dim, random_dim = random_dim,
+ num_channels = num_channels, l2scale = l2scale)
self.raw_df = pd.read_csv(raw_csv_path)
self.test_ratio = test_ratio
self.categorical_columns = categorical_columns
@@ -33,20 +39,20 @@ def __init__(self,
self.mixed_columns = mixed_columns
self.integer_columns = integer_columns
self.problem_type = problem_type
-
+
def fit(self):
-
+
start_time = time.time()
self.data_prep = DataPrep(self.raw_df,self.categorical_columns,self.log_columns,self.mixed_columns,self.integer_columns,self.problem_type,self.test_ratio)
- self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"],
+ self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"],
mixed = self.data_prep.column_types["mixed"],type=self.problem_type)
end_time = time.time()
print('Finished training in',end_time-start_time," seconds.")
- def generate_samples(self):
-
- sample = self.synthesizer.sample(len(self.raw_df))
+ def generate_samples(self, num_samples):
+
+ sample = self.synthesizer.sample(num_samples)
sample_df = self.data_prep.inverse_prep(sample)
return sample_df