|
| 1 | +package com.codingapi.deeplearning.demo04.learn; |
| 2 | + |
| 3 | +import lombok.extern.slf4j.Slf4j; |
| 4 | +import org.nd4j.linalg.api.ndarray.INDArray; |
| 5 | +import org.nd4j.linalg.factory.Nd4j; |
| 6 | +import org.nd4j.linalg.ops.transforms.Transforms; |
| 7 | + |
| 8 | + |
| 9 | +//关于运算Nd4j下的运算符号介绍: |
| 10 | + |
| 11 | +//INDArray mmul INDArray 是矩阵之间的相乘 (2,3)*(3,2) => (2,2),两个矩阵需要满足矩阵相乘的形状 (A,B) * (B,C) => (A,C) |
| 12 | +//INDArray mul INDArray 是矩阵中每个值相乘 (2,3)*(2,3) => (2,3),两个矩阵必须是相同的shape |
| 13 | +//Transforms.sigmoid sigmoid函数 |
| 14 | +//INDArray broadcast 是将矩阵的数据做扩充,例如(1,3)的数据通过 broadcast(4,3)就变成一个(4,3)的矩阵,数据都与第一行相同 |
| 15 | +//INDArray add 是矩阵的加分运算,必须是(A,B)+ (A,B)形状的。add方法也可以加 一个数字,就是对所有的项都加该数组 |
| 16 | +//INDArray transpose 是矩阵的转置运算 |
| 17 | +//INDArray rsub 是对矩阵的逆减法,例如 [[2,3]].rsub(1) = >[[1-2,1-3]] |
| 18 | +//INDArray sub 是矩阵的减法运算,方式同加分运算 |
| 19 | +//Nd4j.sum 是 对矩阵的row,columns是求和,设置为0是对所有的columns求和,1是对columns求和,不设置是全部求和 |
| 20 | + |
| 21 | +/** |
| 22 | + * |
| 23 | + * @author lorne |
| 24 | + * @date 2019-10-31 |
| 25 | + * @description 神经网络反向传递算法实现 |
| 26 | + */ |
| 27 | +@Slf4j |
| 28 | +public class BackPropagationFunction { |
| 29 | + |
| 30 | + private INDArray w1; |
| 31 | + private INDArray b1; |
| 32 | + private INDArray w2; |
| 33 | + private INDArray b2; |
| 34 | + private INDArray dw1,dw2,db1,db2; |
| 35 | + |
| 36 | + |
| 37 | + private INDArray delta1; |
| 38 | + private INDArray delta2; |
| 39 | + |
| 40 | + private double lambda; |
| 41 | + private double alpha; |
| 42 | + |
| 43 | + private int batch; |
| 44 | + |
| 45 | + |
| 46 | + public BackPropagationFunction(double lambda, double alpha, int batch,int inputs) { |
| 47 | + this.lambda = lambda; |
| 48 | + this.alpha = alpha; |
| 49 | + this.batch = batch; |
| 50 | + |
| 51 | + w1 = Nd4j.rand(2,inputs); |
| 52 | + b1 = Nd4j.rand(1,inputs); |
| 53 | + |
| 54 | + w2 = Nd4j.rand(inputs,1); |
| 55 | + b2 = Nd4j.rand(1,1); |
| 56 | + |
| 57 | + |
| 58 | + } |
| 59 | + |
| 60 | + /** |
| 61 | + * 反向传播的训练过程 |
| 62 | + * @param dataSet 数据集 |
| 63 | + * 假如: 显示一个简单的2层网络 |
| 64 | + * |
| 65 | + */ |
| 66 | + public void train(DataSet dataSet){ |
| 67 | + |
| 68 | + log.info("x:shape->{},y:shape->{},w1:shape->{}," + |
| 69 | + "b1:shape->{},w2:shape->{},b2:shape->{}" |
| 70 | + ,dataSet.getX().shape(),dataSet.getY().shape(),w1.shape(), |
| 71 | + b1.shape(),w2.shape(),b2.shape()); |
| 72 | + |
| 73 | + int m = dataSet.getX().rows(); |
| 74 | + |
| 75 | + |
| 76 | + for(int i=0;i<batch;i++) { |
| 77 | + //FP |
| 78 | + INDArray z1 = dataSet.getX().mmul(w1).add(b1.broadcast(m,b1.columns())); |
| 79 | + INDArray a1 = Transforms.sigmoid(z1); |
| 80 | + |
| 81 | + INDArray z2 = a1.mmul(w2).add(b2.broadcast(m,b2.columns())); |
| 82 | + INDArray a2 = Transforms.sigmoid(z2); |
| 83 | + |
| 84 | + //BP |
| 85 | + delta2 = dataSet.getY().sub(a2); |
| 86 | + delta1 = delta2.mmul(w2.transpose()).mul(a1.mul(a1.rsub(1))); |
| 87 | + |
| 88 | + dw1 = dataSet.getX().transpose().mmul(delta1).add(w1.mul(lambda)); |
| 89 | + db1 = Nd4j.sum(delta1, 0); |
| 90 | + dw2 = a1.transpose().mmul(delta2).add(w2.mul(lambda)); |
| 91 | + db2 = Nd4j.sum(delta2, 0); |
| 92 | + |
| 93 | + w1 = w1.sub(dw1.mul(alpha)); |
| 94 | + b1 = b1.sub(db1.mul(alpha)); |
| 95 | + w2 = w2.sub(dw2.mul(alpha)); |
| 96 | + b2 = b2.sub(db2.mul(alpha)); |
| 97 | + } |
| 98 | + |
| 99 | + log.info("w1:->\n{}",w1); |
| 100 | + log.info("b1:->\n{}",b1); |
| 101 | + log.info("w2:->\n{}",w2); |
| 102 | + log.info("b2:->\n{}",b2); |
| 103 | + |
| 104 | + } |
| 105 | + |
| 106 | + |
| 107 | + |
| 108 | + |
| 109 | + |
| 110 | +} |
0 commit comments