diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 67978d4c..88c442d1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -175,13 +175,11 @@ class CrossAttentionBirchSan(nn.Module): value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) dtype = query.dtype - # TODO: do we still need to do *everything* in float32, given how we delay the division? - # TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it - # if self.upcast_attention: - # query = query.float() - # key_t = key_t.float() - - bytes_per_token = torch.finfo(query.dtype).bits//8 + upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + if upcast_attention: + bytes_per_token = torch.finfo(torch.float32).bits//8 + else: + bytes_per_token = torch.finfo(query.dtype).bits//8 batch_x_heads, q_tokens, _ = query.shape _, _, k_tokens = key_t.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens @@ -198,7 +196,7 @@ class CrossAttentionBirchSan(nn.Module): query_chunk_size_x = 1024 * 4 kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 1.2) // 1024) * 1024 + kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 if kv_chunk_size_x < 1024: kv_chunk_size_x = None @@ -220,6 +218,7 @@ class CrossAttentionBirchSan(nn.Module): kv_chunk_size=kv_chunk_size, kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=self.training, + upcast_attention=upcast_attention, ) hidden_states = hidden_states.to(dtype) @@ -383,8 +382,15 @@ class OriginalCrossAttention(nn.Module): out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) -class CrossAttention(CrossAttentionDoggettx): - pass +import sys +if "--use-split-cross-attention" in sys.argv: + print("Using split optimization for cross attention") + class CrossAttention(CrossAttentionDoggettx): + pass +else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + class CrossAttention(CrossAttentionBirchSan): + pass class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fe9bb82c..9649f9d1 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -53,14 +53,27 @@ def _summarize_chunk( key_t: Tensor, value: Tensor, scale: float, + upcast_attention: bool, ) -> AttnChunk: - attn_weights = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key_t, - alpha=scale, - beta=0, - ) + if upcast_attention: + with torch.autocast(enabled=False, device_type = 'cuda'): + query = query.float() + key_t = key_t.float() + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + else: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) @@ -112,14 +125,27 @@ def _get_attention_scores_no_kv_chunking( key_t: Tensor, value: Tensor, scale: float, + upcast_attention: bool, ) -> Tensor: - attn_scores = torch.baddbmm( - torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key_t, - alpha=scale, - beta=0, - ) + if upcast_attention: + with torch.autocast(enabled=False, device_type = 'cuda'): + query = query.float() + key_t = key_t.float() + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + else: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores hidden_states_slice = torch.bmm(attn_probs, value) @@ -137,6 +163,7 @@ def efficient_dot_product_attention( kv_chunk_size: Optional[int] = None, kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, + upcast_attention=False, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -170,11 +197,12 @@ def efficient_dot_product_attention( (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( _get_attention_scores_no_kv_chunking, - scale=scale + scale=scale, + upcast_attention=upcast_attention ) if k_tokens <= kv_chunk_size else ( # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) partial(