@@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
441441 where it is located in the shard if it exists, or -1 if it's not in the shard.
442442 Used to determine the location of each entry in a different distributed configuration.
443443 """
444-
445- # Create an empty index for the global parameter.
446- index = torch .full (
447- parameter_meta .global_shape ,
448- - 1 ,
449- dtype = torch .int64 ,
450- device = device ,
451- )
452444 # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
453445 begin , end = self ._get_parameter_range_in_shard (parameter_name )
454446
455- buffer_index = parameter_meta .global_to_local (index , expand = True )
456- # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
457- # In that case, we work with a separate tensor to be copied back into `buffer_index`.
458- try :
459- buffer_index_flat = buffer_index .view (- 1 )
460- is_view = True
461- except RuntimeError :
462- buffer_index_flat = buffer_index .new_full ((buffer_index .numel (),), - 1 )
463- is_view = False
464-
465- # Copy the shard indices at their respective positions in the flat buffer index.
466- buffer_index_flat [
447+ # Create an empty local index to hold the local shard indices.
448+ buffer_index = torch .full_like (parameter_meta , - 1 , dtype = torch .int64 , device = device )
449+
450+ # Copy the shard indices at their respective positions in the buffer index.
451+ buffer_index .flatten ()[
467452 self ._index_buffer_to_param (
468453 self ._fsdp_dim .rank * self ._shard_size , parameter_name
469454 ) : self ._index_buffer_to_param ((self ._fsdp_dim .rank + 1 ) * self ._shard_size , parameter_name )
470455 ].copy_ (torch .arange (begin , end , dtype = torch .int64 , device = device ))
471456
472- # If needed, copy the flat buffer index back into the index.
473- if not is_view :
474- buffer_index .copy_ (buffer_index_flat .view_as (buffer_index ))
475-
476- return index
457+ # Create a global index from the local one.
458+ return parameter_meta .local_to_global_partial (buffer_index , - 1 )
477459
478460 def copy_shard_overlaps (
479461 self ,
0 commit comments