All resolutions now work with t2i adapter for SDXL.

This commit is contained in:
comfyanonymous 2023-08-22 16:23:54 -04:00
parent 85fde89d7f
commit afcb9cb1df
2 changed files with 16 additions and 8 deletions

View File

@ -2,6 +2,7 @@ import torch
import contextlib import contextlib
import copy import copy
import inspect import inspect
import math
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
@ -1099,6 +1100,12 @@ class T2IAdapter(ControlBase):
self.channels_in = channels_in self.channels_in = channels_in
self.control_input = None self.control_input = None
def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -1116,7 +1123,8 @@ class T2IAdapter(ControlBase):
del self.cond_hint del self.cond_hint
self.control_input = None self.control_input = None
self.cond_hint = None self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1: if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:

View File

@ -103,17 +103,17 @@ class ResnetBlock(nn.Module):
class Adapter(nn.Module): class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
super(Adapter, self).__init__() super(Adapter, self).__init__()
unshuffle = 8 self.unshuffle_amount = 8
resblock_no_downsample = [] resblock_no_downsample = []
resblock_downsample = [3, 2, 1] resblock_downsample = [3, 2, 1]
self.xl = xl self.xl = xl
if self.xl: if self.xl:
unshuffle = 16 self.unshuffle_amount = 16
resblock_no_downsample = [1] resblock_no_downsample = [1]
resblock_downsample = [2] resblock_downsample = [2]
self.input_channels = cin // (unshuffle * unshuffle) self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
self.unshuffle = nn.PixelUnshuffle(unshuffle) self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []
@ -264,9 +264,9 @@ class extractor(nn.Module):
class Adapter_light(nn.Module): class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__() super(Adapter_light, self).__init__()
unshuffle = 8 self.unshuffle_amount = 8
self.unshuffle = nn.PixelUnshuffle(unshuffle) self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
self.input_channels = cin // (unshuffle * unshuffle) self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []