Skip to content

Commit a59ebae

Browse files
committed
Fix the errors caused by branch merge.
1 parent 1903700 commit a59ebae

File tree

6 files changed

+99
-53
lines changed

6 files changed

+99
-53
lines changed

src/TensorFlowNET.Core/Checkpoint/functional_saver.cs

+35-34
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<s
208208
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
209209
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));
210210

211-
// skip the process of device name because lack of API.
212211
string host_device;
213212
if (tensor.IsT0)
214213
{
@@ -218,6 +217,7 @@ public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<s
218217
{
219218
host_device = tensor.AsT1.device;
220219
}
220+
host_device = saveable_object_util.set_cpu0(host_device);
221221
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>());
222222
if (!internal_dict.ContainsKey(checkpoint_key))
223223
{
@@ -329,51 +329,52 @@ IDictionary<string, Operation> restore_func()
329329
{
330330
var restored_tensor_dict = saver.restore(file_prefix, options);
331331

332-
foreach(var pair in restored_tensor_dict)
333-
{
334-
var checkpoint_key = pair.Key;
335-
var slice_and_tensor = pair.Value;
336-
foreach(var item in slice_and_tensor)
332+
foreach (var pair in restored_tensor_dict)
337333
{
338-
var slice_spec = item.Key;
339-
var tensor = item.Value;
340-
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
341-
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
342-
if (!string.IsNullOrEmpty(slice_spec))
334+
var checkpoint_key = pair.Key;
335+
var slice_and_tensor = pair.Value;
336+
foreach (var item in slice_and_tensor)
343337
{
344-
if (!internal_dict.ContainsKey(checkpoint_key))
338+
var slice_spec = item.Key;
339+
var tensor = item.Value;
340+
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
341+
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
342+
if (!string.IsNullOrEmpty(slice_spec))
345343
{
346-
Dictionary<string, Tensor> dict = new();
347-
dict[slice_spec] = tensor;
348-
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
344+
if (!internal_dict.ContainsKey(checkpoint_key))
345+
{
346+
Dictionary<string, Tensor> dict = new();
347+
dict[slice_spec] = tensor;
348+
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
349+
}
350+
else
351+
{
352+
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
353+
}
349354
}
350355
else
351356
{
352-
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
357+
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
353358
}
354-
}
355-
else
356-
{
357-
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
358-
}
359-
restore_fn_input_count[restore_fn]--;
359+
restore_fn_input_count[restore_fn]--;
360360

361-
if (restore_fn_input_count[restore_fn] == 0)
362-
{
363-
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
364-
foreach(var input in restore_fn_inputs[restore_fn])
361+
if (restore_fn_input_count[restore_fn] == 0)
365362
{
366-
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
367-
}
368-
var ret = restore_fn.DynamicInvoke(restored_tensors);
369-
if(ret is IDictionary<string, Operation>)
370-
{
371-
var dict = (IDictionary<string, Operation>)ret;
372-
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
363+
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
364+
foreach (var input in restore_fn_inputs[restore_fn])
365+
{
366+
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
367+
}
368+
var ret = restore_fn.DynamicInvoke(restored_tensors);
369+
if (ret is IDictionary<string, Operation>)
370+
{
371+
var dict = (IDictionary<string, Operation>)ret;
372+
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
373+
}
373374
}
374375
}
375376
}
376-
}
377+
});
377378
}
378379

379380
foreach(var item in _registered_savers)

src/TensorFlowNET.Core/Contexts/Context.Device.cs

+8
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ public PhysicalDevice[] list_physical_devices(string device_type = null)
111111
return results.ToArray();
112112
}
113113

114+
public bool is_custom_device(string device_name)
115+
{
116+
return false;
117+
// TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs.
118+
//ensure_initialized();
119+
//return c_api.TFE_IsCustomDevice(_handle, device_name);
120+
}
121+
114122
public EagerDeviceContext device(string name)
115123
{
116124
return new EagerDeviceContext(this, name);

src/TensorFlowNET.Core/Eager/c_api.eager.cs

+3
Original file line numberDiff line numberDiff line change
@@ -483,5 +483,8 @@ public static extern SafeStatusHandle TFE_TapeGradient(IntPtr tape,
483483
IntPtr[] target, int target_size,
484484
IntPtr[] sources, int source_size,
485485
IntPtr[] outputs, int output_size);
486+
487+
[DllImport(TensorFlowLibName)]
488+
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name);
486489
}
487490
}

