Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ class CLIPType(Enum):
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19

HUNYUAN_IMAGE_REFINER = 20

def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
Expand Down Expand Up @@ -995,6 +995,9 @@ class EmptyClass:
if clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.HUNYUAN_IMAGE_REFINER:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, refiner=True, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageRefinerTokenizer
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
Expand Down
85 changes: 76 additions & 9 deletions comfy/text_encoders/hunyuan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from transformers import ByT5Tokenizer
import os
import re
import torch
import numbers

class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
Expand Down Expand Up @@ -38,6 +40,13 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
return out

class HunyuanImageRefinerTokenizer(HunyuanImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"



class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
Expand All @@ -53,21 +62,45 @@ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)


class HunyuanImageTEModel(QwenImageTEModel):
class HunyuanImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)

if byt5:
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
else:
self.byt5_small = None

def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs)
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1

if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3

out = out[:, template_end:]

extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements
# noqa: W293

if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0]
return cond, p, extra
byt5_out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = byt5_out[0]
return out, pooled, extra



def set_clip_options(self, options):
super().set_clip_options(options)
Expand All @@ -84,14 +117,48 @@ def load_sd(self, sd):
return self.byt5_small.load_sd(sd)
else:
return super().load_sd(sd)
class HunyuanImageRefinerTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)

def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 6171:
template_end = i
break

out = out[:, template_end-1:]

extra["attention_mask"] = extra["attention_mask"][:, template_end-1:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements

return out, pooled, extra


def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None, refiner=False):
class HunyuanImageTEModel_(HunyuanImageTEModel):

def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_
class HunyuanImageTEModel_refiner(HunyuanImageRefinerTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
assert refiner, "refiner must be True"
assert not byt5, "byt5 must be False"
super().__init__(device=device, dtype=dtype, model_options=model_options)
return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_
Loading
Loading