Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 117 additions & 22 deletions tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,84 +1031,179 @@ def get_scalar_value(tensor):
return list(tensor.values)[0]

def fold_shape(tensor):
inp = get_input(get_producer(tensor, "Shape"))
"""Returns the input tensor shape if available, otherwise returns None.
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
"""
shape_node = get_producer(tensor, "Shape")
inp = get_input(shape_node)
if inp is None:
return None

if inp.shape is None or misc.is_dynamic_shape(inp.shape):
return None
return np.array(inp.shape, dtype=np.int64)

full_shape = inp.shape
num_dims = len(full_shape)

# Get start and end attributes (default: start=0, end=None means full shape)
start = shape_node.attrs.get("start", 0)
end = shape_node.attrs.get("end", None)

# Handle negative indices
if start < 0:
start = num_dims + start
if end is None:
end = num_dims
elif end < 0:
end = num_dims + end

# Clamp to valid range
start = max(0, min(start, num_dims))
end = max(0, min(end, num_dims))

if start > end:
return None

target_shape = full_shape[start:end]
return np.array(target_shape, dtype=np.int64)

def fold_shape_gather(tensor):
"""Retrieves and returns the shape of the input tensor as a NumPy array, otherwise returns None.
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
"""
gather = get_producer(tensor, "Gather")
if gather is None:
return None

data = gather.inputs[0]
indices_tensor = gather.inputs[1]

inp = get_input(get_producer(data, "Shape"))
shape_node = get_producer(data, "Shape")
inp = get_input(shape_node)
if inp is None or inp.shape is None:
return None

if not isinstance(indices_tensor, Constant):
return None

# Get the shape slice from Shape node (considering start/end attributes)
full_shape = inp.shape
num_dims = len(full_shape)

start = shape_node.attrs.get("start", 0)
end = shape_node.attrs.get("end", None)

if start < 0:
start = num_dims + start
if end is None:
end = num_dims
elif end < 0:
end = num_dims + end

start = max(0, min(start, num_dims))
end = max(0, min(end, num_dims))

if start > end:
return None

shape_slice = full_shape[start:end]

indices = indices_tensor.values
if not indices.shape: # Scalar-case
shape = inp.shape[int(indices)]
idx = int(indices)
# Handle negative indices relative to shape_slice
if idx < 0:
idx = len(shape_slice) + idx
if idx < 0 or idx >= len(shape_slice):
return None
shape = shape_slice[idx]
if misc.is_dynamic_dimension(shape):
return None
else:
shape = [inp.shape[index] for index in indices]
shape = []
for index in indices:
idx = int(index)
# Handle negative indices relative to shape_slice
if idx < 0:
idx = len(shape_slice) + idx
if idx < 0 or idx >= len(shape_slice):
return None
shape.append(shape_slice[idx])
if misc.is_dynamic_shape(shape):
return None

return np.array(shape, dtype=np.int64)

def fold_shape_slice(tensor):
slice = get_producer(tensor, "Slice")
if slice is None:
"""Fold tensor shape slice information into a NumPy array of int64 type.
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
"""
slice_node = get_producer(tensor, "Slice")
if slice_node is None:
return None

data = slice.inputs[0]
data = slice_node.inputs[0]

if len(slice.inputs) >= 3:
starts, ends = slice.inputs[1:3]
if len(slice_node.inputs) >= 3:
starts, ends = slice_node.inputs[1:3]
if any(not isinstance(t, Constant) for t in [starts, ends]):
return None
starts, ends = get_scalar_value(starts), get_scalar_value(ends)
elif "starts" in slice.attrs and "ends" in slice.attrs:
starts, ends = slice.attrs["starts"][0], slice.attrs["ends"][0]
elif "starts" in slice_node.attrs and "ends" in slice_node.attrs:
starts, ends = slice_node.attrs["starts"][0], slice_node.attrs["ends"][0]
else:
return None

inp = get_input(get_producer(data, "Shape"))
shape_node = get_producer(data, "Shape")
inp = get_input(shape_node)
if inp is None or inp.shape is None:
return None

# For shape tensors, we can only slice on the 0th dimension.
if len(slice.inputs) > 3:
axes = slice.inputs[3]
if len(slice_node.inputs) > 3:
axes = slice_node.inputs[3]
if not isinstance(axes, Constant):
return None

if get_scalar_value(axes) != 0:
return None
elif "axes" in slice.attrs:
if slice.attrs["axes"][0] != 0:
elif "axes" in slice_node.attrs:
if slice_node.attrs["axes"][0] != 0:
return None

steps = 1
if len(slice.inputs) > 4:
steps = slice.inputs[4]
if len(slice_node.inputs) > 4:
steps = slice_node.inputs[4]
if not isinstance(steps, Constant):
return None
steps = get_scalar_value(steps)
elif "steps" in slice.attrs:
steps = slice.attrs["steps"][0]
elif "steps" in slice_node.attrs:
steps = slice_node.attrs["steps"][0]

# Get the shape slice from Shape node (considering start/end attributes)
full_shape = inp.shape
num_dims = len(full_shape)

shape_start = shape_node.attrs.get("start", 0)
shape_end = shape_node.attrs.get("end", None)

if shape_start < 0:
shape_start = num_dims + shape_start
if shape_end is None:
shape_end = num_dims
elif shape_end < 0:
shape_end = num_dims + shape_end

shape_start = max(0, min(shape_start, num_dims))
shape_end = max(0, min(shape_end, num_dims))

if shape_start > shape_end:
return None

shape_slice = full_shape[shape_start:shape_end]

shape = inp.shape[starts:ends:steps]
# Apply the Slice operation on the shape_slice
shape = shape_slice[starts:ends:steps]
if misc.is_dynamic_shape(shape):
return None

Expand Down