Skip to content

Commit ffec6b0

Browse files
authored
feat: Gumbel-Softmax and Activation Interface layers
* Gumbel Softmax and Activation Interface base * Generalize cat_lens property + optimize runtime * GS serializable + Remove optimization (no improv) * Interface serializable * GumbelSoftmaxActivation integrations in regular models * Renaming to GS Activation and activation info * lingering preprocess artifacts * CGAN label column dtype validation fix
1 parent dda17d2 commit ffec6b0

File tree

10 files changed

+297
-43
lines changed

10 files changed

+297
-43
lines changed

src/ydata_synthetic/preprocessing/base_processor.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
1-
"Implements a BaseProcessor Class, not meant to be directly instantiated."
1+
"Base class of Data Preprocessors, do not instantiate this class directly."
22
from __future__ import annotations
33

44
from abc import ABC, abstractmethod
5+
from collections import namedtuple
56
from typing import List, Optional
67

7-
from numpy import ndarray
8-
from pandas import DataFrame, Series
8+
from numpy import concatenate, ndarray, split, zeros
9+
from pandas import DataFrame, Series, concat
910
from sklearn.base import BaseEstimator, TransformerMixin
1011
from sklearn.exceptions import NotFittedError
1112
from typeguard import typechecked
1213

14+
ProcessorInfo = namedtuple("ProcessorInfo", ["numerical", "categorical"])
15+
PipelineInfo = namedtuple("PipelineInfo", ["feat_names_in", "feat_names_out"])
1316

17+
# pylint: disable=R0902
1418
@typechecked
1519
class BaseProcessor(ABC, BaseEstimator, TransformerMixin):
1620
"""
17-
Base class for Data Preprocessing.
18-
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
21+
This data processor works like a scikit learn transformer in with the methods fit, transform and inverse transform.
1922
Args:
2023
num_cols (list of strings):
2124
List of names of numerical columns.
2225
cat_cols (list of strings):
2326
List of names of categorical columns.
2427
"""
2528
def __init__(self, num_cols: Optional[List[str]] = None, cat_cols: Optional[List[str]] = None):
26-
2729
self.num_cols = [] if num_cols is None else num_cols
2830
self.cat_cols = [] if cat_cols is None else cat_cols
2931

30-
self._num_pipeline = None
31-
self._cat_pipeline = None
32+
self._num_pipeline = None # To be overriden by child processors
33+
self._cat_pipeline = None # To be overriden by child processors
3234

33-
self._types = None
35+
self._col_transform_info = None # Metadata object mapping inputs/outputs of each pipeline
3436

3537
@property
3638
def num_pipeline(self) -> BaseEstimator:
@@ -47,6 +49,25 @@ def types(self) -> Series:
4749
"""Returns a Series with the dtypes of each column in the fitted DataFrame."""
4850
return self._types
4951

52+
@property
53+
def col_transform_info(self) -> ProcessorInfo:
54+
"""Returns a ProcessorInfo object specifying input/output feature mappings of this processor's pipelines."""
55+
self._check_is_fitted()
56+
if self._col_transform_info is None:
57+
self._col_transform_info = self.__create_metadata_synth()
58+
return self._col_transform_info
59+
60+
def __create_metadata_synth(self):
61+
num_info = PipelineInfo([], [])
62+
cat_info = PipelineInfo([], [])
63+
# Numerical ls named tuple
64+
if self.num_cols:
65+
num_info = PipelineInfo(self.num_pipeline.feature_names_in_, self.num_pipeline.get_feature_names_out())
66+
# Categorical ls named tuple
67+
if self.cat_cols:
68+
cat_info = PipelineInfo(self.cat_pipeline.feature_names_in_, self.cat_pipeline.get_feature_names_out())
69+
return ProcessorInfo(num_info, cat_info)
70+
5071
def _check_is_fitted(self):
5172
"""Checks if the processor is fitted by testing the numerical pipeline.
5273
Raises NotFittedError if not."""
@@ -86,8 +107,7 @@ def transform(self, X: DataFrame) -> ndarray:
86107
DataFrame used to fit the processor parameters.
87108
Should be aligned with the columns types defined in initialization.
88109
Returns:
89-
transformed (ndarray):
90-
Processed version of the passed DataFrame.
110+
transformed (ndarray): Processed version of the passed DataFrame.
91111
"""
92112
raise NotImplementedError
93113

src/ydata_synthetic/synthesizers/regular/cgan/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""CGAN implementation"""
22
import os
33
from os import path
4-
from typing import List, Tuple, Union
4+
from typing import List, Tuple, Union, Optional, NamedTuple
55

