Skip to content

Commit 65639d5

Browse files
committed
style
1 parent 3178c4e commit 65639d5

File tree

6 files changed

+21
-9
lines changed

6 files changed

+21
-9
lines changed

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def __init__(
171171

172172
self.gradient_checkpointing = False
173173

174-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
174+
def _load_from_state_dict(
175+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
176+
):
175177
if "scale_shift_table" in state_dict:
176178
scale_shift_table = state_dict.pop("scale_shift_table")
177179
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,13 @@ def __init__(
185185
)
186186
self.caption_projection = None
187187
if self.config.caption_channels is not None:
188-
self.caption_projection = PixArtAlphaTextProjection(
189-
in_features=self.config.caption_channels, hidden_size=self.inner_dim
190-
)
188+
self.caption_projection = PixArtAlphaTextProjection(
189+
in_features=self.config.caption_channels, hidden_size=self.inner_dim
190+
)
191191

192-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
192+
def _load_from_state_dict(
193+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
194+
):
193195
if "scale_shift_table" in state_dict:
194196
scale_shift_table = state_dict.pop("scale_shift_table")
195197
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def __init__(
310310

311311
self.gradient_checkpointing = False
312312

313-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
313+
def _load_from_state_dict(
314+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
315+
):
314316
if "scale_shift_table" in state_dict:
315317
scale_shift_table = state_dict.pop("scale_shift_table")
316318
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def __init__(
400400

401401
self.gradient_checkpointing = False
402402

403-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
403+
def _load_from_state_dict(
404+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
405+
):
404406
key = "scale_shift_table"
405407
if prefix + key in state_dict:
406408
scale_shift_table = state_dict.pop(prefix + key)

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ def __init__(
439439

440440
self.gradient_checkpointing = False
441441

442-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
442+
def _load_from_state_dict(
443+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
444+
):
443445
key = "scale_shift_table"
444446
if prefix + key in state_dict:
445447
scale_shift_table = state_dict.pop(prefix + key)

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def __init__(
270270

271271
self.gradient_checkpointing = False
272272

273-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
273+
def _load_from_state_dict(
274+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
275+
):
274276
key = "scale_shift_table"
275277
if prefix + key in state_dict:
276278
scale_shift_table = state_dict.pop(prefix + key)

0 commit comments

Comments
 (0)