Skip to content

Commit c9fe0ba

Browse files
committed
demo04 ...
1 parent 36bcfd7 commit c9fe0ba

File tree

9 files changed

+259
-0
lines changed

9 files changed

+259
-0
lines changed

demo04/init/lr_data.csv

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2,1,0
2+
2,2,0
3+
5,4,1
4+
4,5,1
5+
2,3,0
6+
3,2,0
7+
6,5,1
8+
4,1,0
9+
6,3,1
10+
7,4,1

demo04/pom.xml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<parent>
6+
<artifactId>deeplearning-java</artifactId>
7+
<groupId>com.codingapi.deepleaning</groupId>
8+
<version>0.0.1-SNAPSHOT</version>
9+
</parent>
10+
<modelVersion>4.0.0</modelVersion>
11+
12+
<artifactId>demo04</artifactId>
13+
14+
15+
<dependencies>
16+
<dependency>
17+
<groupId>org.nd4j</groupId>
18+
<artifactId>nd4j-native-platform</artifactId>
19+
<version>1.0.0-beta4</version>
20+
</dependency>
21+
</dependencies>
22+
23+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.codingapi.deeplearning.demo04;
2+
3+
import org.springframework.boot.SpringApplication;
4+
import org.springframework.boot.autoconfigure.SpringBootApplication;
5+
6+
/**
7+
* @author lorne
8+
* @date 2019-10-31
9+
* @description
10+
*/
11+
@SpringBootApplication
12+
public class DeepLearningJavaDemo04Application {
13+
14+
public static void main(String[] args) {
15+
SpringApplication.run(DeepLearningJavaDemo04Application.class,args);
16+
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package com.codingapi.deeplearning.demo04.learn;
2+
3+
import lombok.extern.slf4j.Slf4j;
4+
import org.nd4j.linalg.api.ndarray.INDArray;
5+
import org.nd4j.linalg.factory.Nd4j;
6+
import org.nd4j.linalg.ops.transforms.Transforms;
7+
8+
9+
//关于运算Nd4j下的运算符号介绍:
10+
11+
//INDArray mmul INDArray 是矩阵之间的相乘 (2,3)*(3,2) => (2,2),两个矩阵需要满足矩阵相乘的形状 (A,B) * (B,C) => (A,C)
12+
//INDArray mul INDArray 是矩阵中每个值相乘 (2,3)*(2,3) => (2,3),两个矩阵必须是相同的shape
13+
//Transforms.sigmoid sigmoid函数
14+
//INDArray broadcast 是将矩阵的数据做扩充,例如(1,3)的数据通过 broadcast(4,3)就变成一个(4,3)的矩阵,数据都与第一行相同
15+
//INDArray add 是矩阵的加分运算,必须是(A,B)+ (A,B)形状的。add方法也可以加 一个数字,就是对所有的项都加该数组
16+
//INDArray transpose 是矩阵的转置运算
17+
//INDArray rsub 是对矩阵的逆减法,例如 [[2,3]].rsub(1) = >[[1-2,1-3]]
18+
//INDArray sub 是矩阵的减法运算,方式同加分运算
19+
//Nd4j.sum 是 对矩阵的row,columns是求和,设置为0是对所有的columns求和,1是对columns求和,不设置是全部求和
20+
21+
/**
22+
*
23+
* @author lorne
24+
* @date 2019-10-31
25+
* @description 神经网络反向传递算法实现
26+
*/
27+
@Slf4j
28+
public class BackPropagationFunction {
29+
30+
private INDArray w1;
31+
private INDArray b1;
32+
private INDArray w2;
33+
private INDArray b2;
34+
private INDArray dw1,dw2,db1,db2;
35+
36+
37+
private INDArray delta1;
38+
private INDArray delta2;
39+
40+
private double lambda;
41+
private double alpha;
42+
43+
private int batch;
44+
45+
46+
public BackPropagationFunction(double lambda, double alpha, int batch,int inputs) {
47+
this.lambda = lambda;
48+
this.alpha = alpha;
49+
this.batch = batch;
50+
51+
w1 = Nd4j.rand(2,inputs);
52+
b1 = Nd4j.rand(1,inputs);
53+
54+
w2 = Nd4j.rand(inputs,1);
55+
b2 = Nd4j.rand(1,1);
56+
57+
58+
}
59+
60+
/**
61+
* 反向传播的训练过程
62+
* @param dataSet 数据集
63+
* 假如: 显示一个简单的2层网络
64+
*
65+
*/
66+
public void train(DataSet dataSet){
67+
68+
log.info("x:shape->{},y:shape->{},w1:shape->{}," +
69+
"b1:shape->{},w2:shape->{},b2:shape->{}"
70+
,dataSet.getX().shape(),dataSet.getY().shape(),w1.shape(),
71+
b1.shape(),w2.shape(),b2.shape());
72+
73+
int m = dataSet.getX().rows();
74+
75+
76+
for(int i=0;i<batch;i++) {
77+
//FP
78+
INDArray z1 = dataSet.getX().mmul(w1).add(b1.broadcast(m,b1.columns()));
79+
INDArray a1 = Transforms.sigmoid(z1);
80+
81+
INDArray z2 = a1.mmul(w2).add(b2.broadcast(m,b2.columns()));
82+
INDArray a2 = Transforms.sigmoid(z2);
83+
84+
//BP
85+
delta2 = dataSet.getY().sub(a2);
86+
delta1 = delta2.mmul(w2.transpose()).mul(a1.mul(a1.rsub(1)));
87+
88+
dw1 = dataSet.getX().transpose().mmul(delta1).add(w1.mul(lambda));
89+
db1 = Nd4j.sum(delta1, 0);
90+
dw2 = a1.transpose().mmul(delta2).add(w2.mul(lambda));
91+
db2 = Nd4j.sum(delta2, 0);
92+
93+
w1 = w1.sub(dw1.mul(alpha));
94+
b1 = b1.sub(db1.mul(alpha));
95+
w2 = w2.sub(dw2.mul(alpha));
96+
b2 = b2.sub(db2.mul(alpha));
97+
}
98+
99+
log.info("w1:->\n{}",w1);
100+
log.info("b1:->\n{}",b1);
101+
log.info("w2:->\n{}",w2);
102+
log.info("b2:->\n{}",b2);
103+
104+
}
105+
106+
107+
108+
109+
110+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package com.codingapi.deeplearning.demo04.learn;
2+
3+
import org.nd4j.linalg.api.ndarray.INDArray;
4+
5+
/**
6+
* 特征缩放
7+
* @author lorne
8+
* @date 2019-10-31
9+
* @description
10+
*/
11+
public class DataScalingHelper {
12+
13+
private double max;
14+
private double min;
15+
private DataSet data;
16+
17+
public DataScalingHelper(DataSet data) {
18+
this.data = data;
19+
this.max = data.getX().maxNumber().doubleValue();
20+
this.min = data.getY().maxNumber().doubleValue();
21+
}
22+
23+
public void scaling(){
24+
INDArray array = data.getX();
25+
array = array.sub(min).div((max-min));
26+
data.setX(array);
27+
}
28+
29+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.codingapi.deeplearning.demo04.learn;
2+
3+
import lombok.Data;
4+
import org.nd4j.linalg.api.ndarray.INDArray;
5+
import org.nd4j.linalg.factory.Nd4j;
6+
7+
import java.io.IOException;
8+
9+
/**
10+
* @author lorne
11+
* @date 2019-10-31
12+
* @description 数据集
13+
*/
14+
@Data
15+
public class DataSet {
16+
17+
private INDArray x;
18+
private INDArray y;
19+
20+
21+
public DataSet() throws IOException {
22+
String filePath = "init/lr_data.csv";
23+
INDArray data = Nd4j.readNumpy(filePath,",");
24+
25+
x = data.getColumns(0,1);
26+
y = data.getColumns(2);
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.codingapi.deeplearning.demo04.learn;
2+
3+
/**
4+
* @author lorne
5+
* @date 2019-10-31
6+
* @description 简单的神经网络层
7+
*/
8+
public class SimpleNeuralNetworkLayer {
9+
10+
11+
12+
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.codingapi.deeplearning.demo04.learn;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.springframework.boot.test.context.SpringBootTest;
5+
6+
import java.io.IOException;
7+
8+
9+
/**
10+
* @author lorne
11+
* @date 2019-10-31
12+
* @description
13+
*/
14+
@SpringBootTest
15+
class BackPropagationFunctionTest {
16+
17+
18+
@Test
19+
void train() throws IOException {
20+
DataSet dataSet = new DataSet();
21+
DataScalingHelper scalingHelper = new DataScalingHelper(dataSet);
22+
scalingHelper.scaling();
23+
24+
BackPropagationFunction backPropagationFunction =
25+
new BackPropagationFunction(0,0.1,100,3);
26+
backPropagationFunction.train(dataSet);
27+
}
28+
}

pom.xml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
<module>demo01</module>
88
<module>demo02</module>
99
<module>demo03</module>
10+
<module>demo04</module>
1011
</modules>
1112
<parent>
1213
<groupId>org.springframework.boot</groupId>

0 commit comments

Comments
 (0)