From 58e24b694bb4e2c67556662d435cbd45f9a6c506 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 22 Apr 2025 15:30:58 -0700 Subject: [PATCH] chore: miscallaneous fixes --- py/torch_tensorrt/_compile.py | 10 ++++++++-- py/torch_tensorrt/dynamo/conversion/truncate_double.py | 2 +- py/torch_tensorrt/dynamo/utils.py | 6 +++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index e9c5c3d622..b6567fd5c4 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -584,6 +584,7 @@ def save( arg_inputs: Optional[Sequence[torch.Tensor]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, retrace: bool = False, + pickle_protocol: int = 2, ) -> None: """ Save the model to disk in the specified output format. @@ -596,6 +597,7 @@ def save( output_format (str): Format to save the model. Options include exported_program | torchscript. retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now. + pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models (eg: SAM2) """ if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module @@ -668,7 +670,9 @@ def save( "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save." ) exp_program = export(module) - torch.export.save(exp_program, file_path) + torch.export.save( + exp_program, file_path, pickle_protocol=pickle_protocol + ) else: if arg_inputs is None: raise ValueError( @@ -680,4 +684,6 @@ def save( kwargs=kwarg_inputs, strict=False, ) - torch.export.save(exp_program, file_path) + torch.export.save( + exp_program, file_path, pickle_protocol=pickle_protocol + ) diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_double.py index b14ee95dab..51e35a7840 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_double.py @@ -195,7 +195,7 @@ def repair_double_inputs( # If the data type of the input is long/double, insert necessary # casts to replace the operation - if param.dtype == torch.float64: + if isinstance(param, torch.Tensor) and param.dtype == torch.float64: # Ensure outputs are only repaired once per submodule to avoid # unnecessary ops showing up in the graph if not repaired_outputs_once: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e4018ae95c..e75f5149ba 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) - """ Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64 """ - if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)): + if isinstance(tensor, (torch.Tensor, FakeTensor)): + return tensor.dtype + elif isinstance(tensor, (int, float, bool)): return torch.tensor(tensor).dtype elif isinstance(tensor, torch.SymInt): return torch.int64 @@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] output_dtypes.append(dtype.float32) else: output_dtypes.append(dtype._from(output_meta.dtype)) + elif isinstance(output_meta, torch.SymInt): + output_dtypes.append(dtype.int64) elif "tensor_meta" in output.meta: output_meta = output.meta["tensor_meta"] output_dtypes.append(dtype._from(output_meta.dtype))