-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMeanSquaredError.java
More file actions
46 lines (39 loc) · 1.11 KB
/
MeanSquaredError.java
File metadata and controls
46 lines (39 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/*
By Brendan C. Reidy
Created 12/10/2019
Last Modified 12/10/2019
Mean squared error cost function
*/
public class MeanSquaredError implements CostFunction {
private String name = "Mean Squared Error";
public float individualCost(float a, float b)
{
return ((b - a) * (b - a)) / 2;
} // MSE
public float individualError(float a, float b)
{
return b-a;
} // MSE derivative
public float[] cost(float[] output, float[] correctOutput) // MSE for array
{
float[] returnCost = new float[output.length];
for(int i=0; i<output.length; i++)
{
returnCost[i] = individualCost(output[i], correctOutput[i]);
}
return returnCost;
}
public float[] error(float[] output, float[] correctOutput) // MSE derivative for array
{
float[] returnError = new float[output.length];
for(int i=0; i<output.length; i++)
{
returnError[i] = individualError(output[i], correctOutput[i]);
}
return returnError;
}
public String toString()
{
return name;
}
}