From b80c3276dce510973c24d1c9b7fb48be36292396 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 18 Aug 2023 16:32:23 -0400 Subject: [PATCH] Fix issue with gligen. --- comfy/ldm/modules/attention.py | 20 ++++++++++--------- .../modules/diffusionmodules/openaimodel.py | 6 +++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 87a4aa807..973619bf2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -16,6 +16,8 @@ if model_management.xformers_enabled(): import xformers.ops from comfy.cli_args import args +import comfy.ops + # CrossAttn precision handling if args.dont_upcast_attention: print("disabling upcasting of attention") @@ -51,7 +53,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): super().__init__() self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) @@ -61,7 +63,7 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) @@ -147,7 +149,7 @@ class SpatialSelfAttention(nn.Module): class CrossAttentionBirchSan(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -244,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module): class CrossAttentionDoggettx(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -342,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module): return self.to_out(r2) class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -398,7 +400,7 @@ class CrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops): super().__init__() print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"{heads} heads.") @@ -449,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module): return self.to_out(out) class CrossAttentionPytorch(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -507,7 +509,7 @@ else: class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None, operations=None): + disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): super().__init__() self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, @@ -647,7 +649,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=None): + use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 8063adb85..11cec0eda 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -70,7 +70,7 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -106,7 +106,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -156,7 +156,7 @@ class ResBlock(TimestepBlock): down=False, dtype=None, device=None, - operations=None + operations=comfy.ops ): super().__init__() self.channels = channels