66
import numpy as np
77
from numpy import array, empty, hstack, ndarray, vstack, save
@@ -19,6 +19,7 @@
1919

2020
from ydata_synthetic.synthesizers import TrainParameters
2121
from ydata_synthetic.synthesizers.gan import BaseModel
22+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxActivation
2223

2324

2425
class CGAN(BaseModel):
@@ -44,15 +45,16 @@ def label_col(self, data_label: Tuple[Union[DataFrame, array], str]):
4445
cannot be used as condition."
4546
assert data[label_col].isna().sum() == 0, "The label column contains NaN values, please impute or drop the \
4647
respective records before proceeding."
47-
assert is_float_dtype(data[label_col]) or is_integer_dtype(float), "The label column is expected to be an \
48+
assert is_float_dtype(data[label_col]) or is_integer_dtype(data[label_col]), "The label column is expected to be an \
4849
integer or a float dtype to ensure the function of the embedding layer."
4950
unique_frac = data[label_col].nunique()/len(data.index)
5051
assert unique_frac < 1, "The provided column {label_col} is constituted by unique values and is not suitable \
5152
to be used as condition."
5253

53-
def define_gan(self):
54+
def define_gan(self, activation_info: Optional[NamedTuple] = None):
5455
self.generator = Generator(self.batch_size, self.num_classes). \
55-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
56+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
57+
activation_info = activation_info)
5658

5759
self.discriminator = Discriminator(self.batch_size, self.num_classes). \
5860
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -121,7 +123,7 @@ def train(self, data: Union[DataFrame, array], label_col: str, train_arguments:
121123

122124
processed_data = self.processor.transform(data)
123125
self.data_dim = processed_data.shape[1]
124-
self.define_gan()
126+
self.define_gan(self.processor.col_transform_info)
125127

126128
# Merging labels with processed data
127129
processed_data = hstack([processed_data, label])
@@ -198,7 +200,7 @@ def __init__(self, batch_size, num_classes):
198200
self.batch_size = batch_size
199201
self.num_classes = num_classes
200202

201-
def build_model(self, input_shape, dim, data_dim):
203+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
202204
noise = Input(shape=input_shape, batch_size=self.batch_size)
203205
label = Input(shape=(1,), batch_size=self.batch_size, dtype='int32')
204206
label_embedding = Flatten()(Embedding(self.num_classes, 1)(label))
@@ -208,6 +210,8 @@ def build_model(self, input_shape, dim, data_dim):
208210
x = Dense(dim * 2, activation='relu')(x)
209211
x = Dense(dim * 4, activation='relu')(x)
210212
x = Dense(data_dim)(x)
213+
if activation_info:
214+
x = GumbelSoftmaxActivation(activation_info).call(x)
211215
return Model(inputs=[noise, label], outputs=x)
212216

213217

src/ydata_synthetic/synthesizers/regular/cramergan/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from os import path
3-
from typing import List
3+
from typing import List, Optional, NamedTuple
44

55
import numpy as np
66
import tensorflow as tf
@@ -12,6 +12,7 @@
1212
from ydata_synthetic.synthesizers import TrainParameters
1313
from ydata_synthetic.synthesizers.gan import BaseModel
1414
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
15+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxActivation
1516

1617

1718
class CRAMERGAN(BaseModel):
@@ -26,9 +27,10 @@ def __init__(self, model_parameters, gradient_penalty_weight=10):
2627
self.gradient_penalty_weight = gradient_penalty_weight
2728
super().__init__(model_parameters)
2829

29-
def define_gan(self):
30+
def define_gan(self, activation_info: Optional[NamedTuple] = None):
3031
self.generator = Generator(self.batch_size). \
31-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
32+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
33+
activation_info=activation_info)
3234

