@@ -60,6 +60,12 @@ class CLIPField(BaseModel):
60
60
loras : List [LoRAField ] = Field (description = "LoRAs to apply on model loading" )
61
61
62
62
63
+
64
+ class TransformerField (BaseModel ):
65
+ transformer : ModelIdentifierField = Field (description = "Info to load Transformer submodel" )
66
+ scheduler : ModelIdentifierField = Field (description = "Info to load scheduler submodel" )
67
+
68
+
63
69
class VAEField (BaseModel ):
64
70
vae : ModelIdentifierField = Field (description = "Info to load vae submodel" )
65
71
seamless_axes : List [str ] = Field (default_factory = list , description = 'Axes("x" and "y") to which apply seamless' )
@@ -122,6 +128,49 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
122
128
return ModelIdentifierOutput (model = self .model )
123
129
124
130
131
+ @invocation_output ("flux_model_loader_output" )
132
+ class FluxModelLoaderOutput (BaseInvocationOutput ):
133
+ """Flux base model loader output"""
134
+
135
+ transformer : TransformerField = OutputField (description = FieldDescriptions .transformer , title = "Transformer" )
136
+ clip : CLIPField = OutputField (description = FieldDescriptions .clip , title = "CLIP 1" )
137
+ clip2 : CLIPField = OutputField (description = FieldDescriptions .clip , title = "CLIP 2" )
138
+ vae : VAEField = OutputField (description = FieldDescriptions .vae , title = "VAE" )
139
+
140
+
141
+ @invocation ("flux_model_loader" , title = "Flux Main Model" , tags = ["model" , "flux" ], category = "model" , version = "1.0.3" )
142
+ class FluxModelLoaderInvocation (BaseInvocation ):
143
+ """Loads a flux base model, outputting its submodels."""
144
+
145
+ model : ModelIdentifierField = InputField (
146
+ description = FieldDescriptions .flux_model ,
147
+ ui_type = UIType .FluxMainModel ,
148
+ input = Input .Direct ,
149
+ )
150
+
151
+ def invoke (self , context : InvocationContext ) -> FluxModelLoaderOutput :
152
+ model_key = self .model .key
153
+
154
+ # TODO: not found exceptions
155
+ if not context .models .exists (model_key ):
156
+ raise Exception (f"Unknown model: { model_key } " )
157
+
158
+ transformer = self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
159
+ scheduler = self .model .model_copy (update = {"submodel_type" : SubModelType .Scheduler })
160
+ tokenizer = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer })
161
+ text_encoder = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder })
162
+ tokenizer2 = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer2 })
163
+ text_encoder2 = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder2 })
164
+ vae = self .model .model_copy (update = {"submodel_type" : SubModelType .VAE })
165
+
166
+ return FluxModelLoaderOutput (
167
+ transformer = TransformerField (transformer = transformer , scheduler = scheduler ),
168
+ clip = CLIPField (tokenizer = tokenizer , text_encoder = text_encoder , loras = [], skipped_layers = 0 ),
169
+ clip2 = CLIPField (tokenizer = tokenizer2 , text_encoder = text_encoder2 , loras = [], skipped_layers = 0 ),
170
+ vae = VAEField (vae = vae ),
171
+ )
172
+
173
+
125
174
@invocation (
126
175
"main_model_loader" ,
127
176
title = "Main Model" ,
0 commit comments