Skip to content

chore: miscellaneous fixes for handling graph breaks #3488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def save(
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
retrace: bool = False,
pickle_protocol: int = 2,
) -> None:
"""
Save the model to disk in the specified output format.
Expand All @@ -596,6 +597,7 @@ def save(
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models (eg: SAM2)
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
Expand Down Expand Up @@ -668,7 +670,9 @@ def save(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
torch.export.save(exp_program, file_path)
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)
else:
if arg_inputs is None:
raise ValueError(
Expand All @@ -680,4 +684,6 @@ def save(
kwargs=kwarg_inputs,
strict=False,
)
torch.export.save(exp_program, file_path)
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/truncate_double.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def repair_double_inputs(

# If the data type of the input is long/double, insert necessary
# casts to replace the operation
if param.dtype == torch.float64:
if isinstance(param, torch.Tensor) and param.dtype == torch.float64:
# Ensure outputs are only repaired once per submodule to avoid
# unnecessary ops showing up in the graph
if not repaired_outputs_once:
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
"""
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
"""
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
if isinstance(tensor, (torch.Tensor, FakeTensor)):
return tensor.dtype
elif isinstance(tensor, (int, float, bool)):
return torch.tensor(tensor).dtype
elif isinstance(tensor, torch.SymInt):
return torch.int64
Expand Down Expand Up @@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_meta.dtype))
elif isinstance(output_meta, torch.SymInt):
output_dtypes.append(dtype.int64)
elif "tensor_meta" in output.meta:
output_meta = output.meta["tensor_meta"]
output_dtypes.append(dtype._from(output_meta.dtype))
Expand Down
Loading