mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
All resolutions now work with t2i adapter for SDXL.
This commit is contained in:
parent
85fde89d7f
commit
afcb9cb1df
10
comfy/sd.py
10
comfy/sd.py
@ -2,6 +2,7 @@ import torch
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.util import instantiate_from_config
|
from .ldm.util import instantiate_from_config
|
||||||
@ -1099,6 +1100,12 @@ class T2IAdapter(ControlBase):
|
|||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
|
|
||||||
|
def scale_image_to(self, width, height):
|
||||||
|
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||||
|
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
|
||||||
|
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||||
|
return width, height
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -1116,7 +1123,8 @@ class T2IAdapter(ControlBase):
|
|||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
|
||||||
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
@ -103,17 +103,17 @@ class ResnetBlock(nn.Module):
|
|||||||
class Adapter(nn.Module):
|
class Adapter(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
|
||||||
super(Adapter, self).__init__()
|
super(Adapter, self).__init__()
|
||||||
unshuffle = 8
|
self.unshuffle_amount = 8
|
||||||
resblock_no_downsample = []
|
resblock_no_downsample = []
|
||||||
resblock_downsample = [3, 2, 1]
|
resblock_downsample = [3, 2, 1]
|
||||||
self.xl = xl
|
self.xl = xl
|
||||||
if self.xl:
|
if self.xl:
|
||||||
unshuffle = 16
|
self.unshuffle_amount = 16
|
||||||
resblock_no_downsample = [1]
|
resblock_no_downsample = [1]
|
||||||
resblock_downsample = [2]
|
resblock_downsample = [2]
|
||||||
|
|
||||||
self.input_channels = cin // (unshuffle * unshuffle)
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
||||||
self.unshuffle = nn.PixelUnshuffle(unshuffle)
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
@ -264,9 +264,9 @@ class extractor(nn.Module):
|
|||||||
class Adapter_light(nn.Module):
|
class Adapter_light(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
||||||
super(Adapter_light, self).__init__()
|
super(Adapter_light, self).__init__()
|
||||||
unshuffle = 8
|
self.unshuffle_amount = 8
|
||||||
self.unshuffle = nn.PixelUnshuffle(unshuffle)
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
||||||
self.input_channels = cin // (unshuffle * unshuffle)
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
|
Loading…
Reference in New Issue
Block a user