mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge T2IAdapterLoader and ControlNetLoader.
Workflows will be auto updated.
This commit is contained in:
parent
e1a9e26968
commit
2e73367f45
13
comfy/sd.py
13
comfy/sd.py
@ -527,8 +527,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
elif key in controlnet_data:
|
elif key in controlnet_data:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print("error checkpoint does not contain controlnet data", ckpt_path)
|
net = load_t2i_adapter(controlnet_data)
|
||||||
return None
|
if net is None:
|
||||||
|
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||||
|
return net
|
||||||
|
|
||||||
context_dim = controlnet_data[key].shape[1]
|
context_dim = controlnet_data[key].shape[1]
|
||||||
|
|
||||||
@ -682,15 +684,16 @@ class T2IAdapter:
|
|||||||
out += self.previous_controlnet.get_control_models()
|
out += self.previous_controlnet.get_control_models()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_t2i_adapter(ckpt_path, model=None):
|
def load_t2i_adapter(t2i_data):
|
||||||
t2i_data = load_torch_file(ckpt_path)
|
|
||||||
keys = t2i_data.keys()
|
keys = t2i_data.keys()
|
||||||
if "body.0.in_conv.weight" in keys:
|
if "body.0.in_conv.weight" in keys:
|
||||||
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
||||||
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||||
else:
|
elif 'conv_in.weight' in keys:
|
||||||
cin = t2i_data['conv_in.weight'].shape[1]
|
cin = t2i_data['conv_in.weight'].shape[1]
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
model_ad.load_state_dict(t2i_data)
|
model_ad.load_state_dict(t2i_data)
|
||||||
return T2IAdapter(model_ad, cin // 64)
|
return T2IAdapter(model_ad, cin // 64)
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
from comfy_extras.chainner_models import model_loading
|
from comfy_extras.chainner_models import model_loading
|
||||||
from comfy.sd import load_torch_file
|
from comfy.sd import load_torch_file
|
||||||
import model_management
|
import model_management
|
||||||
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
38
nodes.py
38
nodes.py
@ -24,26 +24,6 @@ import model_management
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
|
||||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
|
||||||
try:
|
|
||||||
import safetensors.torch
|
|
||||||
supported_ckpt_extensions += ['.safetensors']
|
|
||||||
supported_pt_extensions += ['.safetensors']
|
|
||||||
except:
|
|
||||||
print("Could not import safetensors, safetensors support disabled.")
|
|
||||||
|
|
||||||
def recursive_search(directory):
|
|
||||||
result = []
|
|
||||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
|
||||||
for filepath in file:
|
|
||||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
|
||||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def filter_files_extensions(files, extensions):
|
|
||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
|
||||||
|
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
@ -348,23 +328,6 @@ class ControlNetApply:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class T2IAdapterLoader:
|
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
|
||||||
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("CONTROL_NET",)
|
|
||||||
FUNCTION = "load_t2i_adapter"
|
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
|
|
||||||
def load_t2i_adapter(self, t2i_adapter_name):
|
|
||||||
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
|
|
||||||
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
|
|
||||||
return (t2i_adapter,)
|
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -963,7 +926,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
"T2IAdapterLoader": T2IAdapterLoader,
|
|
||||||
"StyleModelLoader": StyleModelLoader,
|
"StyleModelLoader": StyleModelLoader,
|
||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
|
@ -614,6 +614,12 @@ class ComfyApp {
|
|||||||
if (!graphData) {
|
if (!graphData) {
|
||||||
graphData = defaultGraph;
|
graphData = defaultGraph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
||||||
|
for (let n of graphData.nodes) {
|
||||||
|
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
|
||||||
|
}
|
||||||
|
|
||||||
this.graph.configure(graphData);
|
this.graph.configure(graphData);
|
||||||
|
|
||||||
for (const node of this.graph._nodes) {
|
for (const node of this.graph._nodes) {
|
||||||
|
Loading…
Reference in New Issue
Block a user