Skip to content

MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from

Conversation

johnzielke
Copy link
Contributor

Description

The current maisi vae encoder only supports float32 and float16 with a parameter that needs to be set manually. This PR instead infers the norm datatype from the input and should therefore enable the use of other datatypes

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@johnzielke johnzielke force-pushed the bugfix/maisi-vae-autoencoder-fix-dtype branch from 81865e4 to 8888a48 Compare February 4, 2025 15:55
@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 13, 2025

Hi @johnzielke, could you please help resolve the conflict then I cane help trigger the blossom, thanks!

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 21, 2025

/build

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 21, 2025

Hi @dongyang0122, could you please help check whether this change make sense to you? Just want confirm there is no other specific concern regarding this norm_float16 param.

self.print_info = print_info
self.save_mem = save_mem

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the input is float32 but users want convert the output to the float16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change only affects the group norm and makes the behavior in line with the rest of the model.
If I understand you correctly, to achieve what you want, the common pattern would be:

model = model.to(dtype=torch.bfloat16)
prediction = model(x.to(dtype=torch.bfloat16)

Or are you referring to something else?

Copy link
Contributor Author

@johnzielke johnzielke Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change, the parameter needs to be manually adjusted to produce a tensor that's compatible with the rest of the model. In addition, it could only be float32 or float16. Is there a reason one would want to have the GroupNorm in a different datatype than the rest of the model?

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 22, 2025

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants