Skip to content

Replace xm.mark_step with torch_xla.sync() wherever possible #9070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

ghpvnist
Copy link
Collaborator

@ghpvnist ghpvnist commented May 1, 2025

Fixes #8862.

Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo one nit on the comment.

I think we have some tests that look for strings (yuck!) and there's a subtle string mismatch.

"""Launches all pending graph operations.

Args:
wait (bool): whether to block the current process until the execution finished.

reset_scope (bool): whether to reset the tracing scope of lazy tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate what this means? By reading the code, I wasn't sure what does "resetting" mean. Does it mean any tracing scope set by the user for the profiler are invalidated? Maybe could you run some simple tests to verify its behavior?

@yaoshiang
Copy link
Collaborator

This comment shouldn't hold up this PR, but I want to register that I think that we have an opportunity to rename this again in the future to make it easier for users to understand what this does without reading documentation. The problem with sync is that it is too close to torch.cuda.synchronize, which has a different meaning. So torch_xla.sync is going to be confusing the majority of our users.

What I think we should do is have a

torch_xla.synchronize which ONLY waits for all results from the underlying devices to return, similar to torch.cuda.synchronize.
torch_xla.compile(), which will match the norms of torch developers, who expect a possible graph break at the end of a decorated function.
torch_xla._barrier(), which will become a documented but not recommended usage to indicate that lazytensor should break the graph.

There are unresolved design questions:
What happens if a synchronize is inside a compile()? We'd have to look to how this works in CUDA to try to make this parallel.

Next torch.compile vs torch_xla.compile()

torch.cuda.synchronize(device=None)[SOURCE][SOURCE]
Wait for all kernels in all streams on a CUDA device to complete.

Parameters
device (torch.device or int, optional) – device for which to synchronize. It uses the current device, given by current_device(), if device is None (default).

@bhavya01
Copy link
Collaborator

bhavya01 commented May 2, 2025

This comment shouldn't hold up this PR, but I want to register that I think that we have an opportunity to rename this again in the future to make it easier for users to understand what this does without reading documentation. The problem with sync is that it is too close to torch.cuda.synchronize, which has a different meaning. So torch_xla.sync is going to be confusing the majority of our users.

What I think we should do is have a

torch_xla.synchronize which ONLY waits for all results from the underlying devices to return, similar to torch.cuda.synchronize. torch_xla.compile(), which will match the norms of torch developers, who expect a possible graph break at the end of a decorated function. torch_xla._barrier(), which will become a documented but not recommended usage to indicate that lazytensor should break the graph.

There are unresolved design questions: What happens if a synchronize is inside a compile()? We'd have to look to how this works in CUDA to try to make this parallel.

Next torch.compile vs torch_xla.compile()

torch.cuda.synchronize(device=None)[SOURCE][SOURCE] Wait for all kernels in all streams on a CUDA device to complete.

Parameters device (torch.device or int, optional) – device for which to synchronize. It uses the current device, given by current_device(), if device is None (default).

Agree that torch_xla.sync() is not the best naming. Should we explicitly try to educate the users by using something like compile_all_lazy_graphs()? torch_xla.barrier is also a good option. My only concern is that a barrier is generally used in the context of threads and processes and we already have APIs like apply_backward_optimization_barrier. Just wanted to avoid the term from being too overloaded.

@tengyifei
Copy link
Collaborator

IMO it's very worth having a separate discussion for renaming torch_xla.sync(), but outside of this PR. I agree that sync is also confusing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Replace xm.mark_step with torch_xla.sync() in examples and tests
4 participants