src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs

+16-10
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@ Func<Tensor> _read_variable_closure(BaseResourceVariable v)
4646
{
4747
return () =>
4848
{
49-
tf.device(v.Device);
50-
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
49+
return tf_with(ops.device(v.Device), _ =>
5150
{
52-
return null;
53-
}
54-
var x = v.read_value_no_copy();
55-
tf.device("/device:CPU:0");
56-
return array_ops.identity(x);
51+
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
52+
{
53+
return null;
54+
}
55+
var x = v.read_value_no_copy();
56+
return tf_with(ops.device("/device:CPU:0"), _ =>
57+
{
58+
return array_ops.identity(x);
59+
});
60+
});
5761
};
5862
}
5963

@@ -69,10 +73,12 @@ Func<Tensor> _read_variable_closure(BaseResourceVariable v)
6973
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
7074
{
7175
var restored_tensor = restored_tensors[0];
72-
tf.device(_var_device);
73-
restored_tensor = array_ops.identity(restored_tensor);
74-
return resource_variable_ops.shape_safe_assign_variable_handle(
76+
return tf_with(ops.device(_var_device), _ =>
77+
{
78+
restored_tensor = array_ops.identity(restored_tensor);
79+
return resource_variable_ops.shape_safe_assign_variable_handle(
7580
handle_op, _var_shape, restored_tensor);
81+
});
7682
}
7783
}
7884
}

src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs

+13
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
using System.Diagnostics;
2121
using System.Linq;
2222
using Tensorflow.Checkpoint;
23+
using Tensorflow.Contexts;
24+
using Tensorflow.Device;
2325
using Tensorflow.Operations.Activation;
2426
using Tensorflow.Train;
2527
using Tensorflow.Training;
@@ -406,6 +408,17 @@ public static OneOf<BaseResourceVariable, MySaveableObject> create_saveable_obje
406408
return factory(key);
407409
}
408410

411+
public static string set_cpu0(string device_string)
412+
{
413+
if (tf.Context.is_custom_device(device_string))
414+
{
415+
return device_string;
416+
}
417+
var parsed_device = DeviceSpec.from_string(device_string);
418+
parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0);
419+
return parsed_device.ToString();
420+
}
421+
409422
private static bool _tensor_comes_from_variable(object v)
410423
{
411424
return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type);

src/TensorFlowNET.Core/Variables/ResourceVariable.cs

+24-9
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,29 @@ private void _init_from_args(object initial_value = null,
124124

125125
if (_in_graph_mode)
126126
{
127+
// TODO(Rinne): deal with initializer_op.
128+
//if(initial_value is not null)
129+
//{
130+
// tf_with(ops.name_scope("Assign"), n =>
131+
// {
132+
// tf_with(ops.device(handle.Device), _ =>
133+
// {
134+
135+
// });
136+
// });
137+
//}
127138
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
128139
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;
129140

130141
ops.colocate_with(initializer_op);
131-
tf.device(handle.Device);
132-
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
133-
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
134-
_graph_element = gen_array_ops.identity(handle, name = "read");
135-
ops.add_to_collections<IVariableV1>(collections, this);
136-
_dtype = handle.dtype;
142+
tf_with(ops.device(handle.Device), _ =>
143+
{
144+
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
145+
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
146+
_graph_element = gen_array_ops.identity(handle, name = "read");
147+
ops.add_to_collections<IVariableV1>(collections, this);
148+
_dtype = handle.dtype;
149+
});
137150
}
138151
else
139152
{
@@ -149,9 +162,11 @@ private void _init_from_args(object initial_value = null,
149162
_graph_element = null;
150163
if (!string.IsNullOrEmpty(caching_device))
151164
{
152-
tf.device(caching_device);
153-
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
154-
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
165+
tf_with(ops.device(caching_device), _ =>
166+
{
167+
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
168+
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
169+
});
155170
}
156171
_dtype = _initial_value.dtype.as_base_dtype();
157172
// initial_value = _in_graph_mode ? initial_value : null;

0 commit comments

Comments
 (0)