Skip to content

Commit 0392027

Browse files
authored
Merge pull request #1217 from novikov-alexander/alnovi/gradient_more_tests
test: more gradient optimizer tests
2 parents b72c803 + 483ac82 commit 0392027

File tree

3 files changed

+182
-23
lines changed

3 files changed

+182
-23
lines changed

src/TensorFlowNET.Core/Tensors/tensor_util.cs

+32-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*****************************************************************************
1+
/*****************************************************************************
22
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135135
TF_DataType.TF_QINT32
136136
};
137137

138+
private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
139+
{
140+
var rows = inputArray.GetLength(0);
141+
var cols = inputArray.GetLength(1);
142+
var outputArray = new TOut[rows, cols];
143+
144+
for (var i = 0; i < rows; i++)
145+
{
146+
for (var j = 0; j < cols; j++)
147+
{
148+
outputArray[i, j] = converter(inputArray[i, j]);
149+
}
150+
}
151+
152+
return outputArray;
153+
}
154+
138155
/// <summary>
139156
/// Create a TensorProto, invoked in graph mode
140157
/// </summary>
@@ -157,19 +174,21 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
157174
else if(origin_dtype != dtype)
158175
{
159176
var new_system_dtype = dtype.as_system_dtype();
160-
if (values is long[] long_values)
161-
{
162-
if (dtype == TF_DataType.TF_INT32)
163-
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
164-
}
165-
else if (values is double[] double_values)
177+
178+
values = values switch
166179
{
167-
if (dtype == TF_DataType.TF_FLOAT)
168-
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
169-
}
170-
else
171-
values = Convert.ChangeType(values, new_system_dtype);
172-
180+
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
181+
long[] longValues => values,
182+
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
183+
float[] floatValues => values,
184+
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
185+
float[,] float2DValues => values,
186+
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
187+
double[] doubleValues => values,
188+
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
189+
double[,] double2DValues => values,
190+
_ => Convert.ChangeType(values, new_system_dtype),
191+
};
173192
dtype = values.GetDataType();
174193
}
175194

test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

+114-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Linq;
34
using Tensorflow;
45
using Tensorflow.NumPy;
56
using static Tensorflow.Binding;
@@ -67,6 +68,51 @@ public void TestBasic()
6768
TestBasic<double>();
6869
}
6970

71+
private void TestMinimizeResourceVariable<T>() where T : struct
72+
{
73+
var dtype = GetTypeForNumericType<T>();
74+
75+
// train.GradientDescentOptimizer is V1 only API.
76+
tf.Graph().as_default();
77+
using (var sess = self.cached_session())
78+
{
79+
var var0 = tf.Variable(new[,] { { 1.0f, 2.0f } }, dtype: dtype);
80+
var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype);
81+
var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype);
82+
83+
var pred = math_ops.matmul(var0, x) + var1;
84+
var loss = pred * pred;
85+
var sgd_op = tf.train.GradientDescentOptimizer(1.0f).minimize(loss);
86+
87+
var global_variables = tf.global_variables_initializer();
88+
sess.run(global_variables);
89+
90+
sess.run(new[] { var0, var1 });
91+
// Fetch params to validate initial values
92+
self.assertAllCloseAccordingToType<T>(new[,] { { 1.0, 2.0 } }, self.evaluate<T[,]>(var0));
93+
self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate<T[]>(var1));
94+
// Run 1 step of sgd
95+
sgd_op.run();
96+
// Validate updated params
97+
var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0;
98+
var np_grad = 2 * np_pred;
99+
self.assertAllCloseAccordingToType(
100+
new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } },
101+
self.evaluate<T[,]>(var0));
102+
self.assertAllCloseAccordingToType(
103+
new[] { 3.0 - np_grad },
104+
self.evaluate<T[]>(var1));
105+
}
106+
}
107+
108+
[TestMethod]
109+
public void TestMinimizeResourceVariable()
110+
{
111+
//TODO: add np.half
112+
TestMinimizeResourceVariable<float>();
113+
TestMinimizeResourceVariable<double>();
114+
}
115+
70116
private void TestTensorLearningRate<T>() where T : struct
71117
{
72118
var dtype = GetTypeForNumericType<T>();
@@ -115,5 +161,72 @@ public void TestTensorLearningRate()
115161
TestTensorLearningRate<float>();
116162
TestTensorLearningRate<double>();
117163
}
164+
165+
public void TestGradWrtRef<T>() where T : struct
166+
{
167+
var dtype = GetTypeForNumericType<T>();
168+
169+
var graph = tf.Graph().as_default();
170+
using (var sess = self.cached_session())
171+
{
172+
var opt = tf.train.GradientDescentOptimizer(3.0f);
173+
var values = new[] { 1.0, 3.0 };
174+
var vars_ = values.Select(
175+
v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1
176+
).ToList();
177+
var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_);
178+
sess.run(tf.global_variables_initializer());
179+
foreach (var (grad, _) in grads_and_vars)
180+
self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate<T[]>(grad));
181+
182+
}
183+
}
184+
185+
[TestMethod]
186+
public void TestGradWrtRef()
187+
{
188+
TestGradWrtRef<float>();
189+
TestGradWrtRef<double>();
190+
}
191+
192+
public void TestWithGlobalStep<T>() where T : struct
193+
{
194+
var dtype = GetTypeForNumericType<T>();
195+
196+
tf.Graph().as_default();
197+
using (var sess = self.cached_session())
198+
{
199+
var global_step = tf.Variable(0, trainable: false);
200+
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
201+
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
202+
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
203+
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
204+
var grads_and_vars = new[] {
205+
Tuple.Create(grads0, var0 as IVariableV1),
206+
Tuple.Create(grads1, var1 as IVariableV1)
207+
};
208+
var sgd_op = tf.train.GradientDescentOptimizer(3.0f)
209+
.apply_gradients(grads_and_vars, global_step: global_step);
210+
211+
sess.run(tf.global_variables_initializer());
212+
// Fetch params to validate initial values
213+
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
214+
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
215+
// Run 1 step of sgd
216+
sgd_op.run();
217+
// Validate updated params and global_step
218+
self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate<T[]>(var0));
219+
self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate<T[]>(var1));
220+
Assert.AreEqual(1, self.evaluate<int>(global_step));
221+
}
222+
223+
}
224+
225+
[TestMethod]
226+
public void TestWithGlobalStep()
227+
{
228+
TestWithGlobalStep<float>();
229+
TestWithGlobalStep<double>();
230+
}
118231
}
119232
}

test/Tensorflow.UnitTest/PythonTest.cs

+36-9
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ public int Compare(object? x, object? y)
175175
return 1;
176176
}
177177

178-
var a = (double)x;
179-
var b = (double)y;
178+
var a = Convert.ToDouble(x);
179+
var b = Convert.ToDouble(y);
180180

181181
double delta = Math.Abs(a - b);
182182
if (delta < _epsilon)
@@ -187,6 +187,19 @@ public int Compare(object? x, object? y)
187187
}
188188
}
189189

190+
public void assertAllCloseAccordingToType<T>(
191+
double[,] expected,
192+
T[,] given,
193+
double eps = 1e-6,
194+
float float_eps = 1e-6f)
195+
{
196+
Assert.AreEqual(expected.GetLength(0), given.GetLength(0));
197+
Assert.AreEqual(expected.GetLength(1), given.GetLength(1));
198+
199+
var flattenGiven = given.Cast<T>().ToArray();
200+
assertAllCloseAccordingToType(expected, flattenGiven, eps, float_eps);
201+
}
202+
190203
public void assertAllCloseAccordingToType<T>(
191204
ICollection expected,
192205
ICollection<T> given,
@@ -267,21 +280,35 @@ public T evaluate<T>(Tensor tensor)
267280
{
268281
var sess = tf.get_default_session();
269282
var ndarray = tensor.eval(sess);
270-
if (typeof(T) == typeof(double)
271-
|| typeof(T) == typeof(float)
272-
|| typeof(T) == typeof(int))
283+
284+
if (typeof(T) == typeof(int))
285+
{
286+
int i = ndarray;
287+
result = i;
288+
}
289+
else if (typeof(T) == typeof(float))
290+
{
291+
float f = ndarray;
292+
result = f;
293+
}
294+
else if (typeof(T) == typeof(double))
273295
{
274-
result = Convert.ChangeType(ndarray, typeof(T));
296+
double d = ndarray;
297+
result = d;
275298
}
276-
else if (typeof(T) == typeof(double[]))
299+
else if (
300+
typeof(T) == typeof(double[])
301+
|| typeof(T) == typeof(double[,]))
277302
{
278303
result = ndarray.ToMultiDimArray<double>();
279304
}
280-
else if (typeof(T) == typeof(float[]))
305+
else if (typeof(T) == typeof(float[])
306+
|| typeof(T) == typeof(float[,]))
281307
{
282308
result = ndarray.ToMultiDimArray<float>();
283309
}
284-
else if (typeof(T) == typeof(int[]))
310+
else if (typeof(T) == typeof(int[])
311+
|| typeof(T) == typeof(int[,]))
285312
{
286313
result = ndarray.ToMultiDimArray<int>();
287314
}

0 commit comments

Comments
 (0)