mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use fp16 for fp16 control nets.
This commit is contained in:
parent
71354c7c57
commit
220a72d36b
32
comfy/sd.py
32
comfy/sd.py
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
import sd1_clip
|
||||
import sd2_clip
|
||||
@ -327,23 +328,36 @@ class VAE:
|
||||
return samples
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model):
|
||||
def __init__(self, control_model, device="cuda"):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.device = device
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt):
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
||||
print("set cond_hint", self.cond_hint.shape)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
|
||||
if self.control_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
with precision_scope(self.device):
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||
out = []
|
||||
autocast_enabled = torch.is_autocast_enabled()
|
||||
for x in control:
|
||||
x *= self.strength
|
||||
return control
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -377,6 +391,11 @@ def load_controlnet(ckpt_path):
|
||||
return None
|
||||
|
||||
context_dim = controlnet_data[key].shape[1]
|
||||
|
||||
use_fp16 = False
|
||||
if controlnet_data[key].dtype == torch.float16:
|
||||
use_fp16 = True
|
||||
|
||||
control_model = cldm.ControlNet(image_size=32,
|
||||
in_channels=4,
|
||||
hint_channels=3,
|
||||
@ -389,7 +408,8 @@ def load_controlnet(ckpt_path):
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
legacy=False)
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
|
||||
if pth:
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user