-
Notifications
You must be signed in to change notification settings - Fork 91
GANs mnist #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
GANs mnist #172
Changes from 4 commits
8ff8b75
f0ea62c
a1bf251
3a68c83
acb2048
90214b6
3f9bafe
2ebecff
7ded9fc
e55e29e
0a6a590
abfcd33
e155073
bc3afc2
e3f9003
8ba21fd
a9d5011
2a45d92
2ea30b2
4cac2a4
25abdc9
3281b0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
|
|
||
| TARGET := mnist_gan | ||
| SRC := mnist_gan.cpp | ||
| LIBS_NAME := armadillo mlpack | ||
|
|
||
| CXX := g++ | ||
| CXXFLAGS += -std=c++11 -Wall -Wextra -O3 -DNDEBUG | ||
| # Use these CXXFLAGS instead if you want to compile with debugging symbols and | ||
| # without optimizations. | ||
| # CXXFLAGS += -std=c++11 -Wall -Wextra -g -O0 | ||
| LDFLAGS += -fopenmp | ||
| LDFLAGS += -lboost_serialization | ||
| LDFLAGS += -larmadillo | ||
| LDFLAGS += -L /home/viole/mlpack/build/lib/ # /path/to/mlpack/library/ # if installed locally. | ||
| # Add header directories for any includes that aren't on the | ||
| # default compiler search path. | ||
| INCLFLAGS := -I /home/viole/mlpac/build/include/ | ||
| CXXFLAGS += $(INCLFLAGS) | ||
|
|
||
| OBJS := $(SRC:.cpp=.o) | ||
| LIBS := $(addprefix -l,$(LIBS_NAME)) | ||
| CLEAN_LIST := $(TARGET) $(OBJS) | ||
|
|
||
| # default rule | ||
| default: all | ||
|
|
||
| $(TARGET): $(OBJS) | ||
| $(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) | ||
|
|
||
| .PHONY: all | ||
| all: $(TARGET) | ||
|
|
||
| .PHONY: clean | ||
| clean: | ||
| @echo CLEAN $(CLEAN_LIST) | ||
| @rm -f $(CLEAN_LIST) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| #include <mlpack/core.hpp> | ||
| #include <mlpack/core/data/split_data.hpp> | ||
|
|
||
| #include <mlpack/methods/ann/init_rules/gaussian_init.hpp> | ||
| #include <mlpack/methods/ann/loss_functions/sigmoid_cross_entropy_error.hpp> | ||
| #include <mlpack/methods/ann/gan/gan.hpp> | ||
| #include <mlpack/methods/ann/layer/layer.hpp> | ||
| #include <mlpack/methods/softmax_regression/softmax_regression.hpp> | ||
|
|
||
| #include <ensmallen.hpp> | ||
|
|
||
| using namespace mlpack; | ||
| using namespace mlpack::data; | ||
| using namespace mlpack::ann; | ||
| using namespace mlpack::math; | ||
| using namespace mlpack::regression; | ||
| using namespace std::placeholders; | ||
|
|
||
|
|
||
| int main() | ||
| { | ||
| size_t trainRatio = 0.8; | ||
| size_t dNumKernels = 32; | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| size_t discriminatorPreTrain = 5; | ||
| size_t batchSize = 64; | ||
| size_t noiseDim = 100; | ||
| size_t generatorUpdateStep = 1; | ||
| size_t numSamples = 10; | ||
| size_t cycles = 10; | ||
| double stepSize = 0.0003; | ||
| double eps = 1e-8; | ||
| size_t numEpochs = 1; | ||
| double tolerance = 1e-5; | ||
| bool shuffle = true; | ||
| double multiplier = 10; | ||
|
|
||
| std::cout << std::boolalpha | ||
| << " batchSize = " << batchSize << std::endl | ||
| << " generatorUpdateStep = " << generatorUpdateStep << std::endl | ||
| << " noiseDim = " << noiseDim << std::endl | ||
| << " numSamples = " << numSamples << std::endl | ||
| << " stepSize = " << stepSize << std::endl | ||
| << " numEpochs = " << numEpochs << std::endl | ||
| << " shuffle = " << shuffle << std::endl; | ||
|
|
||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| arma::mat mnistDataset; | ||
| data::Load("/home/viole/Documents/datasets/digit-recognizer/train.csv", mnistDataset, true); | ||
|
|
||
| std::cout << arma::size(mnistDataset) << std::endl; | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mnistDataset = mnistDataset.submat(1, 0, mnistDataset.n_rows-1, mnistDataset.n_cols-1); | ||
| mnistDataset /= 255.0; | ||
|
|
||
| arma::mat trainDataset, valDataset; | ||
| data::Split(mnistDataset, trainDataset, valDataset, trainRatio); | ||
|
|
||
| std::cout << " Dataset Loaded " << std::endl; | ||
| std::cout << " Train Dataset Size : (" << trainDataset.n_rows << ", " << trainDataset.n_cols << ")" << std::endl; | ||
|
|
||
| std::cout << " Validation Dataset Size : (" << valDataset.n_rows << ", " << valDataset.n_cols << ")" << std::endl; | ||
|
|
||
| arma::mat trainTest, dump; | ||
| data::Split(trainDataset, dump, trainTest, 0.045); | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| size_t iterPerCycle = (numEpochs * trainDataset.n_cols); | ||
|
|
||
| /** | ||
| * @brief Model Architecture: | ||
| * | ||
| * Discriminator: | ||
| * 28x28x1-----------> conv (32 filters of size 5x5, | ||
| * stride = 1, padding = 2)----------> 28x28x32 | ||
| * 28x28x32----------> ReLU -----------------------------> 28x28x32 | ||
| * 28x28x32----------> Mean pooling ---------------------> 14x14x32 | ||
| * 14x14x32----------> conv (64 filters of size 5x5, | ||
| * stride = 1, padding = 2)------> 14x14x64 | ||
| * 14x14x64----------> ReLU -----------------------------> 14x14x64 | ||
| * 14x14x64----------> Mean pooling ---------------------> 7x7x64 | ||
| * 7x7x64------------> Linear Layer ---------------------> 1024 | ||
| * 1024--------------> ReLU -----------------------------> 1024 | ||
| * 1024 -------------> Linear ---------------------------> 1 | ||
| * | ||
| * | ||
| * Generator: | ||
| * | ||
| * | ||
| * Note: Output of a Convolution layer = [(W-K+2P)/S + 1] | ||
| * where, W : Size of input volume | ||
| * K : Kernel size | ||
| * P : Padding | ||
| * S : Stride | ||
| */ | ||
|
|
||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Creating the Discriminator network. | ||
| FFN<SigmoidCrossEntropyError<> > discriminator; | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| discriminator.Add<Convolution<> >(1, // Number of input activation maps | ||
| dNumKernels, // Number of output activation maps | ||
| 5, // Filter width | ||
| 5, // Filter height | ||
| 1, // Stride along width | ||
| 1, // Stride along height | ||
| 2, // Padding width | ||
| 2, // Padding height | ||
| 28, // Input widht | ||
| 28); // Input height | ||
| // Adding first ReLU | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| discriminator.Add<ReLULayer<> >(); | ||
| // Adding mean pooling layer | ||
| discriminator.Add<MeanPooling<> >(2, 2, 2, 2); | ||
| // Adding second convolution layer | ||
| discriminator.Add<Convolution<> >(dNumKernels, 2 * dNumKernels, 5, 5, 1, 1, | ||
| 2, 2, 14, 14); | ||
| // Adding second ReLU | ||
| discriminator.Add<ReLULayer<> >(); | ||
| // Adding second mean pooling layer | ||
| discriminator.Add<MeanPooling<> >(2, 2, 2, 2); | ||
| // Adding linear layer | ||
| discriminator.Add<Linear<> >(7 * 7 * 2 * dNumKernels, 1024); | ||
| // Adding third ReLU | ||
| discriminator.Add<ReLULayer<> >(); | ||
| // Adding final layer | ||
| discriminator.Add<Linear<> >(1024, 1); | ||
|
|
||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Creating the Generator network | ||
| FFN<SigmoidCrossEntropyError<> > generator; | ||
| generator.Add<Linear<> >(noiseDim, 3136); | ||
| generator.Add<BatchNorm<> >(3136); | ||
| generator.Add<ReLULayer<> >(); | ||
| generator.Add<Convolution<> >(1, // Number of input activation maps | ||
| noiseDim / 2, // Number of output activation maps | ||
| 3, // Filter width | ||
| 3, // Filter height | ||
| 2, // Stride along width | ||
| 2, // Stride along height | ||
| 1, // Padding width | ||
| 1, // Padding height | ||
| 56, // input width | ||
| 56); // input height | ||
| // Adding first batch normalization layer | ||
| generator.Add<BatchNorm<> >(39200); | ||
| // Adding first ReLU | ||
| generator.Add<ReLULayer<> >(); | ||
| // Adding a bilinear interpolation layer | ||
| generator.Add<BilinearInterpolation<> >(28, 28, 56, 56, noiseDim / 2); | ||
| // Adding second convolution layer | ||
| generator.Add<Convolution<> >(noiseDim / 2, noiseDim / 4, 3, 3, 2, 2, 1, 1, | ||
| 56, 56); | ||
swaingotnochill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Adding second batch normalization layer | ||
| generator.Add<BatchNorm<> >(19600); | ||
| // Adding second ReLU | ||
| generator.Add<ReLULayer<> >(); | ||
| // Adding second bilinear interpolation layer | ||
| generator.Add<BilinearInterpolation<> >(28, 28, 56, 56, noiseDim / 4); | ||
| // Adding third convolution layer | ||
| generator.Add<Convolution<> >(noiseDim / 4, 1, 3, 3, 2, 2, 1, 1, 56, 56); | ||
| // Adding final tanh layer | ||
| generator.Add<TanHLayer<> >(); | ||
|
|
||
| // Creating GAN. | ||
| GaussianInitialization gaussian(0, 1); | ||
| ens::Adam optimizer(stepSize, // Step size of optimizer. | ||
| batchSize, // Batch size. | ||
| 0.9, // Exponential decay rate for first moment estimates. | ||
| 0.999, // Exponential decay rate for weighted norm estimates. | ||
| eps, // Value used to initialize the mean squared gradient parameter. | ||
| iterPerCycle, // Maximum number of iterations. | ||
| tolerance, // Tolerance. | ||
| shuffle); // Shuffle. | ||
| std::function<double()> noiseFunction = [] () { | ||
| return math::RandNormal(0, 1);}; | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| GAN<FFN<SigmoidCrossEntropyError<> >, GaussianInitialization, | ||
| std::function<double()> > gan(generator, discriminator, | ||
| gaussian, noiseFunction, noiseDim, batchSize, generatorUpdateStep, | ||
| discriminatorPreTrain, multiplier); | ||
swaingotnochill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| std::cout << "Training ... " << std::endl; | ||
|
|
||
| const clock_t beginTime = clock(); | ||
| // Cycles for monitoring training progress. | ||
| for( int i = 0; i < cycles; i++) | ||
swaingotnochill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
| // Training the neural network. For first iteration, weights are random, | ||
| // thus using current values as starting point. | ||
| gan.Train(trainDataset, | ||
| optimizer, | ||
| ens::PrintLoss(), | ||
| ens::ProgressBar(), | ||
| ens::Report()); | ||
|
|
||
| optimizer.ResetPolicy() = false; | ||
| std::cout << " Model Performance " << | ||
| gan.Evaluate(gan.Parameters(), // Parameters of the network. | ||
| i, // Index of current input. | ||
| batchSize); // Batch size. | ||
|
||
| } | ||
|
||
|
|
||
| std::cout << " Time taken to train -> " << float(clock()-beginTime) / CLOCKS_PER_SEC << "seconds" << std::endl; | ||
|
|
||
| data::Save("./saved_models/ganMnist.bin", "ganMnist", gan); | ||
| std::cout << "Model saved in mnist_gan/saved_models." << std::endl; | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.