Skip to content

Commit

Permalink
SAM2: Use torch.export for VOS (#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Feb 20, 2025
1 parent 0293bcd commit 6bab4db
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 212 deletions.
1 change: 0 additions & 1 deletion examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
TASK_TYPES = ["amg", "sps", "mps"]


# NOTE: We have to declare a separate class, because torch.export demands it.
# We build this explicitly for the sole purpose of exporting _predict_masks
# We made sure _predict_masks is fullgraph=True compileable so it can be exported
# We must be sure to export using example args that are big enough and past
Expand Down
2 changes: 1 addition & 1 deletion examples/sam2_amg_server/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def main(
sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle"
)
if export_model != "":
if not Path(output_folder).is_dir():
if not Path(export_model).is_dir():
raise ValueError(f"Expected {export_model} to be a directory.")
print(f"Exporting model to {export_model}.")
from compile_export_utils import export_model as export_model_fn
Expand Down
271 changes: 271 additions & 0 deletions examples/sam2_vos_example/compile_export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import time
from pathlib import Path
from typing import Optional

import torch

from torchao._models.sam2.sam2_video_predictor import SAM2VideoPredictor

# Tools used to avoid compilation cold start and dynamo cache lookups
# We take the compiled model and export it using the largest
# inputs possible (to avoid recompilations).
# We track the largest size and fail if we size something larger
# We export every compile-able subregion after wrapping it into
# a class to make export happy.

TASK_TYPES = ["amg", "sps", "mps"]


class SAM2VideoPredictor_forward_sam_heads(torch.nn.Module):
def __init__(
self,
predictor: Optional[SAM2VideoPredictor],
batch_size=1,
aoti_compiled_model=None,
furious=False,
):
super().__init__()
self.predictor = predictor
self.batch_size = batch_size
self.aoti_compiled_model = aoti_compiled_model
self.furious = furious

def forward(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
assert mask_inputs is None
assert multimask_output
if self.predictor is None:
assert self.aoti_compiled_model is not None
return self.aoti_compiled_model(
backbone_features=backbone_features,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return self.predictor._forward_sam_heads(
backbone_features=backbone_features,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)


def aot_compile(
model_directory,
name,
fn,
sample_args,
sample_kwargs=None,
options=None,
overwrite=False,
):
path = Path(model_directory) / Path(f"{name}.pt2")
if path.exists() and not overwrite:
raise ValueError(f"{path} already exists and overwrite is {overwrite}")
print(f"Saving at {path=}")
if options is None:
options = {
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_inference

exported = export_for_inference(fn, sample_args, sample_kwargs)
output_path = torch._inductor.aoti_compile_and_package(
exported,
package_path=str(path),
inductor_configs=options,
)
return output_path


def aot_load(path):
return torch._export.aot_load(path, "cuda")


class FunctionModel(torch.nn.Module):
def __init__(self, module, fn_name):
super().__init__()
self.module = module
self.fn_name = fn_name

def forward(self, *args):
return getattr(self.module, self.fn_name)(*args)


def export_model(
predictor,
model_directory,
furious=False,
batch_size=1,
overwrite=False,
):
if furious:
set_furious(predictor)

example_input = torch.empty(batch_size, 3, 1024, 1024)
# example_input = example_input.to(predictor._image_dtype)
example_input = example_input.to(torch.bfloat16)
# example_input = (example_input.to(predictor.device),)
example_input = (example_input.to("cuda:0"),)
aot_compile(
model_directory,
"sam2_image_encoder_trunk",
predictor.image_encoder.trunk,
example_input,
overwrite=overwrite,
)

example_input_args = ()
example_input_kwargs = {
"backbone_features": torch.randn(
batch_size, 256, 64, 64, dtype=torch.float32, device="cuda"
),
# "point_inputs": {
# "point_coords": torch.ones(batch_size, 1, 2, dtype=torch.float32, device="cuda"),
# "point_labels": torch.ones(batch_size, 1, dtype=torch.int32, device="cuda"),
# },
"point_inputs": None,
"mask_inputs": None,
"high_res_features": [
torch.randn(
batch_size,
32,
256,
256,
dtype=torch.bfloat16,
device="cuda",
),
torch.randn(
batch_size,
64,
128,
128,
dtype=torch.bfloat16,
device="cuda",
),
],
"multimask_output": True,
}
sam2_video_forward_sam_heads = SAM2VideoPredictor_forward_sam_heads(
predictor,
batch_size=batch_size,
furious=False,
)
aot_compile(
model_directory,
"sam2_video_forward_sam_heads",
sam2_video_forward_sam_heads,
example_input_args,
sample_kwargs=example_input_kwargs,
overwrite=overwrite,
)

return predictor


class LoadedModel(torch.nn.Module):
def __init__(self, aoti_compiled_model):
super().__init__()
self.aoti_compiled_model = aoti_compiled_model

def forward(self, *args, **kwargs):
return self.aoti_compiled_model(*args, **kwargs)


class LoadedDecoder(torch.nn.Module):
def __init__(self, aoti_compiled_model, other):
super().__init__()
self.aoti_compiled_model = aoti_compiled_model
self.other = other

def forward(self, *args):
return self.aoti_compiled_model(*args)

def get_dense_pe(self, *args, **kwargs) -> torch.Tensor:
return self.other.get_dense_pe(*args, **kwargs)


def load_exported_model(
predictor,
model_directory,
furious=False,
batch_size=1,
):
if furious:
set_furious(predictor)
t0 = time.time()
path = Path(model_directory) / Path("sam2_image_encoder_trunk.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
predictor.image_encoder.trunk = pkg_m

path = Path(model_directory) / Path("sam2_video_forward_sam_heads.pt2")
assert path.exists(), f"Expected {path} to exist"
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = SAM2VideoPredictor_forward_sam_heads(
None,
batch_size=batch_size,
aoti_compiled_model=pkg,
furious=furious,
)
predictor._forward_sam_heads = pkg_m.forward

print(f"End load image encoder and _forward_sam_heads. Took {time.time() - t0}s")
return predictor


def set_fast(predictor, loaded_exported_model=False):
if not loaded_exported_model:
predictor.image_encoder.trunk.forward = torch.compile(
predictor.image_encoder.trunk.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
if not loaded_exported_model:
predictor._forward_sam_heads = torch.compile(
predictor._forward_sam_heads,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
predictor.memory_attention = torch.compile(
predictor.memory_attention,
mode="max-autotune",
fullgraph=True,
dynamic=True,
)
predictor.memory_encoder.forward = torch.compile(
predictor.memory_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)


def set_furious(mask_generator):
mask_generator.predictor.model.image_encoder = (
mask_generator.predictor.model.image_encoder.to(torch.float16)
)
# NOTE: Not baseline feature
mask_generator.predictor._image_dtype = torch.float16
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision("high")
mask_generator.predictor.model.sam_mask_decoder = (
mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
)
# NOTE: Not baseline feature
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16
2 changes: 2 additions & 0 deletions examples/sam2_vos_example/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
requests
fire
Loading

0 comments on commit 6bab4db

Please sign in to comment.