mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Implement support for t2i style model.
It needs the CLIPVision model so I added CLIPVisionLoader and CLIPVisionEncode. Put the clip vision model in models/clip_vision Put the t2i style model in models/style_models StyleModelLoader to load it, StyleModelApply to apply it ConditioningAppend to append the conditioning it outputs to a positive one.
This commit is contained in:
parent
cc8baf1080
commit
47acb3d73e
26
comfy/sd.py
26
comfy/sd.py
@ -613,11 +613,7 @@ class T2IAdapter:
|
|||||||
def load_t2i_adapter(ckpt_path, model=None):
|
def load_t2i_adapter(ckpt_path, model=None):
|
||||||
t2i_data = load_torch_file(ckpt_path)
|
t2i_data = load_torch_file(ckpt_path)
|
||||||
keys = t2i_data.keys()
|
keys = t2i_data.keys()
|
||||||
if "style_embedding" in keys:
|
if "body.0.in_conv.weight" in keys:
|
||||||
pass
|
|
||||||
# TODO
|
|
||||||
# model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
|
||||||
elif "body.0.in_conv.weight" in keys:
|
|
||||||
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
||||||
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||||
else:
|
else:
|
||||||
@ -626,6 +622,26 @@ def load_t2i_adapter(ckpt_path, model=None):
|
|||||||
model_ad.load_state_dict(t2i_data)
|
model_ad.load_state_dict(t2i_data)
|
||||||
return T2IAdapter(model_ad, cin // 64)
|
return T2IAdapter(model_ad, cin // 64)
|
||||||
|
|
||||||
|
|
||||||
|
class StyleModel:
|
||||||
|
def __init__(self, model, device="cpu"):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def get_cond(self, input):
|
||||||
|
return self.model(input.last_hidden_state)
|
||||||
|
|
||||||
|
|
||||||
|
def load_style_model(ckpt_path):
|
||||||
|
model_data = load_torch_file(ckpt_path)
|
||||||
|
keys = model_data.keys()
|
||||||
|
if "style_embedding" in keys:
|
||||||
|
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
|
else:
|
||||||
|
raise Exception("invalid style model {}".format(ckpt_path))
|
||||||
|
model.load_state_dict(model_data)
|
||||||
|
return StyleModel(model)
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip_data = load_torch_file(ckpt_path)
|
clip_data = load_torch_file(ckpt_path)
|
||||||
config = {}
|
config = {}
|
||||||
|
32
comfy_extras/clip_vision.py
Normal file
32
comfy_extras/clip_vision.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
|
||||||
|
from comfy.sd import load_torch_file
|
||||||
|
import os
|
||||||
|
|
||||||
|
class ClipVisionModel():
|
||||||
|
def __init__(self):
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
|
||||||
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
|
self.model = CLIPVisionModel(config)
|
||||||
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
|
do_center_crop=True,
|
||||||
|
do_convert_rgb=True,
|
||||||
|
do_normalize=True,
|
||||||
|
do_resize=True,
|
||||||
|
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
||||||
|
image_std=[0.26862954,0.26130258,0.27577711],
|
||||||
|
resample=3, #bicubic
|
||||||
|
size=224)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
inputs = self.processor(images=[image[0]], return_tensors="pt")
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def load(ckpt_path):
|
||||||
|
clip_data = load_torch_file(ckpt_path)
|
||||||
|
clip = ClipVisionModel()
|
||||||
|
clip.load_sd(clip_data)
|
||||||
|
return clip
|
0
models/clip_vision/put_clip_vision_models_here
Normal file
0
models/clip_vision/put_clip_vision_models_here
Normal file
0
models/style_models/put_t2i_style_model_here
Normal file
0
models/style_models/put_t2i_style_model_here
Normal file
90
nodes.py
90
nodes.py
@ -18,6 +18,8 @@ import comfy.samplers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
import comfy_extras.clip_vision
|
||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
@ -370,6 +372,89 @@ class CLIPLoader:
|
|||||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
class CLIPVisionLoader:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
clip_dir = os.path.join(models_dir, "clip_vision")
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP_VISION",)
|
||||||
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_clip(self, clip_name):
|
||||||
|
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||||
|
clip_vision = comfy_extras.clip_vision.load(clip_path)
|
||||||
|
return (clip_vision,)
|
||||||
|
|
||||||
|
class CLIPVisionEncode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||||
|
"image": ("IMAGE",)
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP_VISION_EMBED",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip_vision, image):
|
||||||
|
output = clip_vision.encode_image(image)
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
class StyleModelLoader:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
style_model_dir = os.path.join(models_dir, "style_models")
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STYLE_MODEL",)
|
||||||
|
FUNCTION = "load_style_model"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_style_model(self, style_model_name):
|
||||||
|
style_model_path = os.path.join(self.style_model_dir, style_model_name)
|
||||||
|
style_model = comfy.sd.load_style_model(style_model_path)
|
||||||
|
return (style_model,)
|
||||||
|
|
||||||
|
|
||||||
|
class StyleModelApply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"clip_vision_embed": ("CLIP_VISION_EMBED", ),
|
||||||
|
"style_model": ("STYLE_MODEL", )
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "apply_stylemodel"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def apply_stylemodel(self, clip_vision_embed, style_model):
|
||||||
|
c = style_model.get_cond(clip_vision_embed)
|
||||||
|
return ([[c, {}]], )
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningAppend:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", )}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def append(self, conditioning_to, conditioning_from):
|
||||||
|
c = []
|
||||||
|
to_append = conditioning_from[0][0]
|
||||||
|
for t in conditioning_to:
|
||||||
|
n = [torch.cat((t[0],to_append), dim=1), t[1].copy()]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -866,6 +951,11 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
|
"StyleModelLoader": StyleModelLoader,
|
||||||
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
|
"StyleModelApply":StyleModelApply,
|
||||||
|
"ConditioningAppend":ConditioningAppend,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
|
Loading…
Reference in New Issue
Block a user