diff --git a/dl4j-gan-examples/README.md b/dl4j-gan-examples/README.md
new file mode 100644
index 0000000000..0143fef903
--- /dev/null
+++ b/dl4j-gan-examples/README.md
@@ -0,0 +1,6 @@
+An example of a simple gan implemented with DL4J
+
+ ***** ******** *****************
+ z ---- * G *----* G(z) * ------ * discriminator * ---- fake
+ ***** ******** * *
+ x ----------------------------- ***************** ---- real
diff --git a/dl4j-gan-examples/pom.xml b/dl4j-gan-examples/pom.xml
new file mode 100644
index 0000000000..1b3f4bb647
--- /dev/null
+++ b/dl4j-gan-examples/pom.xml
@@ -0,0 +1,113 @@
+
+
+ 4.0.0
+
+ org.deeplearning4j
+ dl4j-gan-examples
+ 1.0.0-SNAPSHOT
+
+
+
+ 1.0.0-beta7
+ 1.2.3
+ 1.8
+ 2.4.3
+ UTF-8
+
+
+
+
+
+ org.deeplearning4j
+ deeplearning4j-core
+ ${dl4j-master.version}
+
+
+ org.nd4j
+ nd4j-native
+ ${dl4j-master.version}
+
+
+ ch.qos.logback
+ logback-classic
+ ${logback.version}
+
+
+ junit
+ junit
+ 4.12
+ compile
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.5.1
+
+ ${java.version}
+ ${java.version}
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ ${maven-shade-plugin.version}
+
+ true
+ bin
+ true
+
+
+ *:*
+
+ org/datanucleus/**
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
+ reference.conf
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java
new file mode 100644
index 0000000000..04f57c5537
--- /dev/null
+++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java
@@ -0,0 +1,75 @@
+package org.deeplearning4j.ganexamples;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+
+import javax.swing.*;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * @author zdl
+ */
+public class MNISTVisualizer {
+ private double imageScale;
+ private List digits;
+ private String title;
+ private int gridWidth;
+ private JFrame frame;
+
+ public MNISTVisualizer(double imageScale, String title) {
+ this(imageScale, title, 5);
+ }
+
+ public MNISTVisualizer(double imageScale, String title, int gridWidth) {
+ this.imageScale = imageScale;
+ this.title = title;
+ this.gridWidth = gridWidth;
+ }
+
+ public void visualize() {
+ if (null != frame) {
+ frame.dispose();
+ }
+ frame = new JFrame();
+ frame.setTitle(title);
+ frame.setSize(800, 600);
+ JPanel panel = new JPanel();
+ panel.setPreferredSize(new Dimension(800, 600));
+ panel.setLayout(new GridLayout(0, gridWidth));
+ List list = getComponents();
+ for (JLabel image : list) {
+ panel.add(image);
+ }
+
+ frame.add(panel);
+ frame.setVisible(true);
+ frame.pack();
+ }
+
+ public List getComponents() {
+ List images = new ArrayList();
+ for (INDArray arr : digits) {
+ BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
+ for (int i = 0; i < 784; i++) {
+ bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i)));
+ }
+ ImageIcon orig = new ImageIcon(bi);
+ Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28),
+ Image.SCALE_DEFAULT);
+ ImageIcon scaled = new ImageIcon(imageScaled);
+ images.add(new JLabel(scaled));
+ }
+ return images;
+ }
+
+ public List getDigits() {
+ return digits;
+ }
+
+ public void setDigits(List digits) {
+ this.digits = digits;
+ }
+
+}
diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java
new file mode 100644
index 0000000000..9452e469b1
--- /dev/null
+++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java
@@ -0,0 +1,115 @@
+package org.deeplearning4j.ganexamples;
+
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.RmsProp;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * ***** ******** *****************
+ * z ---- * G *----* G(z) * ------ * discriminator * ---- fake
+ * ***** ******** * *
+ * x ----------------------------- ***************** ---- real
+ *
+ * @author zdl
+ */
+public class SimpleGan {
+
+ public static void main(String[] args) throws Exception {
+
+ /**
+ *Build the discriminator
+ */
+ MultiLayerConfiguration discriminatorConf = new NeuralNetConfiguration.Builder().seed(12345)
+ .weightInit(WeightInit.XAVIER).updater(new RmsProp(0.001))
+ .list()
+ .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build())
+ .layer(1, new DenseLayer.Builder().activation(Activation.RELU)
+ .nIn(512).nOut(256).build())
+ .layer(2, new DenseLayer.Builder().activation(Activation.RELU)
+ .nIn(256).nOut(128).build())
+ .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
+ .activation(Activation.SIGMOID).nIn(128).nOut(1).build()).build();
+
+
+ MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder().seed(12345)
+ .weightInit(WeightInit.XAVIER)
+ //generator
+ .updater(new RmsProp(0.001)).list()
+ .layer(0, new DenseLayer.Builder().nIn(20).nOut(256).activation(Activation.RELU).build())
+ .layer(1, new DenseLayer.Builder().activation(Activation.RELU)
+ .nIn(256).nOut(512).build())
+ .layer(2, new DenseLayer.Builder().activation(Activation.RELU)
+ .nIn(512).nOut(28 * 28).build())
+ //Freeze the discriminator parameter
+ .layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build()))
+ .layer(4, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(512).nOut(256).activation(Activation.RELU).build()))
+ .layer(5, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU).build()))
+ .layer(6, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
+ .activation(Activation.SIGMOID).nIn(128).nOut(1).build())).build();
+
+
+ MultiLayerNetwork discriminatorNetwork = new MultiLayerNetwork(discriminatorConf);
+ discriminatorNetwork.init();
+ System.out.println(discriminatorNetwork.summary());
+ discriminatorNetwork.setListeners(new ScoreIterationListener(1));
+
+ MultiLayerNetwork ganNetwork = new MultiLayerNetwork(ganConf);
+ ganNetwork.init();
+ ganNetwork.setListeners(new ScoreIterationListener(1));
+ System.out.println(ganNetwork.summary());
+
+ DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
+
+ INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
+ INDArray labelG = Nd4j.ones(30, 1);
+ MNISTVisualizer mnistVisualizer = new MNISTVisualizer(1, "Gan");
+ for (int i = 1; i <= 100000; i++) {
+ if (!train.hasNext()) {
+ train.reset();
+ }
+ INDArray trueImage = train.next().getFeatures();
+ INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 20});
+ List ganFeedForward = ganNetwork.feedForward(z, false);
+ INDArray fakeImage = ganFeedForward.get(3);
+ INDArray trainDiscriminatorFeatures = Nd4j.vstack(trueImage, fakeImage);
+ //Training discriminator
+ discriminatorNetwork.fit(trainDiscriminatorFeatures, labelD);
+ copyDiscriminatorParam(discriminatorNetwork, ganNetwork);
+ //Training generator
+ ganNetwork.fit(z, labelG);
+ if (i % 1000 == 0) {
+ List indArrays = ganNetwork.feedForward(Nd4j.rand(new NormalDistribution(), new long[]{30, 20}), false);
+ List list = new ArrayList<>();
+ INDArray indArray = indArrays.get(3);
+ for (int j = 0; j < indArray.size(0); j++) {
+ list.add(indArray.getRow(j));
+ }
+ mnistVisualizer.setDigits(list);
+ mnistVisualizer.visualize();
+ }
+ }
+ }
+
+ public static void copyDiscriminatorParam(MultiLayerNetwork discriminatorNetwork, MultiLayerNetwork ganNetwork) {
+ for (int i = 0; i <= 3; i++) {
+ ganNetwork.getLayer(i + 3).setParams(discriminatorNetwork.getLayer(i).params());
+ }
+ }
+}