1
- /*****************************************************************************
1
+ /*****************************************************************************
2
2
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3
3
4
4
Licensed under the Apache License, Version 2.0 (the "License");
@@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135
135
TF_DataType . TF_QINT32
136
136
} ;
137
137
138
+ private static TOut [ , ] ConvertArray2D < TIn , TOut > ( TIn [ , ] inputArray , Func < TIn , TOut > converter )
139
+ {
140
+ var rows = inputArray . GetLength ( 0 ) ;
141
+ var cols = inputArray . GetLength ( 1 ) ;
142
+ var outputArray = new TOut [ rows , cols ] ;
143
+
144
+ for ( var i = 0 ; i < rows ; i ++ )
145
+ {
146
+ for ( var j = 0 ; j < cols ; j ++ )
147
+ {
148
+ outputArray [ i , j ] = converter ( inputArray [ i , j ] ) ;
149
+ }
150
+ }
151
+
152
+ return outputArray ;
153
+ }
154
+
138
155
/// <summary>
139
156
/// Create a TensorProto, invoked in graph mode
140
157
/// </summary>
@@ -157,19 +174,16 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
157
174
else if ( origin_dtype != dtype )
158
175
{
159
176
var new_system_dtype = dtype . as_system_dtype ( ) ;
160
- if ( values is long [ ] long_values )
161
- {
162
- if ( dtype == TF_DataType . TF_INT32 )
163
- values = long_values . Select ( x => ( int ) Convert . ChangeType ( x , new_system_dtype ) ) . ToArray ( ) ;
164
- }
165
- else if ( values is double [ ] double_values )
177
+
178
+ values = values switch
166
179
{
167
- if ( dtype == TF_DataType . TF_FLOAT )
168
- values = double_values . Select ( x => ( float ) Convert . ChangeType ( x , new_system_dtype ) ) . ToArray ( ) ;
169
- }
170
- else
171
- values = Convert . ChangeType ( values , new_system_dtype ) ;
172
-
180
+ long [ ] longValues when dtype == TF_DataType . TF_INT32 => longValues . Select ( x => ( int ) x ) . ToArray ( ) ,
181
+ float [ ] floatValues when dtype == TF_DataType . TF_DOUBLE => floatValues . Select ( x => ( double ) x ) . ToArray ( ) ,
182
+ float [ , ] float2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( float2DValues , Convert . ToDouble ) ,
183
+ double [ ] doubleValues when dtype == TF_DataType . TF_FLOAT => doubleValues . Select ( x => ( float ) x ) . ToArray ( ) ,
184
+ double [ , ] double2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( double2DValues , Convert . ToSingle ) ,
185
+ _ => Convert . ChangeType ( values , new_system_dtype ) ,
186
+ } ;
173
187
dtype = values . GetDataType ( ) ;
174
188
}
175
189
0 commit comments