3335
self.critic = Critic(self.batch_size). \
3436
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -145,7 +147,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat
145147

146148
data = self.processor.transform(data)
147149
self.data_dim = data.shape[1]
148-
self.define_gan()
150+
self.define_gan(self.processor.col_transform_info)
149151

150152
iterations = int(abs(data.shape[0] / self.batch_size) + 1)
151153

@@ -190,12 +192,14 @@ def __init__(self, batch_size):
190192
"""Simple generator with dense feedforward layers."""
191193
self.batch_size = batch_size
192194

193-
def build_model(self, input_shape, dim, data_dim):
195+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
194196
input_ = Input(shape=input_shape, batch_size=self.batch_size)
195197
x = Dense(dim, activation='relu')(input_)
196198
x = Dense(dim * 2, activation='relu')(x)
197199
x = Dense(dim * 4, activation='relu')(x)
198200
x = Dense(data_dim)(x)
201+
if activation_info:
202+
x = GumbelSoftmaxActivation(activation_info)(x)
199203
return Model(inputs=input_, outputs=x)
200204

201205
class Critic(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/dragan/model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from os import path
33

4+
from typing import Optional, NamedTuple
45
import tensorflow as tf
56
import tqdm
67
from tensorflow.keras import Model, initializers
@@ -9,6 +10,7 @@
910

1011
from ydata_synthetic.synthesizers.gan import BaseModel
1112
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
13+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxActivation
1214

1315

1416
class DRAGAN(BaseModel):
@@ -21,10 +23,11 @@ def __init__(self, model_parameters, n_discriminator, gradient_penalty_weight=10
2123
self.gradient_penalty_weight = gradient_penalty_weight
2224
super().__init__(model_parameters)
2325

24-
def define_gan(self):
26+
def define_gan(self, col_transform_info: Optional[NamedTuple] = None):
2527
# define generator/discriminator
2628
self.generator = Generator(self.batch_size). \
27-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
29+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
30+
activation_info=col_transform_info)
2831

2932
self.discriminator = Discriminator(self.batch_size). \
3033
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -125,7 +128,7 @@ def train(self, data, train_arguments, num_cols, cat_cols):
125128

126129
processed_data = self.processor.transform(data)
127130
self.data_dim = processed_data.shape[1]
128-
self.define_gan()
131+
self.define_gan(self.processor.col_transform_info)
129132

130133
train_loader = self.get_data_batch(processed_data, self.batch_size)
131134

@@ -174,10 +177,12 @@ class Generator(Model):
174177
def __init__(self, batch_size):
175178
self.batch_size = batch_size
176179

177-
def build_model(self, input_shape, dim, data_dim):
180+
def build_model(self, input_shape, dim, data_dim, activation_info: NamedTuple = None):
178181
input = Input(shape=input_shape, batch_size = self.batch_size)
179182
x = Dense(dim, kernel_initializer=initializers.TruncatedNormal(mean=0., stddev=0.5), activation='relu')(input)
180183
x = Dense(dim * 2, activation='relu')(x)
181184
x = Dense(dim * 4, activation='relu')(x)
182185
x = Dense(data_dim)(x)
186+
if activation_info:
187+
x = GumbelSoftmaxActivation(activation_info)(x)
183188
return Model(inputs=input, outputs=x)

src/ydata_synthetic/synthesizers/regular/vanillagan/model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
from os import path
33
import numpy as np
4-
from typing import List
4+
from typing import List, Optional, NamedTuple
55
from tqdm import trange
66

77
from ydata_synthetic.synthesizers.gan import BaseModel
88
from ydata_synthetic.synthesizers import TrainParameters
9+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxActivation
910

1011
import tensorflow as tf
1112
from tensorflow.keras.layers import Input, Dense, Dropout
@@ -19,9 +20,10 @@ class VanilllaGAN(BaseModel):
1920
def __init__(self, model_parameters):
2021
super().__init__(model_parameters)
2122

22-
def define_gan(self):
23+
def define_gan(self, activation_info: Optional[NamedTuple]):
2324
self.generator = Generator(self.batch_size).\
24-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
25+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
26+
activation_info = activation_info)
2527

