mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add a node to load diff controlnets.
This commit is contained in:
parent
3ae61a2bca
commit
62df8dd62a
17
comfy/sd.py
17
comfy/sd.py
@ -400,7 +400,7 @@ class ControlNet:
|
|||||||
out.append(self.control_model)
|
out.append(self.control_model)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_controlnet(ckpt_path):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = load_torch_file(ckpt_path)
|
controlnet_data = load_torch_file(ckpt_path)
|
||||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@ -437,6 +437,21 @@ def load_controlnet(ckpt_path):
|
|||||||
use_fp16=use_fp16)
|
use_fp16=use_fp16)
|
||||||
|
|
||||||
if pth:
|
if pth:
|
||||||
|
if 'difference' in controlnet_data:
|
||||||
|
if model is not None:
|
||||||
|
m = model.patch_model()
|
||||||
|
model_sd = m.state_dict()
|
||||||
|
for x in controlnet_data:
|
||||||
|
c_m = "control_model."
|
||||||
|
if x.startswith(c_m):
|
||||||
|
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
|
||||||
|
if sd_key in model_sd:
|
||||||
|
cd = controlnet_data[x]
|
||||||
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||||
|
model.unpatch_model()
|
||||||
|
else:
|
||||||
|
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
|
19
nodes.py
19
nodes.py
@ -232,6 +232,24 @@ class ControlNetLoader:
|
|||||||
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||||
return (controlnet,)
|
return (controlnet,)
|
||||||
|
|
||||||
|
class DiffControlNetLoader:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONTROL_NET",)
|
||||||
|
FUNCTION = "load_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_controlnet(self, model, control_net_name):
|
||||||
|
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||||
|
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
|
||||||
|
return (controlnet,)
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApply:
|
class ControlNetApply:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -770,6 +788,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
||||||
|
Loading…
Reference in New Issue
Block a user