mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-26 00:05:18 +00:00
Fix issue with gligen.
This commit is contained in:
parent
d6e4b342e6
commit
b80c3276dc
@ -16,6 +16,8 @@ if model_management.xformers_enabled():
|
|||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
print("disabling upcasting of attention")
|
print("disabling upcasting of attention")
|
||||||
@ -51,7 +53,7 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -61,7 +63,7 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
@ -147,7 +149,7 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
class CrossAttentionBirchSan(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -244,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
class CrossAttentionDoggettx(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -342,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -398,7 +400,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads.")
|
f"{heads} heads.")
|
||||||
@ -449,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
class CrossAttentionPytorch(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -507,7 +509,7 @@ else:
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False, dtype=None, device=None, operations=None):
|
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
@ -647,7 +649,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True, dtype=None, device=None, operations=None):
|
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
|
@ -70,7 +70,7 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -106,7 +106,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -156,7 +156,7 @@ class ResBlock(TimestepBlock):
|
|||||||
down=False,
|
down=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=comfy.ops
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
Loading…
Reference in New Issue
Block a user