-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326
Conversation
Signed-off-by: John Zielke <[email protected]>
81865e4
to
8888a48
Compare
Hi @johnzielke, could you please help resolve the conflict then I cane help trigger the blossom, thanks! |
/build |
Hi @dongyang0122, could you please help check whether this change make sense to you? Just want confirm there is no other specific concern regarding this |
self.print_info = print_info | ||
self.save_mem = save_mem | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if self.print_info: | ||
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") | ||
|
||
target_dtype = input.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the input is float32 but users want convert the output to the float16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change only affects the group norm and makes the behavior in line with the rest of the model.
If I understand you correctly, to achieve what you want, the common pattern would be:
model = model.to(dtype=torch.bfloat16)
prediction = model(x.to(dtype=torch.bfloat16)
Or are you referring to something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this change, the parameter needs to be manually adjusted to produce a tensor that's compatible with the rest of the model. In addition, it could only be float32 or float16. Is there a reason one would want to have the GroupNorm in a different datatype than the rest of the model?
/build |
Description
The current maisi vae encoder only supports float32 and float16 with a parameter that needs to be set manually. This PR instead infers the norm datatype from the input and should therefore enable the use of other datatypes
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.