Skip to content

Conversation

zhongbozhu
Copy link
Collaborator

Description

Motivation: #2053

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: zhongboz <[email protected]>
Signed-off-by: zhongboz <[email protected]>
@zhongbozhu zhongbozhu requested a review from timmoon10 September 2, 2025 19:42
@zhongbozhu zhongbozhu self-assigned this Sep 2, 2025
Signed-off-by: zhongboz <[email protected]>
Signed-off-by: zhongboz <[email protected]>
@@ -641,11 +641,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we didn't really support this case anyway?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
if is_first_microbatch is None or is_first_microbatch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we assume we can skip setting the device if is_first_microbatch=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that the device won't change across microbatches in a global batch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since in a CPU bounded fwd only case, skipping set device for every single forward pass could account for 10% perf difference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is really ad hoc. Personally, I think it would be better to not to support the multi-device case (basically revert #1974) than to have inconsistent multi-device support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with it, but not sure if there are any potential impact for customers using this feature?

Signed-off-by: zhongboz <[email protected]>
@zhongbozhu
Copy link
Collaborator Author

/te-ci pytorch L1

@zhongbozhu
Copy link
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants