mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Lowvram mode for gligen and fix some lowvram issues.
This commit is contained in:
parent
9bd33b6bd4
commit
cb1551b819
@ -242,14 +242,28 @@ class Gligen(nn.Module):
|
||||
self.position_net = position_net
|
||||
self.key_dim = key_dim
|
||||
self.max_objs = 30
|
||||
self.lowvram = False
|
||||
|
||||
def _set_position(self, boxes, masks, positive_embeddings):
|
||||
if self.lowvram == True:
|
||||
self.position_net.to(boxes.device)
|
||||
|
||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(key, x):
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
module.cpu()
|
||||
return r
|
||||
return func_lowvram
|
||||
else:
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
@ -294,8 +308,11 @@ class Gligen(nn.Module):
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_lowvram(self, value=True):
|
||||
self.lowvram = value
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
self.lowvram = False
|
||||
|
||||
def get_models(self):
|
||||
return [self]
|
||||
|
@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
if current_index is not None:
|
||||
transformer_options["current_index"] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
@ -805,13 +818,13 @@ class UNetModel(nn.Module):
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
|
||||
@ -828,7 +841,7 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = module(h, emb, context, transformer_options, output_shape)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models):
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
for m in control_models:
|
||||
if hasattr(m, 'set_lowvram'):
|
||||
m.set_lowvram(True)
|
||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user