diff --git a/Experiment_Script_Adult.ipynb b/Experiment_Script_Adult.ipynb index 31336b8..a98cf49 100644 --- a/Experiment_Script_Adult.ipynb +++ b/Experiment_Script_Adult.ipynb @@ -27,9 +27,24 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 1%|█▋ | 4/300 [3:00:00<219:14:43, 2666.50s/it]" + ] + } + ], "source": [ "synthesizer = CTABGAN(raw_csv_path = real_path,\n", " test_ratio = 0.20,\n", @@ -38,17 +53,22 @@ " mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]},\n", " integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'],\n", " problem_type= {\"Classification\": 'income'},\n", - " epochs = 300) \n", + " epochs = 300,\n", + " batch_size = 500,\n", + " class_dim = (256, 256, 256, 256),\n", + " random_dim = 100,\n", + " num_channels = 64,\n", + " l2scale = 1e-5) \n", "\n", "for i in range(num_exp):\n", " synthesizer.fit()\n", - " syn = synthesizer.generate_samples()\n", + " syn = synthesizer.generate_samples(100)\n", " syn.to_csv(fake_file_root+\"/\"+dataset+\"/\"+ dataset+\"_fake_{exp}.csv\".format(exp=i), index= False)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -57,77 +77,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
AccAUCF1_Score
lr8.3836630.076099-0.016126
dt11.3727100.1168210.126888
rf11.5876750.1199690.120840
mlp11.7002760.0950170.088192
\n", - "
" - ], - "text/plain": [ - " Acc AUC F1_Score\n", - "lr 8.383663 0.076099 -0.016126\n", - "dt 11.372710 0.116821 0.126888\n", - "rf 11.587675 0.119969 0.120840\n", - "mlp 11.700276 0.095017 0.088192" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "classifiers_list = [\"lr\",\"dt\",\"rf\",\"mlp\"]\n", "result_mat = get_utility_metrics(real_path,fake_paths,\"MinMax\",classifiers_list, test_ratio = 0.20)\n", @@ -139,59 +91,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Average WD (Continuous ColumnsAverage JSD (Categorical Columns)Correlation Distance
00.0187470.0851251.847952
\n", - "
" - ], - "text/plain": [ - " Average WD (Continuous Columns Average JSD (Categorical Columns) \\\n", - "0 0.018747 0.085125 \n", - "\n", - " Correlation Distance \n", - "0 1.847952 " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "adult_categorical = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']\n", "stat_res_avg = []\n", @@ -206,68 +108,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
DCR between Real and Fake (5th perc)DCR within Real(5th perc)DCR within Fake (5th perc)NNDR between Real and Fake (5th perc)NNDR within Real (5th perc)NNDR within Fake (5th perc)
00.5695460.2165450.4510310.6340420.4420520.567227
\n", - "
" - ], - "text/plain": [ - " DCR between Real and Fake (5th perc) DCR within Real(5th perc) \\\n", - "0 0.569546 0.216545 \n", - "\n", - " DCR within Fake (5th perc) NNDR between Real and Fake (5th perc) \\\n", - "0 0.451031 0.634042 \n", - "\n", - " NNDR within Real (5th perc) NNDR within Fake (5th perc) \n", - "0 0.442052 0.567227 " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "priv_res_avg = []\n", "for fake_path in fake_paths:\n", @@ -281,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/model/ctabgan.py b/model/ctabgan.py index fde96e7..df7fc93 100644 --- a/model/ctabgan.py +++ b/model/ctabgan.py @@ -21,11 +21,17 @@ def __init__(self, mixed_columns= {'capital-loss':[0.0],'capital-gain':[0.0]}, integer_columns = ['age', 'fnlwgt','capital-gain', 'capital-loss','hours-per-week'], problem_type= {"Classification": 'income'}, - epochs = 1): + epochs = 1, + batch_size=500, + class_dim=(256, 256, 256, 256), + random_dim=100, + num_channels=64, + l2scale=1e-5): self.__name__ = 'CTABGAN' - self.synthesizer = CTABGANSynthesizer(epochs = epochs) + self.synthesizer = CTABGANSynthesizer(epochs = epochs, batch_size = batch_size, class_dim = class_dim, random_dim = random_dim, + num_channels = num_channels, l2scale = l2scale) self.raw_df = pd.read_csv(raw_csv_path) self.test_ratio = test_ratio self.categorical_columns = categorical_columns @@ -33,20 +39,20 @@ def __init__(self, self.mixed_columns = mixed_columns self.integer_columns = integer_columns self.problem_type = problem_type - + def fit(self): - + start_time = time.time() self.data_prep = DataPrep(self.raw_df,self.categorical_columns,self.log_columns,self.mixed_columns,self.integer_columns,self.problem_type,self.test_ratio) - self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"], + self.synthesizer.fit(train_data=self.data_prep.df, categorical = self.data_prep.column_types["categorical"], mixed = self.data_prep.column_types["mixed"],type=self.problem_type) end_time = time.time() print('Finished training in',end_time-start_time," seconds.") - def generate_samples(self): - - sample = self.synthesizer.sample(len(self.raw_df)) + def generate_samples(self, num_samples): + + sample = self.synthesizer.sample(num_samples) sample_df = self.data_prep.inverse_prep(sample) return sample_df