2628
self.discriminator = Discriminator(self.batch_size).\
2729
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -63,8 +65,7 @@ def get_data_batch(self, train, batch_size, seed=0):
6365
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
6466
return train[train_ix[start_i: stop_i]]
6567

66-
def train(self, data, train_arguments: TrainParameters, num_cols: List[str],
67-
cat_cols: List[str]):
68+
def train(self, data, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
6869
"""
6970
Args:
7071
data: A pandas DataFrame or a Numpy array with the data to be synthesized
@@ -76,7 +77,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str],
7677

7778
processed_data = self.processor.transform(data)
7879
self.data_dim = processed_data.shape[1]
79-
self.define_gan()
80+
self.define_gan(self.processor.col_transform_info)
8081

8182
iterations = int(abs(data.shape[0]/self.batch_size)+1)
8283

@@ -130,12 +131,14 @@ class Generator(tf.keras.Model):
130131
def __init__(self, batch_size):
131132
self.batch_size=batch_size
132133

133-
def build_model(self, input_shape, dim, data_dim):
134+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
134135
input= Input(shape=input_shape, batch_size=self.batch_size)
135136
x = Dense(dim, activation='relu')(input)
136137
x = Dense(dim * 2, activation='relu')(x)
137138
x = Dense(dim * 4, activation='relu')(x)
138139
x = Dense(data_dim)(x)
140+
if activation_info:
141+
x = GumbelSoftmaxActivation(activation_info)(x)
139142
return Model(inputs=input, outputs=x)
140143

141144
class Discriminator(tf.keras.Model):

src/ydata_synthetic/synthesizers/regular/wgan/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from os import mkdir, path
2-
from typing import List
2+
from typing import List, Optional, NamedTuple
33

44
import numpy as np
55
import tensorflow as tf
@@ -11,6 +11,7 @@
1111

1212
from ydata_synthetic.synthesizers import TrainParameters
1313
from ydata_synthetic.synthesizers.gan import BaseModel
14+
from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxActivation
1415

1516

1617
#Auxiliary Keras backend class to calculate the Random Weighted average
@@ -41,9 +42,10 @@ def __init__(self, model_parameters, n_critic, clip_value=0.01):
4142
def wasserstein_loss(self, y_true, y_pred):
4243
return K.mean(y_true * y_pred)
4344

44-
def define_gan(self):
45+
def define_gan(self, activation_info: Optional[NamedTuple] = None):
4546
self.generator = Generator(self.batch_size). \
46-
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)
47+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
48+
activation_info=activation_info)
4749

4850
self.critic = Critic(self.batch_size). \
4951
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)
@@ -96,7 +98,7 @@ def train(self, data, train_arguments: TrainParameters, num_cols: List[str],
9698

9799
processed_data = self.processor.transform(data)
98100
self.data_dim = processed_data.shape[1]
99-
self.define_gan()
101+
self.define_gan(self.processor.col_transform_info)
100102

101103
#Create a summary file
102104
iterations = int(abs(data.shape[0]/self.batch_size)+1)
@@ -153,12 +155,14 @@ class Generator(tf.keras.Model):
153155
def __init__(self, batch_size):
154156
self.batch_size = batch_size
155157

156-
def build_model(self, input_shape, dim, data_dim):
158+
def build_model(self, input_shape, dim, data_dim, activation_info: Optional[NamedTuple] = None):
157159
input = Input(shape=input_shape, batch_size=self.batch_size)
158160
x = Dense(dim, activation='relu')(input)
159161
x = Dense(dim * 2, activation='relu')(x)
160162
x = Dense(dim * 4, activation='relu')(x)
161163
x = Dense(data_dim)(x)
164+
if activation_info:
165+
x = GumbelSoftmaxActivation(activation_info)(x)
162166
return Model(inputs=input, outputs=x)
163167

164168
class Critic(tf.keras.Model):

0 commit comments

Comments
 (0)