diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index d2d70b96e5c5..93c696ef3ad1 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -541,8 +541,9 @@ def create_graph_object( assert np.all(shuffle_global_nids[1:] - shuffle_global_nids[:-1] == 1) shuffle_global_nid_range = (shuffle_global_nids[0], shuffle_global_nids[-1]) - # Determine the node ID ranges of different node types. +# Determine the node ID ranges of different node types. prev_last_id = last_ids.get(part_id - 1, 0) + max_last_id = prev_last_id # Initialize max_last_id before the loop for ntype_name in global_nid_ranges: ntype_id = ntypes_map[ntype_name] type_nids = shuffle_global_nids[ntype_ids == ntype_id] @@ -552,15 +553,17 @@ def create_graph_object( node_map_val[ntype_name].append( [int(type_nids[0]), int(type_nids[-1]) + 1] ) - last_id = th.tensor( - [max(prev_last_id, int(type_nids[-1]) + 1)], dtype=th.int64 - ) + # Update max_last_id with the maximum ID seen in this partition + max_last_id = max(max_last_id, int(type_nids[-1]) + 1) + + # Create the last_id tensor after the loop to ensure it's always initialized + last_id = th.tensor([max_last_id], dtype=th.int64) id_ntypes = list(global_nid_ranges.keys()) - + gather_last_ids = [ th.zeros(1, dtype=th.int64) for _ in range(dist.get_world_size()) ] - + dist.all_gather(gather_last_ids, last_id) prev_last_id = _update_node_map( node_map_val, gather_last_ids, id_ntypes, prev_last_id