From 72890c3b111691bed997e6d49513d45405cb16a1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 26 Mar 2025 18:20:39 +1000 Subject: [PATCH 1/2] experiment(backend): autocast dtype in CustomLinear This resolves an issue where specifying `float32` precision causes FLUX Fill to error. I noticed that our other customized torch modules do some dtype casting themselves, so maybe this is a fine place to do this? Maybe this could break things... See #7836 --- .../torch_module_autocast/cast_to_dtype.py | 19 +++++++++++++++++++ .../custom_modules/custom_linear.py | 5 +++++ 2 files changed, 24 insertions(+) create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py new file mode 100644 index 00000000000..e7ce95bcc4b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py @@ -0,0 +1,19 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) + + +def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T: + """Helper function to cast an optional tensor to a target dtype.""" + + if t is None: + # If the tensor is None, return it as is. + return t + + if t.dtype != to_dtype: + # The tensor is on the wrong device and we don't care about the dtype - or the dtype is already correct. + return t.to(dtype=to_dtype) + + return t diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index c440526b9b9..2f3f4266d75 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -3,6 +3,7 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_dtype import cast_to_dtype from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -73,6 +74,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) + + weight = cast_to_dtype(weight, input.dtype) + bias = cast_to_dtype(bias, input.dtype) + return torch.nn.functional.linear(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: From eaa1d8eb71357a2588f5774ea282cae9933a6637 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 27 Mar 2025 05:33:57 +1000 Subject: [PATCH 2/2] tidy(backend): errant comments --- .../load/model_cache/torch_module_autocast/cast_to_dtype.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py index e7ce95bcc4b..f0af929955f 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_dtype.py @@ -7,13 +7,9 @@ def cast_to_dtype(t: T, to_dtype: torch.dtype) -> T: """Helper function to cast an optional tensor to a target dtype.""" - if t is None: - # If the tensor is None, return it as is. return t if t.dtype != to_dtype: - # The tensor is on the wrong device and we don't care about the dtype - or the dtype is already correct. return t.to(dtype=to_dtype) - return t