Use fp16 for fp16 control nets.

This commit is contained in:
comfyanonymous 2023-02-17 15:31:38 -05:00
parent 71354c7c57
commit 220a72d36b

View File

@ -1,4 +1,5 @@
import torch
import contextlib
import sd1_clip
import sd2_clip
@ -327,23 +328,36 @@ class VAE:
return samples
class ControlNet:
def __init__(self, control_model):
def __init__(self, control_model, device="cuda"):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.device = device
def get_control(self, x_noisy, t, cond_txt):
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
print("set cond_hint", self.cond_hint.shape)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
with precision_scope(self.device):
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
out = []
autocast_enabled = torch.is_autocast_enabled()
for x in control:
x *= self.strength
return control
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
out.append(x)
return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
@ -377,6 +391,11 @@ def load_controlnet(ckpt_path):
return None
context_dim = controlnet_data[key].shape[1]
use_fp16 = False
if controlnet_data[key].dtype == torch.float16:
use_fp16 = True
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
@ -389,7 +408,8 @@ def load_controlnet(ckpt_path):
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
legacy=False)
legacy=False,
use_fp16=use_fp16)
if pth:
class WeightsLoader(torch.nn.Module):