Skip to content

Commit b72c803

Browse files
authored
Merge pull request #1250 from SchoenTannenbaum/master
fix: regularizer serialization problem
2 parents 8775b0b + 5f9fce5 commit b72c803

File tree

10 files changed

+289
-69
lines changed

10 files changed

+289
-69
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1-
namespace Tensorflow.Keras
1+
using Newtonsoft.Json;
2+
using System.Collections.Generic;
3+
using Tensorflow.Keras.Saving.Common;
4+
5+
namespace Tensorflow.Keras
26
{
3-
public interface IRegularizer
4-
{
5-
Tensor Apply(RegularizerArgs args);
6-
}
7+
[JsonConverter(typeof(CustomizedRegularizerJsonConverter))]
8+
public interface IRegularizer
9+
{
10+
[JsonProperty("class_name")]
11+
string ClassName { get; }
12+
[JsonProperty("config")]
13+
IDictionary<string, object> Config { get; }
14+
Tensor Apply(RegularizerArgs args);
15+
}
16+
17+
public interface IRegularizerApi
18+
{
19+
IRegularizer GetRegularizerFromName(string name);
20+
IRegularizer L1 { get; }
21+
IRegularizer L2 { get; }
22+
IRegularizer L1L2 { get; }
23+
}
24+
725
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Tensorflow.Operations.Regularizers;
7+
8+
namespace Tensorflow.Keras.Saving.Common
9+
{
10+
class RegularizerInfo
11+
{
12+
public string class_name { get; set; }
13+
public JObject config { get; set; }
14+
}
15+
16+
public class CustomizedRegularizerJsonConverter : JsonConverter
17+
{
18+
public override bool CanConvert(Type objectType)
19+
{
20+
return objectType == typeof(IRegularizer);
21+
}
22+
23+
public override bool CanRead => true;
24+
25+
public override bool CanWrite => true;
26+
27+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
28+
{
29+
var regularizer = value as IRegularizer;
30+
if (regularizer is null)
31+
{
32+
JToken.FromObject(null).WriteTo(writer);
33+
return;
34+
}
35+
JToken.FromObject(new RegularizerInfo()
36+
{
37+
class_name = regularizer.ClassName,
38+
config = JObject.FromObject(regularizer.Config)
39+
}, serializer).WriteTo(writer);
40+
}
41+
42+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
43+
{
44+
var info = serializer.Deserialize<RegularizerInfo>(reader);
45+
if (info is null)
46+
{
47+
return null;
48+
}
49+
return info.class_name switch
50+
{
51+
"L1L2" => new L1L2 (info.config["l1"].ToObject<float>(), info.config["l2"].ToObject<float>()),
52+
"L1" => new L1(info.config["l1"].ToObject<float>()),
53+
"L2" => new L2(info.config["l2"].ToObject<float>()),
54+
};
55+
}
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1 : IRegularizer
8+
{
9+
float _l1;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L1";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L1(float l1 = 0.01f)
16+
{
17+
// l1 = 0.01 if l1 is None else l1
18+
// validate_float_arg(l1, name = "l1")
19+
// self.l1 = ops.convert_to_tensor(l1)
20+
this._l1 = l1;
21+
22+
_config = new();
23+
_config["l1"] = _l1;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l1 * ops.sum(ops.absolute(x))
30+
return _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
31+
}
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1L2 : IRegularizer
8+
{
9+
float _l1;
10+
float _l2;
11+
private readonly Dictionary<string, object> _config;
12+
13+
public string ClassName => "L1L2";
14+
public virtual IDictionary<string, object> Config => _config;
15+
16+
public L1L2(float l1 = 0.0f, float l2 = 0.0f)
17+
{
18+
//l1 = 0.0 if l1 is None else l1
19+
//l2 = 0.0 if l2 is None else l2
20+
// validate_float_arg(l1, name = "l1")
21+
// validate_float_arg(l2, name = "l2")
22+
23+
// self.l1 = l1
24+
// self.l2 = l2
25+
this._l1 = l1;
26+
this._l2 = l2;
27+
28+
_config = new();
29+
_config["l1"] = l1;
30+
_config["l2"] = l2;
31+
}
32+
33+
public Tensor Apply(RegularizerArgs args)
34+
{
35+
//regularization = ops.convert_to_tensor(0.0, dtype = x.dtype)
36+
//if self.l1:
37+
// regularization += self.l1 * ops.sum(ops.absolute(x))
38+
//if self.l2:
39+
// regularization += self.l2 * ops.sum(ops.square(x))
40+
//return regularization
41+
42+
Tensor regularization = tf.constant(0.0, args.X.dtype);
43+
regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
44+
regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X));
45+
return regularization;
46+
}
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L2 : IRegularizer
8+
{
9+
float _l2;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L2";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L2(float l2 = 0.01f)
16+
{
17+
// l2 = 0.01 if l2 is None else l2
18+
// validate_float_arg(l2, name = "l2")
19+
// self.l2 = l2
20+
this._l2 = l2;
21+
22+
_config = new();
23+
_config["l2"] = _l2;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l2 * ops.sum(ops.square(x))
30+
return _l2 * math_ops.reduce_sum(math_ops.square(args.X));
31+
}
32+
}
33+
}
+47-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,51 @@
1-
namespace Tensorflow.Keras
1+
using Tensorflow.Operations.Regularizers;
2+
3+
namespace Tensorflow.Keras
24
{
3-
public class Regularizers
5+
public class Regularizers: IRegularizerApi
6+
{
7+
private static Dictionary<string, IRegularizer> _nameActivationMap;
8+
9+
public IRegularizer l1(float l1 = 0.01f)
10+
=> new L1(l1);
11+
public IRegularizer l2(float l2 = 0.01f)
12+
=> new L2(l2);
13+
14+
//From TF source
15+
//# The default value for l1 and l2 are different from the value in l1_l2
16+
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
17+
//# and no l1 penalty.
18+
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
19+
=> new L1L2(l1, l2);
20+
21+
static Regularizers()
422
{
5-
public IRegularizer l2(float l2 = 0.01f)
6-
=> new L2(l2);
23+
_nameActivationMap = new Dictionary<string, IRegularizer>();
24+
_nameActivationMap["L1"] = new L1();
25+
_nameActivationMap["L1"] = new L2();
26+
_nameActivationMap["L1"] = new L1L2();
727
}
28+
29+
public IRegularizer L1 => l1();
30+
31+
public IRegularizer L2 => l2();
32+
33+
public IRegularizer L1L2 => l1l2();
34+
35+
public IRegularizer GetRegularizerFromName(string name)
36+
{
37+
if (name == null)
38+
{
39+
throw new Exception($"Regularizer name cannot be null");
40+
}
41+
if (!_nameActivationMap.TryGetValue(name, out var res))
42+
{
43+
throw new Exception($"Regularizer {name} not found");
44+
}
45+
else
46+
{
47+
return res;
48+
}
49+
}
50+
}
851
}

src/TensorFlowNET.Keras/Regularizers/L1.cs

-19
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L1L2.cs

-24
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L2.cs

-17
This file was deleted.

test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

+48
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.VisualStudio.TestPlatform.Utilities;
22
using Microsoft.VisualStudio.TestTools.UnitTesting;
33
using Newtonsoft.Json.Linq;
4+
using System.Collections.Generic;
45
using System.Linq;
56
using System.Xml.Linq;
67
using Tensorflow.Keras.Engine;
@@ -129,6 +130,53 @@ public void TestModelBeforeTF2_5()
129130
}
130131

131132

133+
[TestMethod]
134+
public void BiasRegularizerSaveAndLoad()
135+
{
136+
var savemodel = keras.Sequential(new List<ILayer>()
137+
{
138+
tf.keras.layers.InputLayer((227, 227, 3)),
139+
tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
140+
tf.keras.layers.BatchNormalization(),
141+
tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
142+
143+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
144+
tf.keras.layers.BatchNormalization(),
145+
146+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
147+
tf.keras.layers.BatchNormalization(),
148+
149+
tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
150+
tf.keras.layers.BatchNormalization(),
151+
tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
152+
153+
tf.keras.layers.Flatten(),
154+
155+
tf.keras.layers.Dense(1000, activation: "linear"),
156+
tf.keras.layers.Softmax(1)
157+
});
158+
159+
savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
160+
161+
var num_epochs = 1;
162+
var batch_size = 8;
163+
164+
var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
165+
166+
savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs);
167+
168+
savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf");
169+
170+
var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load");
171+
loadModel.summary();
172+
173+
loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
174+
175+
var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16);
176+
177+
loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs);
178+
}
179+
132180

133181
[TestMethod]
134182
public void CreateConcatenateModelSaveAndLoad()

0 commit comments

Comments
 (0)