diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 83d7e11815..a17377f68c 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -344,8 +344,7 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - # TODO: update to 2.8.0 after https://github.com/pytorch/ao/pull/2786 is landed - if torch_version_at_least("2.9.0"): + if torch_version_at_least("2.8.0"): from statistics import median res = benchmarker.benchmark_gpu(