Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mnist_gan/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.ipynb_checkpoints
saved_models
samples_posterior
samples_csv_files
mnist_first250_training_4s_and_9s.arm
mnist_gan
mnist_gan_generate
mnist_gan_generate.o
mnist_gan.o
36 changes: 36 additions & 0 deletions mnist_gan/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

TARGET := mnist_gan_generate
SRC := mnist_gan_generate.cpp
LIBS_NAME := armadillo mlpack

CXX := g++
CXXFLAGS += -std=c++11 -Wall -Wextra -O3 -DNDEBUG
# Use these CXXFLAGS instead if you want to compile with debugging symbols and
# without optimizations.
# CXXFLAGS += -std=c++11 -Wall -Wextra -g -O0
LDFLAGS += -fopenmp
LDFLAGS += -lboost_serialization
LDFLAGS += -larmadillo
LDFLAGS += -L /home/viole/mlpack/build/lib/ # /path/to/mlpack/library/ # if installed locally.
# Add header directories for any includes that aren't on the
# default compiler search path.
INCLFLAGS := -I /home/viole/mlpac/build/include/
CXXFLAGS += $(INCLFLAGS)

OBJS := $(SRC:.cpp=.o)
LIBS := $(addprefix -l,$(LIBS_NAME))
CLEAN_LIST := $(TARGET) $(OBJS)

# default rule
default: all

$(TARGET): $(OBJS)
$(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS)

.PHONY: all
all: $(TARGET)

.PHONY: clean
clean:
@echo CLEAN $(CLEAN_LIST)
@rm -f $(CLEAN_LIST)
42 changes: 42 additions & 0 deletions mnist_gan/gan_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* @file gan_utils.cpp
* @author Roshan Swain
* @author Atharva Khandait
*
* Utility function necessary for working with GAN models.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MODELS_GAN_UTILS_HPP
#define MODELS_GAN_UTILS_HPP

#include <mlpack/core.hpp>
#include <mlpack/methods/ann/ffn.hpp>

using namespace mlpack;
using namespace mlpack::ann;

// Sample from the output distribution and post-process the outputs(because
// we pre-processed it before passing it to the model).
template<typename DataType = arma::mat>
void GetSample(DataType &input, DataType& samples, bool isBinary)
{
if (isBinary)
{
samples = arma::conv_to<DataType>::from(
arma::randu<DataType>(input.n_rows, input.n_cols) <= input);
samples *= 255;
}
else
{
samples = input / 2 + 0.5;
samples *= 255;
samples = arma::clamp(samples, 0, 255);
}
}

#endif
91 changes: 91 additions & 0 deletions mnist_gan/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@file generate_images.py
@author Atharva Khandait
Generates jpg files from csv.
mlpack is free software; you may redistribute it and/or modify it under the
terms of the 3-clause BSD license. You should have received a copy of the
3-clause BSD license along with mlpack. If not, see
http://www.opensource.org/licenses/BSD-3-Clause for more information.
"""

from PIL import Image
import numpy as np
import cv2
import os

def ImagesFromCSV(filename,
imgShape = (28, 28),
destination = 'samples',
saveIndividual = False):

# Import the data into a numpy matrix.
samples = np.genfromtxt(filename, delimiter = ',', dtype = np.uint8)

# Reshape and save it as an image in the destination.
tempImage = Image.fromarray(np.reshape(samples[:, 0], imgShape), 'L')
if saveIndividual:
tempImage.save(destination + '/sample0.jpg')

# All the images will be concatenated to this for a combined image.
allSamples = tempImage

for i in range(1, samples.shape[1]):
tempImage = np.reshape(samples[:, i], imgShape)

allSamples = np.concatenate((allSamples, tempImage), axis = 1)

tempImage = Image.fromarray(tempImage, 'L')
if saveIndividual:
tempImage.save(destination + '/sample' + str(i) + '.jpg')

tempImage = allSamples
allSamples = Image.fromarray(allSamples, 'L')
allSamples.save(destination + '/allSamples' + '.jpg')

print ('Samples saved in ' + destination + '/.')

return tempImage

# Save posterior samples.
ImagesFromCSV('./samples_csv_files/samples_posterior.csv', destination =
'samples_posterior')

# Save prior samples with individual latent varying.
latentSize = 10
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent0.csv',
destination = 'samples_prior')

for i in range(1, latentSize):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent' + str(i) + '.csv',
destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/allLatent.jpg')

# Save prior samples with 2d latent varying.
nofSamples = 20
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d0.csv',
destination = 'latent')

for i in range(1, nofSamples):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d' + str(i) +
'.csv', destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/2dLatent.jpg')

# AVI file
vid_fname = 'gans_celebface_training1.avi'
sample_dir = " "

files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'generated' in f]
files.sort()

out = cv2.VideoWriter(vid_fname, cv2.VideoWriter_fourcc(*'MP4V'), 1, (530, 530))
[out.write(cv2.imread(fname)) for fname in files]
out.release()


###Output###
Loading