Skip to content

Commit 5f9fce5

Browse files
RegularizerAPI and UnitTest
1 parent f5ba382 commit 5f9fce5

File tree

4 files changed

+98
-7
lines changed

4 files changed

+98
-7
lines changed

src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,14 @@ public interface IRegularizer
1212
[JsonProperty("config")]
1313
IDictionary<string, object> Config { get; }
1414
Tensor Apply(RegularizerArgs args);
15-
}
15+
}
16+
17+
public interface IRegularizerApi
18+
{
19+
IRegularizer GetRegularizerFromName(string name);
20+
IRegularizer L1 { get; }
21+
IRegularizer L2 { get; }
22+
IRegularizer L1L2 { get; }
23+
}
24+
1625
}

src/TensorFlowNET.Core/Operations/Regularizers/L1.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public class L1 : IRegularizer
99
float _l1;
1010
private readonly Dictionary<string, object> _config;
1111

12-
public string ClassName => "L2";
12+
public string ClassName => "L1";
1313
public virtual IDictionary<string, object> Config => _config;
1414

1515
public L1(float l1 = 0.01f)
+39-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,51 @@
1-
namespace Tensorflow.Keras
1+
using Tensorflow.Operations.Regularizers;
2+
3+
namespace Tensorflow.Keras
24
{
3-
public class Regularizers
5+
public class Regularizers: IRegularizerApi
46
{
7+
private static Dictionary<string, IRegularizer> _nameActivationMap;
8+
59
public IRegularizer l1(float l1 = 0.01f)
6-
=> new Tensorflow.Operations.Regularizers.L1(l1);
10+
=> new L1(l1);
711
public IRegularizer l2(float l2 = 0.01f)
8-
=> new Tensorflow.Operations.Regularizers.L2(l2);
12+
=> new L2(l2);
913

1014
//From TF source
1115
//# The default value for l1 and l2 are different from the value in l1_l2
1216
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
1317
//# and no l1 penalty.
1418
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
15-
=> new Tensorflow.Operations.Regularizers.L1L2(l1, l2);
19+
=> new L1L2(l1, l2);
20+
21+
static Regularizers()
22+
{
23+
_nameActivationMap = new Dictionary<string, IRegularizer>();
24+
_nameActivationMap["L1"] = new L1();
25+
_nameActivationMap["L1"] = new L2();
26+
_nameActivationMap["L1"] = new L1L2();
27+
}
28+
29+
public IRegularizer L1 => l1();
30+
31+
public IRegularizer L2 => l2();
32+
33+
public IRegularizer L1L2 => l1l2();
34+
35+
public IRegularizer GetRegularizerFromName(string name)
36+
{
37+
if (name == null)
38+
{
39+
throw new Exception($"Regularizer name cannot be null");
40+
}
41+
if (!_nameActivationMap.TryGetValue(name, out var res))
42+
{
43+
throw new Exception($"Regularizer {name} not found");
44+
}
45+
else
46+
{
47+
return res;
48+
}
49+
}
1650
}
1751
}

test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

+48
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.VisualStudio.TestPlatform.Utilities;
22
using Microsoft.VisualStudio.TestTools.UnitTesting;
33
using Newtonsoft.Json.Linq;
4+
using System.Collections.Generic;
45
using System.Linq;
56
using System.Xml.Linq;
67
using Tensorflow.Keras.Engine;
@@ -129,6 +130,53 @@ public void TestModelBeforeTF2_5()
129130
}
130131

131132

133+
[TestMethod]
134+
public void BiasRegularizerSaveAndLoad()
135+
{
136+
var savemodel = keras.Sequential(new List<ILayer>()
137+
{
138+
tf.keras.layers.InputLayer((227, 227, 3)),
139+
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
140+
tf.keras.layers.BatchNormalization(),
141+
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
142+
143+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
144+
tf.keras.layers.BatchNormalization(),
145+
146+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
147+
tf.keras.layers.BatchNormalization(),
148+
149+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
150+
tf.keras.layers.BatchNormalization(),
151+
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
152+
153+
tf.keras.layers.Flatten(),
154+
155+
tf.keras.layers.Dense(1000, activation: "linear"),
156+
tf.keras.layers.Softmax(1)
157+
});
158+
159+
savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
160+
161+
var num_epochs = 1;
162+
var batch_size = 8;
163+
164+
var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
165+
166+
savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs);
167+
168+
savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf");
169+
170+
var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load");
171+
loadModel.summary();
172+
173+
loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
174+
175+
var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
176+
177+
loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs);
178+
}
179+
132180

133181
[TestMethod]
134182
public void CreateConcatenateModelSaveAndLoad()

0 commit comments

Comments
 (0)