diff --git a/dl4j-gan-examples/README.md b/dl4j-gan-examples/README.md new file mode 100644 index 0000000000..0143fef903 --- /dev/null +++ b/dl4j-gan-examples/README.md @@ -0,0 +1,6 @@ +An example of a simple gan implemented with DL4J + + ***** ******** ***************** + z ---- * G *----* G(z) * ------ * discriminator * ---- fake + ***** ******** * * + x ----------------------------- ***************** ---- real diff --git a/dl4j-gan-examples/pom.xml b/dl4j-gan-examples/pom.xml new file mode 100644 index 0000000000..1b3f4bb647 --- /dev/null +++ b/dl4j-gan-examples/pom.xml @@ -0,0 +1,113 @@ + + + 4.0.0 + + org.deeplearning4j + dl4j-gan-examples + 1.0.0-SNAPSHOT + + + + 1.0.0-beta7 + 1.2.3 + 1.8 + 2.4.3 + UTF-8 + + + + + + org.deeplearning4j + deeplearning4j-core + ${dl4j-master.version} + + + org.nd4j + nd4j-native + ${dl4j-master.version} + + + ch.qos.logback + logback-classic + ${logback.version} + + + junit + junit + 4.12 + compile + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.5.1 + + ${java.version} + ${java.version} + + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + true + bin + true + + + *:* + + org/datanucleus/** + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + package + + shade + + + + + reference.conf + + + + + + + + + + + + diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java new file mode 100644 index 0000000000..04f57c5537 --- /dev/null +++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/MNISTVisualizer.java @@ -0,0 +1,75 @@ +package org.deeplearning4j.ganexamples; + +import org.nd4j.linalg.api.ndarray.INDArray; + +import javax.swing.*; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.util.ArrayList; +import java.util.List; + +/** + * @author zdl + */ +public class MNISTVisualizer { + private double imageScale; + private List digits; + private String title; + private int gridWidth; + private JFrame frame; + + public MNISTVisualizer(double imageScale, String title) { + this(imageScale, title, 5); + } + + public MNISTVisualizer(double imageScale, String title, int gridWidth) { + this.imageScale = imageScale; + this.title = title; + this.gridWidth = gridWidth; + } + + public void visualize() { + if (null != frame) { + frame.dispose(); + } + frame = new JFrame(); + frame.setTitle(title); + frame.setSize(800, 600); + JPanel panel = new JPanel(); + panel.setPreferredSize(new Dimension(800, 600)); + panel.setLayout(new GridLayout(0, gridWidth)); + List list = getComponents(); + for (JLabel image : list) { + panel.add(image); + } + + frame.add(panel); + frame.setVisible(true); + frame.pack(); + } + + public List getComponents() { + List images = new ArrayList(); + for (INDArray arr : digits) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i))); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28), + Image.SCALE_DEFAULT); + ImageIcon scaled = new ImageIcon(imageScaled); + images.add(new JLabel(scaled)); + } + return images; + } + + public List getDigits() { + return digits; + } + + public void setDigits(List digits) { + this.digits = digits; + } + +} diff --git a/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java new file mode 100644 index 0000000000..9452e469b1 --- /dev/null +++ b/dl4j-gan-examples/src/main/java/org/deeplearning4j/ganexamples/SimpleGan.java @@ -0,0 +1,115 @@ +package org.deeplearning4j.ganexamples; + +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +/** + * ***** ******** ***************** + * z ---- * G *----* G(z) * ------ * discriminator * ---- fake + * ***** ******** * * + * x ----------------------------- ***************** ---- real + * + * @author zdl + */ +public class SimpleGan { + + public static void main(String[] args) throws Exception { + + /** + *Build the discriminator + */ + MultiLayerConfiguration discriminatorConf = new NeuralNetConfiguration.Builder().seed(12345) + .weightInit(WeightInit.XAVIER).updater(new RmsProp(0.001)) + .list() + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(512).nOut(256).build()) + .layer(2, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(256).nOut(128).build()) + .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.XENT) + .activation(Activation.SIGMOID).nIn(128).nOut(1).build()).build(); + + + MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder().seed(12345) + .weightInit(WeightInit.XAVIER) + //generator + .updater(new RmsProp(0.001)).list() + .layer(0, new DenseLayer.Builder().nIn(20).nOut(256).activation(Activation.RELU).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(256).nOut(512).build()) + .layer(2, new DenseLayer.Builder().activation(Activation.RELU) + .nIn(512).nOut(28 * 28).build()) + //Freeze the discriminator parameter + .layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build())) + .layer(4, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(512).nOut(256).activation(Activation.RELU).build())) + .layer(5, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU).build())) + .layer(6, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.XENT) + .activation(Activation.SIGMOID).nIn(128).nOut(1).build())).build(); + + + MultiLayerNetwork discriminatorNetwork = new MultiLayerNetwork(discriminatorConf); + discriminatorNetwork.init(); + System.out.println(discriminatorNetwork.summary()); + discriminatorNetwork.setListeners(new ScoreIterationListener(1)); + + MultiLayerNetwork ganNetwork = new MultiLayerNetwork(ganConf); + ganNetwork.init(); + ganNetwork.setListeners(new ScoreIterationListener(1)); + System.out.println(ganNetwork.summary()); + + DataSetIterator train = new MnistDataSetIterator(30, true, 12345); + + INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); + INDArray labelG = Nd4j.ones(30, 1); + MNISTVisualizer mnistVisualizer = new MNISTVisualizer(1, "Gan"); + for (int i = 1; i <= 100000; i++) { + if (!train.hasNext()) { + train.reset(); + } + INDArray trueImage = train.next().getFeatures(); + INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 20}); + List ganFeedForward = ganNetwork.feedForward(z, false); + INDArray fakeImage = ganFeedForward.get(3); + INDArray trainDiscriminatorFeatures = Nd4j.vstack(trueImage, fakeImage); + //Training discriminator + discriminatorNetwork.fit(trainDiscriminatorFeatures, labelD); + copyDiscriminatorParam(discriminatorNetwork, ganNetwork); + //Training generator + ganNetwork.fit(z, labelG); + if (i % 1000 == 0) { + List indArrays = ganNetwork.feedForward(Nd4j.rand(new NormalDistribution(), new long[]{30, 20}), false); + List list = new ArrayList<>(); + INDArray indArray = indArrays.get(3); + for (int j = 0; j < indArray.size(0); j++) { + list.add(indArray.getRow(j)); + } + mnistVisualizer.setDigits(list); + mnistVisualizer.visualize(); + } + } + } + + public static void copyDiscriminatorParam(MultiLayerNetwork discriminatorNetwork, MultiLayerNetwork ganNetwork) { + for (int i = 0; i <= 3; i++) { + ganNetwork.getLayer(i + 3).setParams(discriminatorNetwork.getLayer(i).params()); + } + } +}