@@ -208,7 +208,6 @@ public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<s
208
208
_keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] = restore_fn ;
209
209
_restore_fn_to_keys . SetDefault ( restore_fn , new List < ( string , string ) > ( ) ) . Add ( ( checkpoint_key , slice_spec ) ) ;
210
210
211
- // skip the process of device name because lack of API.
212
211
string host_device ;
213
212
if ( tensor . IsT0 )
214
213
{
@@ -218,6 +217,7 @@ public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<s
218
217
{
219
218
host_device = tensor . AsT1 . device ;
220
219
}
220
+ host_device = saveable_object_util . set_cpu0 ( host_device ) ;
221
221
var internal_dict = tensors_by_device . SetDefault ( host_device , new Dictionary < string , IDictionary < string , OneOf < Tensor , SaveSpec > > > ( ) ) ;
222
222
if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
223
223
{
@@ -329,51 +329,52 @@ IDictionary<string, Operation> restore_func()
329
329
{
330
330
var restored_tensor_dict = saver . restore ( file_prefix , options ) ;
331
331
332
- foreach ( var pair in restored_tensor_dict )
333
- {
334
- var checkpoint_key = pair . Key ;
335
- var slice_and_tensor = pair . Value ;
336
- foreach ( var item in slice_and_tensor )
332
+ foreach ( var pair in restored_tensor_dict )
337
333
{
338
- var slice_spec = item . Key ;
339
- var tensor = item . Value ;
340
- var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
341
- var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , OneOf < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
342
- if ( ! string . IsNullOrEmpty ( slice_spec ) )
334
+ var checkpoint_key = pair . Key ;
335
+ var slice_and_tensor = pair . Value ;
336
+ foreach ( var item in slice_and_tensor )
343
337
{
344
- if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
338
+ var slice_spec = item . Key ;
339
+ var tensor = item . Value ;
340
+ var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
341
+ var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , OneOf < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
342
+ if ( ! string . IsNullOrEmpty ( slice_spec ) )
345
343
{
346
- Dictionary < string , Tensor > dict = new ( ) ;
347
- dict [ slice_spec ] = tensor ;
348
- internal_dict [ checkpoint_key ] = OneOf < Tensor , IDictionary < string , Tensor > > . FromT1 ( dict ) ;
344
+ if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
345
+ {
346
+ Dictionary < string , Tensor > dict = new ( ) ;
347
+ dict [ slice_spec ] = tensor ;
348
+ internal_dict [ checkpoint_key ] = OneOf < Tensor , IDictionary < string , Tensor > > . FromT1 ( dict ) ;
349
+ }
350
+ else
351
+ {
352
+ internal_dict [ checkpoint_key ] . AsT1 [ slice_spec ] = tensor ;
353
+ }
349
354
}
350
355
else
351
356
{
352
- internal_dict [ checkpoint_key ] . AsT1 [ slice_spec ] = tensor ;
357
+ internal_dict [ checkpoint_key ] = OneOf < Tensor , IDictionary < string , Tensor > > . FromT0 ( tensor ) ;
353
358
}
354
- }
355
- else
356
- {
357
- internal_dict [ checkpoint_key ] = OneOf < Tensor , IDictionary < string , Tensor > > . FromT0 ( tensor ) ;
358
- }
359
- restore_fn_input_count [ restore_fn ] -- ;
359
+ restore_fn_input_count [ restore_fn ] -- ;
360
360
361
- if ( restore_fn_input_count [ restore_fn ] == 0 )
362
- {
363
- Dictionary < string , OneOf < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
364
- foreach ( var input in restore_fn_inputs [ restore_fn ] )
361
+ if ( restore_fn_input_count [ restore_fn ] == 0 )
365
362
{
366
- restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
367
- }
368
- var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
369
- if ( ret is IDictionary < string , Operation > )
370
- {
371
- var dict = ( IDictionary < string , Operation > ) ret ;
372
- restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
363
+ Dictionary < string , OneOf < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
364
+ foreach ( var input in restore_fn_inputs [ restore_fn ] )
365
+ {
366
+ restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
367
+ }
368
+ var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
369
+ if ( ret is IDictionary < string , Operation > )
370
+ {
371
+ var dict = ( IDictionary < string , Operation > ) ret ;
372
+ restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
373
+ }
373
374
}
374
375
}
375
376
}
376
- }
377
+ } ) ;
377
378
}
378
379
379
380
foreach ( var item in _registered_savers )
0 commit comments