44import torch
55import comfy .utils
66import folder_paths
7+ from typing_extensions import override
8+ from comfy_api .latest import ComfyExtension , io
79
810try :
911 from spandrel_extra_arches import EXTRA_REGISTRY
1315except :
1416 pass
1517
16- class UpscaleModelLoader :
18+ class UpscaleModelLoader ( io . ComfyNode ) :
1719 @classmethod
18- def INPUT_TYPES (s ):
19- return {"required" : { "model_name" : (folder_paths .get_filename_list ("upscale_models" ), ),
20- }}
21- RETURN_TYPES = ("UPSCALE_MODEL" ,)
22- FUNCTION = "load_model"
20+ def define_schema (cls ):
21+ return io .Schema (
22+ node_id = "UpscaleModelLoader" ,
23+ display_name = "Load Upscale Model" ,
24+ category = "loaders" ,
25+ inputs = [
26+ io .Combo .Input ("model_name" , options = folder_paths .get_filename_list ("upscale_models" )),
27+ ],
28+ outputs = [
29+ io .UpscaleModel .Output (),
30+ ],
31+ )
2332
24- CATEGORY = "loaders"
25-
26- def load_model (self , model_name ):
33+ @classmethod
34+ def execute (cls , model_name ) -> io .NodeOutput :
2735 model_path = folder_paths .get_full_path_or_raise ("upscale_models" , model_name )
2836 sd = comfy .utils .load_torch_file (model_path , safe_load = True )
2937 if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd :
@@ -33,21 +41,29 @@ def load_model(self, model_name):
3341 if not isinstance (out , ImageModelDescriptor ):
3442 raise Exception ("Upscale model must be a single-image model." )
3543
36- return (out , )
44+ return io . NodeOutput (out )
3745
46+ load_model = execute # TODO: remove
3847
39- class ImageUpscaleWithModel :
40- @classmethod
41- def INPUT_TYPES (s ):
42- return {"required" : { "upscale_model" : ("UPSCALE_MODEL" ,),
43- "image" : ("IMAGE" ,),
44- }}
45- RETURN_TYPES = ("IMAGE" ,)
46- FUNCTION = "upscale"
4748
48- CATEGORY = "image/upscaling"
49+ class ImageUpscaleWithModel (io .ComfyNode ):
50+ @classmethod
51+ def define_schema (cls ):
52+ return io .Schema (
53+ node_id = "ImageUpscaleWithModel" ,
54+ display_name = "Upscale Image (using Model)" ,
55+ category = "image/upscaling" ,
56+ inputs = [
57+ io .UpscaleModel .Input ("upscale_model" ),
58+ io .Image .Input ("image" ),
59+ ],
60+ outputs = [
61+ io .Image .Output (),
62+ ],
63+ )
4964
50- def upscale (self , upscale_model , image ):
65+ @classmethod
66+ def execute (cls , upscale_model , image ) -> io .NodeOutput :
5167 device = model_management .get_torch_device ()
5268
5369 memory_required = model_management .module_size (upscale_model .model )
@@ -75,9 +91,19 @@ def upscale(self, upscale_model, image):
7591
7692 upscale_model .to ("cpu" )
7793 s = torch .clamp (s .movedim (- 3 ,- 1 ), min = 0 , max = 1.0 )
78- return (s ,)
94+ return io .NodeOutput (s )
95+
96+ upscale = execute # TODO: remove
97+
98+
99+ class UpscaleModelExtension (ComfyExtension ):
100+ @override
101+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
102+ return [
103+ UpscaleModelLoader ,
104+ ImageUpscaleWithModel ,
105+ ]
106+
79107
80- NODE_CLASS_MAPPINGS = {
81- "UpscaleModelLoader" : UpscaleModelLoader ,
82- "ImageUpscaleWithModel" : ImageUpscaleWithModel
83- }
108+ async def comfy_entrypoint () -> UpscaleModelExtension :
109+ return UpscaleModelExtension ()
0 commit comments