1- from invokeai .backend .quantization .gguf .ggml_tensor import GGMLTensor
21from typing import Any
32
43import torch
54
5+ from invokeai .backend .quantization .gguf .ggml_tensor import GGMLTensor
6+
67
78class CachedModelOnlyFullLoad :
89 """A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@@ -78,8 +79,7 @@ def full_load_to_vram(self) -> int:
7879 new_state_dict [k ] = v .to (self ._compute_device , copy = True )
7980 self ._model .load_state_dict (new_state_dict , assign = True )
8081
81-
82- check_for_gguf = hasattr (self ._model , 'state_dict' ) and self ._model .state_dict ().get ("img_in.weight" )
82+ check_for_gguf = hasattr (self ._model , "state_dict" ) and self ._model .state_dict ().get ("img_in.weight" )
8383 if isinstance (check_for_gguf , GGMLTensor ):
8484 old_value = torch .__future__ .get_overwrite_module_params_on_conversion ()
8585 torch .__future__ .set_overwrite_module_params_on_conversion (True )
@@ -103,7 +103,7 @@ def full_unload_from_vram(self) -> int:
103103 if self ._cpu_state_dict is not None :
104104 self ._model .load_state_dict (self ._cpu_state_dict , assign = True )
105105
106- check_for_gguf = hasattr (self ._model , ' state_dict' ) and self ._model .state_dict ().get ("img_in.weight" )
106+ check_for_gguf = hasattr (self ._model , " state_dict" ) and self ._model .state_dict ().get ("img_in.weight" )
107107 if isinstance (check_for_gguf , GGMLTensor ):
108108 old_value = torch .__future__ .get_overwrite_module_params_on_conversion ()
109109 torch .__future__ .set_overwrite_module_params_on_conversion (True )
0 commit comments