Skip to content

Commit b3ce158

Browse files
Update tensor_util.cs
1 parent 43f43eb commit b3ce158

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

src/TensorFlowNET.Core/Tensors/tensor_util.cs

+27-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*****************************************************************************
1+
/*****************************************************************************
22
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135135
TF_DataType.TF_QINT32
136136
};
137137

138+
private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
139+
{
140+
var rows = inputArray.GetLength(0);
141+
var cols = inputArray.GetLength(1);
142+
var outputArray = new TOut[rows, cols];
143+
144+
for (var i = 0; i < rows; i++)
145+
{
146+
for (var j = 0; j < cols; j++)
147+
{
148+
outputArray[i, j] = converter(inputArray[i, j]);
149+
}
150+
}
151+
152+
return outputArray;
153+
}
154+
138155
/// <summary>
139156
/// Create a TensorProto, invoked in graph mode
140157
/// </summary>
@@ -157,19 +174,16 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
157174
else if(origin_dtype != dtype)
158175
{
159176
var new_system_dtype = dtype.as_system_dtype();
160-
if (values is long[] long_values)
161-
{
162-
if (dtype == TF_DataType.TF_INT32)
163-
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
164-
}
165-
else if (values is double[] double_values)
177+
178+
values = values switch
166179
{
167-
if (dtype == TF_DataType.TF_FLOAT)
168-
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
169-
}
170-
else
171-
values = Convert.ChangeType(values, new_system_dtype);
172-
180+
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
181+
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
182+
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
183+
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
184+
double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle),
185+
_ => Convert.ChangeType(values, new_system_dtype),
186+
};
173187
dtype = values.GetDataType();
174188
}
175189

0 commit comments

Comments
 (0)