Add a node to load diff controlnets.

This commit is contained in:
comfyanonymous 2023-02-22 23:22:03 -05:00
parent 3ae61a2bca
commit 62df8dd62a
2 changed files with 35 additions and 1 deletions

View File

@ -400,7 +400,7 @@ class ControlNet:
out.append(self.control_model) out.append(self.control_model)
return out return out
def load_controlnet(ckpt_path): def load_controlnet(ckpt_path, model=None):
controlnet_data = load_torch_file(ckpt_path) controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False pth = False
@ -437,6 +437,21 @@ def load_controlnet(ckpt_path):
use_fp16=use_fp16) use_fp16=use_fp16)
if pth: if pth:
if 'difference' in controlnet_data:
if model is not None:
m = model.patch_model()
model_sd = m.state_dict()
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
model.unpatch_model()
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
w = WeightsLoader() w = WeightsLoader()

View File

@ -232,6 +232,24 @@ class ControlNetLoader:
controlnet = comfy.sd.load_controlnet(controlnet_path) controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,) return (controlnet,)
class DiffControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
return (controlnet,)
class ControlNetApply: class ControlNetApply:
@classmethod @classmethod
@ -770,6 +788,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
} }
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")