diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index 7884c85b..3117dc8e 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -1031,15 +1031,46 @@ def get_scalar_value(tensor): return list(tensor.values)[0] def fold_shape(tensor): - inp = get_input(get_producer(tensor, "Shape")) + """Returns the input tensor shape if available, otherwise returns None. + Handles Shape node with optional 'start' and 'end' attributes (opset 15+). + """ + shape_node = get_producer(tensor, "Shape") + inp = get_input(shape_node) if inp is None: return None if inp.shape is None or misc.is_dynamic_shape(inp.shape): return None - return np.array(inp.shape, dtype=np.int64) + + full_shape = inp.shape + num_dims = len(full_shape) + + # Get start and end attributes (default: start=0, end=None means full shape) + start = shape_node.attrs.get("start", 0) + end = shape_node.attrs.get("end", None) + + # Handle negative indices + if start < 0: + start = num_dims + start + if end is None: + end = num_dims + elif end < 0: + end = num_dims + end + + # Clamp to valid range + start = max(0, min(start, num_dims)) + end = max(0, min(end, num_dims)) + + if start > end: + return None + + target_shape = full_shape[start:end] + return np.array(target_shape, dtype=np.int64) def fold_shape_gather(tensor): + """Retrieves and returns the shape of the input tensor as a NumPy array, otherwise returns None. + Handles Shape node with optional 'start' and 'end' attributes (opset 15+). + """ gather = get_producer(tensor, "Gather") if gather is None: return None @@ -1047,68 +1078,132 @@ def fold_shape_gather(tensor): data = gather.inputs[0] indices_tensor = gather.inputs[1] - inp = get_input(get_producer(data, "Shape")) + shape_node = get_producer(data, "Shape") + inp = get_input(shape_node) if inp is None or inp.shape is None: return None if not isinstance(indices_tensor, Constant): return None + # Get the shape slice from Shape node (considering start/end attributes) + full_shape = inp.shape + num_dims = len(full_shape) + + start = shape_node.attrs.get("start", 0) + end = shape_node.attrs.get("end", None) + + if start < 0: + start = num_dims + start + if end is None: + end = num_dims + elif end < 0: + end = num_dims + end + + start = max(0, min(start, num_dims)) + end = max(0, min(end, num_dims)) + + if start > end: + return None + + shape_slice = full_shape[start:end] + indices = indices_tensor.values if not indices.shape: # Scalar-case - shape = inp.shape[int(indices)] + idx = int(indices) + # Handle negative indices relative to shape_slice + if idx < 0: + idx = len(shape_slice) + idx + if idx < 0 or idx >= len(shape_slice): + return None + shape = shape_slice[idx] if misc.is_dynamic_dimension(shape): return None else: - shape = [inp.shape[index] for index in indices] + shape = [] + for index in indices: + idx = int(index) + # Handle negative indices relative to shape_slice + if idx < 0: + idx = len(shape_slice) + idx + if idx < 0 or idx >= len(shape_slice): + return None + shape.append(shape_slice[idx]) if misc.is_dynamic_shape(shape): return None return np.array(shape, dtype=np.int64) def fold_shape_slice(tensor): - slice = get_producer(tensor, "Slice") - if slice is None: + """Fold tensor shape slice information into a NumPy array of int64 type. + Handles Shape node with optional 'start' and 'end' attributes (opset 15+). + """ + slice_node = get_producer(tensor, "Slice") + if slice_node is None: return None - data = slice.inputs[0] + data = slice_node.inputs[0] - if len(slice.inputs) >= 3: - starts, ends = slice.inputs[1:3] + if len(slice_node.inputs) >= 3: + starts, ends = slice_node.inputs[1:3] if any(not isinstance(t, Constant) for t in [starts, ends]): return None starts, ends = get_scalar_value(starts), get_scalar_value(ends) - elif "starts" in slice.attrs and "ends" in slice.attrs: - starts, ends = slice.attrs["starts"][0], slice.attrs["ends"][0] + elif "starts" in slice_node.attrs and "ends" in slice_node.attrs: + starts, ends = slice_node.attrs["starts"][0], slice_node.attrs["ends"][0] else: return None - inp = get_input(get_producer(data, "Shape")) + shape_node = get_producer(data, "Shape") + inp = get_input(shape_node) if inp is None or inp.shape is None: return None # For shape tensors, we can only slice on the 0th dimension. - if len(slice.inputs) > 3: - axes = slice.inputs[3] + if len(slice_node.inputs) > 3: + axes = slice_node.inputs[3] if not isinstance(axes, Constant): return None if get_scalar_value(axes) != 0: return None - elif "axes" in slice.attrs: - if slice.attrs["axes"][0] != 0: + elif "axes" in slice_node.attrs: + if slice_node.attrs["axes"][0] != 0: return None steps = 1 - if len(slice.inputs) > 4: - steps = slice.inputs[4] + if len(slice_node.inputs) > 4: + steps = slice_node.inputs[4] if not isinstance(steps, Constant): return None steps = get_scalar_value(steps) - elif "steps" in slice.attrs: - steps = slice.attrs["steps"][0] + elif "steps" in slice_node.attrs: + steps = slice_node.attrs["steps"][0] + + # Get the shape slice from Shape node (considering start/end attributes) + full_shape = inp.shape + num_dims = len(full_shape) + + shape_start = shape_node.attrs.get("start", 0) + shape_end = shape_node.attrs.get("end", None) + + if shape_start < 0: + shape_start = num_dims + shape_start + if shape_end is None: + shape_end = num_dims + elif shape_end < 0: + shape_end = num_dims + shape_end + + shape_start = max(0, min(shape_start, num_dims)) + shape_end = max(0, min(shape_end, num_dims)) + + if shape_start > shape_end: + return None + + shape_slice = full_shape[shape_start:shape_end] - shape = inp.shape[starts:ends:steps] + # Apply the Slice operation on the shape_slice + shape = shape_slice[starts:ends:steps] if misc.is_dynamic_shape(shape): return None