-
Notifications
You must be signed in to change notification settings - Fork 564
Open
Description
Hello team,
I am trying to profile an inference run for a model as follows:
device = torch_xla.device()
torch_model = torch_model.to(device)
torch_inputs = torch_inputs.to(device)
print(f"Running inference on {model}")
print(f"Warming up {model} with {num_warmup_runs} runs")
for _ in range(num_warmup_runs):
with torch.no_grad():
logits = torch_model(torch_inputs).logits
torch_xla.sync(wait=True)
latencies_ms = []
for i in range(args.num_iterations):
start_time = time.time()
with torch.no_grad():
logits = torch_model(torch_inputs).logits
torch_xla.sync(wait=True)
end_time = time.time()
latencies_ms.append((end_time - start_time) * 1000)However, this program only terminates when I comment out the torch_xla.sync lines. I also noticed that I get similar performance numbers with and without the torch_xla.sync() call. When do I need to sync it, and why might it be causing the program to hang?
Perhaps more importantly, what is the right way to measure how long an inference takes on a TPU?
Metadata
Metadata
Assignees
Labels
No labels