Skip to content

Commit

Permalink
Some dtype fixes for vae tests (#989)
Browse files Browse the repository at this point in the history
Fixes data dependent test failures from VAE
  • Loading branch information
IanNod authored Feb 21, 2025
1 parent 63606a5 commit 5250392
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
14 changes: 12 additions & 2 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self, hp, theta: Theta):
self.mid_block = self._create_mid_block(theta("decoder")("mid_block"))
# up
self.up_blocks = nn.ModuleList([])
self.upscale_dtype = theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")(
"weight"
self.upscale_dtype = unbox_tensor(
theta("decoder")("up_blocks")(0)("resnets")(0)("conv1")("weight")
).dtype
for i, up_block_name in enumerate(hp.up_block_types):
up_block_theta = theta("decoder")("up_blocks")(i)
Expand Down Expand Up @@ -74,6 +74,16 @@ def forward(
"latent_embeds": latent_embeds,
},
)
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)

sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
Expand Down
10 changes: 1 addition & 9 deletions sharktank/sharktank/pipelines/flux/export_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"):
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
return self.ae.forward(d_in)
return self.ae.forward(z)


def get_ae_model_and_inputs(
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/tools/import_hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def import_hf_dataset(
for params_path in param_paths:
with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
tensors += [
DefaultPrimitiveTensor(name=name, data=st.get_tensor(name))
DefaultPrimitiveTensor(
name=name, data=st.get_tensor(name).to(target_dtype)
)
for name in st.keys()
]

Expand Down
4 changes: 2 additions & 2 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def testVaeIreeVsHuggingFace(self):
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")

# TODO: Decomposing attention due to https://github.com/iree-org/iree/issues/19286, remove once issue is resolved
module = export_vae(model, inputs, True)
module = export_vae(model, inputs.to(dtype=dtype), True)
module_f32 = export_vae(model_f32, inputs, True)

module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir")
Expand Down Expand Up @@ -317,7 +317,7 @@ def testVaeIreeVsHuggingFace(self):
parameters_path="{self._temp_dir}/flux_vae_bf16.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
input_args = OrderedDict([("inputs", inputs.to(dtype=dtype))])
iree_args = flatten_for_iree_signature(input_args)

iree_args = prepare_iree_module_function_args(
Expand Down

0 comments on commit 5250392

Please sign in to comment.