Skip to content

Commit b21a58a

Browse files
Merge branch 'SciSharp:master' into alnovi/gradient_more_tests
2 parents 18db147 + b72c803 commit b21a58a

File tree

11 files changed

+294
-78
lines changed

11 files changed

+294
-78
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1-
namespace Tensorflow.Keras
1+
using Newtonsoft.Json;
2+
using System.Collections.Generic;
3+
using Tensorflow.Keras.Saving.Common;
4+
5+
namespace Tensorflow.Keras
26
{
3-
public interface IRegularizer
4-
{
5-
Tensor Apply(RegularizerArgs args);
6-
}
7+
[JsonConverter(typeof(CustomizedRegularizerJsonConverter))]
8+
public interface IRegularizer
9+
{
10+
[JsonProperty("class_name")]
11+
string ClassName { get; }
12+
[JsonProperty("config")]
13+
IDictionary<string, object> Config { get; }
14+
Tensor Apply(RegularizerArgs args);
15+
}
16+
17+
public interface IRegularizerApi
18+
{
19+
IRegularizer GetRegularizerFromName(string name);
20+
IRegularizer L1 { get; }
21+
IRegularizer L2 { get; }
22+
IRegularizer L1L2 { get; }
23+
}
24+
725
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Tensorflow.Operations.Regularizers;
7+
8+
namespace Tensorflow.Keras.Saving.Common
9+
{
10+
class RegularizerInfo
11+
{
12+
public string class_name { get; set; }
13+
public JObject config { get; set; }
14+
}
15+
16+
public class CustomizedRegularizerJsonConverter : JsonConverter
17+
{
18+
public override bool CanConvert(Type objectType)
19+
{
20+
return objectType == typeof(IRegularizer);
21+
}
22+
23+
public override bool CanRead => true;
24+
25+
public override bool CanWrite => true;
26+
27+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
28+
{
29+
var regularizer = value as IRegularizer;
30+
if (regularizer is null)
31+
{
32+
JToken.FromObject(null).WriteTo(writer);
33+
return;
34+
}
35+
JToken.FromObject(new RegularizerInfo()
36+
{
37+
class_name = regularizer.ClassName,
38+
config = JObject.FromObject(regularizer.Config)
39+
}, serializer).WriteTo(writer);
40+
}
41+
42+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
43+
{
44+
var info = serializer.Deserialize<RegularizerInfo>(reader);
45+
if (info is null)
46+
{
47+
return null;
48+
}
49+
return info.class_name switch
50+
{
51+
"L1L2" => new L1L2 (info.config["l1"].ToObject<float>(), info.config["l2"].ToObject<float>()),
52+
"L1" => new L1(info.config["l1"].ToObject<float>()),
53+
"L2" => new L2(info.config["l2"].ToObject<float>()),
54+
};
55+
}
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1 : IRegularizer
8+
{
9+
float _l1;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L1";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L1(float l1 = 0.01f)
16+
{
17+
// l1 = 0.01 if l1 is None else l1
18+
// validate_float_arg(l1, name = "l1")
19+
// self.l1 = ops.convert_to_tensor(l1)
20+
this._l1 = l1;
21+
22+
_config = new();
23+
_config["l1"] = _l1;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l1 * ops.sum(ops.absolute(x))
30+
return _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
31+
}
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L1L2 : IRegularizer
8+
{
9+
float _l1;
10+
float _l2;
11+
private readonly Dictionary<string, object> _config;
12+
13+
public string ClassName => "L1L2";
14+
public virtual IDictionary<string, object> Config => _config;
15+
16+
public L1L2(float l1 = 0.0f, float l2 = 0.0f)
17+
{
18+
//l1 = 0.0 if l1 is None else l1
19+
//l2 = 0.0 if l2 is None else l2
20+
// validate_float_arg(l1, name = "l1")
21+
// validate_float_arg(l2, name = "l2")
22+
23+
// self.l1 = l1
24+
// self.l2 = l2
25+
this._l1 = l1;
26+
this._l2 = l2;
27+
28+
_config = new();
29+
_config["l1"] = l1;
30+
_config["l2"] = l2;
31+
}
32+
33+
public Tensor Apply(RegularizerArgs args)
34+
{
35+
//regularization = ops.convert_to_tensor(0.0, dtype = x.dtype)
36+
//if self.l1:
37+
// regularization += self.l1 * ops.sum(ops.absolute(x))
38+
//if self.l2:
39+
// regularization += self.l2 * ops.sum(ops.square(x))
40+
//return regularization
41+
42+
Tensor regularization = tf.constant(0.0, args.X.dtype);
43+
regularization += _l1 * math_ops.reduce_sum(math_ops.abs(args.X));
44+
regularization += _l2 * math_ops.reduce_sum(math_ops.square(args.X));
45+
return regularization;
46+
}
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
3+
using Tensorflow.Keras;
4+
5+
namespace Tensorflow.Operations.Regularizers
6+
{
7+
public class L2 : IRegularizer
8+
{
9+
float _l2;
10+
private readonly Dictionary<string, object> _config;
11+
12+
public string ClassName => "L2";
13+
public virtual IDictionary<string, object> Config => _config;
14+
15+
public L2(float l2 = 0.01f)
16+
{
17+
// l2 = 0.01 if l2 is None else l2
18+
// validate_float_arg(l2, name = "l2")
19+
// self.l2 = l2
20+
this._l2 = l2;
21+
22+
_config = new();
23+
_config["l2"] = _l2;
24+
}
25+
26+
27+
public Tensor Apply(RegularizerArgs args)
28+
{
29+
//return self.l2 * ops.sum(ops.square(x))
30+
return _l2 * math_ops.reduce_sum(math_ops.square(args.X));
31+
}
32+
}
33+
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs

+5-9
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,19 @@ private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concre
8888

8989
public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
9090
{
91-
Trackable get_merged_trackable(Trackable x)
91+
void merged_trackable(Trackable x)
9292
{
9393
// TODO: complete it with new definitions `Asset` and `TrackableConstant`.
94-
return x;
9594
}
95+
9696
var trackable_objects = base.breadth_first_traversal();
9797

9898
foreach(var obj in _children_cache.Keys)
9999
{
100100
// skip the deletion of cache (maybe do it later).
101101
foreach(var pair in _children_cache[obj])
102102
{
103-
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
103+
merged_trackable(pair.Value);
104104
}
105105
}
106106

@@ -109,15 +109,11 @@ Trackable get_merged_trackable(Trackable x)
109109

110110
public List<(string, Trackable)> list_dependencies(Trackable obj)
111111
{
112-
IDictionary<string, Trackable> children;
113-
if (!_children_cache.ContainsKey(obj))
112+
if (!_children_cache.TryGetValue(obj, out var children))
114113
{
115114
children= new Dictionary<string, Trackable>();
116115
}
117-
else
118-
{
119-
children= _children_cache[obj];
120-
}
116+
121117
List<(string, Trackable)> res = new();
122118
foreach(var pair in obj.deserialization_dependencies(children))
123119
{
+47-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,51 @@
1-
namespace Tensorflow.Keras
1+
using Tensorflow.Operations.Regularizers;
2+
3+
namespace Tensorflow.Keras
24
{
3-
public class Regularizers
5+
public class Regularizers: IRegularizerApi
6+
{
7+
private static Dictionary<string, IRegularizer> _nameActivationMap;
8+
9+
public IRegularizer l1(float l1 = 0.01f)
10+
=> new L1(l1);
11+
public IRegularizer l2(float l2 = 0.01f)
12+
=> new L2(l2);
13+
14+
//From TF source
15+
//# The default value for l1 and l2 are different from the value in l1_l2
16+
//# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
17+
//# and no l1 penalty.
18+
public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f)
19+
=> new L1L2(l1, l2);
20+
21+
static Regularizers()
422
{
5-
public IRegularizer l2(float l2 = 0.01f)
6-
=> new L2(l2);
23+
_nameActivationMap = new Dictionary<string, IRegularizer>();
24+
_nameActivationMap["L1"] = new L1();
25+
_nameActivationMap["L1"] = new L2();
26+
_nameActivationMap["L1"] = new L1L2();
727
}
28+
29+
public IRegularizer L1 => l1();
30+
31+
public IRegularizer L2 => l2();
32+
33+
public IRegularizer L1L2 => l1l2();
34+
35+
public IRegularizer GetRegularizerFromName(string name)
36+
{
37+
if (name == null)
38+
{
39+
throw new Exception($"Regularizer name cannot be null");
40+
}
41+
if (!_nameActivationMap.TryGetValue(name, out var res))
42+
{
43+
throw new Exception($"Regularizer {name} not found");
44+
}
45+
else
46+
{
47+
return res;
48+
}
49+
}
50+
}
851
}

src/TensorFlowNET.Keras/Regularizers/L1.cs

-19
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L1L2.cs

-24
This file was deleted.

src/TensorFlowNET.Keras/Regularizers/L2.cs

-17
This file was deleted.

0 commit comments

Comments
 (0)