2023-01-03 06:53:32 +00:00
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch . utils . checkpoint import checkpoint
import math
2024-03-11 20:24:47 +00:00
import logging
2023-03-06 16:41:40 +00:00
try :
from typing import Optional , NamedTuple , List , Protocol
except ImportError :
from typing import Optional , NamedTuple , List
from typing_extensions import Protocol
2023-01-03 06:53:32 +00:00
from torch import Tensor
from typing import List
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-02-09 17:33:27 +00:00
2023-01-03 06:53:32 +00:00
def dynamic_slice (
x : Tensor ,
starts : List [ int ] ,
sizes : List [ int ] ,
) - > Tensor :
slicing = [ slice ( start , start + size ) for start , size in zip ( starts , sizes ) ]
return x [ slicing ]
class AttnChunk ( NamedTuple ) :
exp_values : Tensor
exp_weights_sum : Tensor
max_score : Tensor
class SummarizeChunk ( Protocol ) :
@staticmethod
def __call__ (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
) - > AttnChunk : . . .
class ComputeQueryChunkAttn ( Protocol ) :
@staticmethod
def __call__ (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
) - > Tensor : . . .
def _summarize_chunk (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
scale : float ,
2023-01-25 06:22:43 +00:00
upcast_attention : bool ,
2024-01-07 09:13:58 +00:00
mask ,
2023-01-03 06:53:32 +00:00
) - > AttnChunk :
2023-01-25 06:22:43 +00:00
if upcast_attention :
with torch . autocast ( enabled = False , device_type = ' cuda ' ) :
query = query . float ( )
key_t = key_t . float ( )
attn_weights = torch . baddbmm (
torch . empty ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
query ,
key_t ,
alpha = scale ,
beta = 0 ,
)
else :
attn_weights = torch . baddbmm (
torch . empty ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
query ,
key_t ,
alpha = scale ,
beta = 0 ,
)
2023-01-03 06:53:32 +00:00
max_score , _ = torch . max ( attn_weights , - 1 , keepdim = True )
max_score = max_score . detach ( )
2023-10-26 00:17:28 +00:00
attn_weights - = max_score
2024-01-07 09:13:58 +00:00
if mask is not None :
attn_weights + = mask
2023-10-26 00:17:28 +00:00
torch . exp ( attn_weights , out = attn_weights )
2023-07-06 00:58:44 +00:00
exp_weights = attn_weights . to ( value . dtype )
2023-01-03 06:53:32 +00:00
exp_values = torch . bmm ( exp_weights , value )
max_score = max_score . squeeze ( - 1 )
return AttnChunk ( exp_values , exp_weights . sum ( dim = - 1 ) , max_score )
def _query_chunk_attention (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
summarize_chunk : SummarizeChunk ,
kv_chunk_size : int ,
2024-01-07 09:13:58 +00:00
mask ,
2023-01-03 06:53:32 +00:00
) - > Tensor :
batch_x_heads , k_channels_per_head , k_tokens = key_t . shape
_ , _ , v_channels_per_head = value . shape
2024-01-07 09:13:58 +00:00
def chunk_scanner ( chunk_idx : int , mask ) - > AttnChunk :
2023-01-03 06:53:32 +00:00
key_chunk = dynamic_slice (
key_t ,
( 0 , 0 , chunk_idx ) ,
( batch_x_heads , k_channels_per_head , kv_chunk_size )
)
value_chunk = dynamic_slice (
value ,
( 0 , chunk_idx , 0 ) ,
( batch_x_heads , kv_chunk_size , v_channels_per_head )
)
2024-01-07 09:13:58 +00:00
if mask is not None :
mask = mask [ : , : , chunk_idx : chunk_idx + kv_chunk_size ]
return summarize_chunk ( query , key_chunk , value_chunk , mask = mask )
2023-01-03 06:53:32 +00:00
chunks : List [ AttnChunk ] = [
2024-01-07 09:13:58 +00:00
chunk_scanner ( chunk , mask ) for chunk in torch . arange ( 0 , k_tokens , kv_chunk_size )
2023-01-03 06:53:32 +00:00
]
acc_chunk = AttnChunk ( * map ( torch . stack , zip ( * chunks ) ) )
chunk_values , chunk_weights , chunk_max = acc_chunk
global_max , _ = torch . max ( chunk_max , 0 , keepdim = True )
max_diffs = torch . exp ( chunk_max - global_max )
chunk_values * = torch . unsqueeze ( max_diffs , - 1 )
chunk_weights * = max_diffs
all_values = chunk_values . sum ( dim = 0 )
all_weights = torch . unsqueeze ( chunk_weights , - 1 ) . sum ( dim = 0 )
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
scale : float ,
2023-01-25 06:22:43 +00:00
upcast_attention : bool ,
2024-01-07 09:13:58 +00:00
mask ,
2023-01-03 06:53:32 +00:00
) - > Tensor :
2023-01-25 06:22:43 +00:00
if upcast_attention :
with torch . autocast ( enabled = False , device_type = ' cuda ' ) :
query = query . float ( )
key_t = key_t . float ( )
attn_scores = torch . baddbmm (
torch . empty ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
query ,
key_t ,
alpha = scale ,
beta = 0 ,
)
else :
attn_scores = torch . baddbmm (
torch . empty ( 1 , 1 , 1 , device = query . device , dtype = query . dtype ) ,
query ,
key_t ,
alpha = scale ,
beta = 0 ,
)
2023-01-31 00:55:01 +00:00
2024-01-07 09:13:58 +00:00
if mask is not None :
attn_scores + = mask
2023-01-31 00:55:01 +00:00
try :
attn_probs = attn_scores . softmax ( dim = - 1 )
del attn_scores
2023-03-22 18:49:00 +00:00
except model_management . OOM_EXCEPTION :
2024-03-11 20:24:47 +00:00
logging . warning ( " ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead " )
2024-12-12 23:49:40 +00:00
attn_scores - = attn_scores . max ( dim = - 1 , keepdim = True ) . values # noqa: F821 attn_scores is not defined
2023-01-31 00:55:01 +00:00
torch . exp ( attn_scores , out = attn_scores )
summed = torch . sum ( attn_scores , dim = - 1 , keepdim = True )
attn_scores / = summed
attn_probs = attn_scores
2023-07-06 00:58:44 +00:00
hidden_states_slice = torch . bmm ( attn_probs . to ( value . dtype ) , value )
2023-01-03 06:53:32 +00:00
return hidden_states_slice
class ScannedChunk ( NamedTuple ) :
chunk_idx : int
attn_chunk : AttnChunk
def efficient_dot_product_attention (
query : Tensor ,
key_t : Tensor ,
value : Tensor ,
query_chunk_size = 1024 ,
kv_chunk_size : Optional [ int ] = None ,
kv_chunk_size_min : Optional [ int ] = None ,
use_checkpoint = True ,
2023-01-25 06:22:43 +00:00
upcast_attention = False ,
2024-01-07 09:13:58 +00:00
mask = None ,
2023-01-03 06:53:32 +00:00
) :
""" Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
https : / / arxiv . org / abs / 2112.05682 v2 which comes with O ( sqrt ( n ) ) memory requirements .
Args :
query : queries for calculating attention with shape of
` [ batch * num_heads , tokens , channels_per_head ] ` .
key_t : keys for calculating attention with shape of
` [ batch * num_heads , channels_per_head , tokens ] ` .
value : values to be used in attention with shape of
` [ batch * num_heads , tokens , channels_per_head ] ` .
query_chunk_size : int : query chunks size
kv_chunk_size : Optional [ int ] : key / value chunks size . if None : defaults to sqrt ( key_tokens )
kv_chunk_size_min : Optional [ int ] : key / value minimum chunk size . only considered when kv_chunk_size is None . changes ` sqrt ( key_tokens ) ` into ` max ( sqrt ( key_tokens ) , kv_chunk_size_min ) ` , to ensure our chunk sizes don ' t get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint : bool : whether to use checkpointing ( recommended True for training , False for inference )
Returns :
Output of shape ` [ batch * num_heads , query_tokens , channels_per_head ] ` .
"""
batch_x_heads , q_tokens , q_channels_per_head = query . shape
_ , _ , k_tokens = key_t . shape
scale = q_channels_per_head * * - 0.5
kv_chunk_size = min ( kv_chunk_size or int ( math . sqrt ( k_tokens ) ) , k_tokens )
if kv_chunk_size_min is not None :
kv_chunk_size = max ( kv_chunk_size , kv_chunk_size_min )
2024-01-07 09:13:58 +00:00
if mask is not None and len ( mask . shape ) == 2 :
mask = mask . unsqueeze ( 0 )
2023-01-03 06:53:32 +00:00
def get_query_chunk ( chunk_idx : int ) - > Tensor :
return dynamic_slice (
query ,
( 0 , chunk_idx , 0 ) ,
( batch_x_heads , min ( query_chunk_size , q_tokens ) , q_channels_per_head )
)
2024-01-07 09:13:58 +00:00
def get_mask_chunk ( chunk_idx : int ) - > Tensor :
if mask is None :
return None
2024-11-22 07:10:09 +00:00
if mask . shape [ 1 ] == 1 :
return mask
2024-01-07 09:13:58 +00:00
chunk = min ( query_chunk_size , q_tokens )
return mask [ : , chunk_idx : chunk_idx + chunk ]
2023-01-25 06:22:43 +00:00
summarize_chunk : SummarizeChunk = partial ( _summarize_chunk , scale = scale , upcast_attention = upcast_attention )
2023-01-03 06:53:32 +00:00
summarize_chunk : SummarizeChunk = partial ( checkpoint , summarize_chunk ) if use_checkpoint else summarize_chunk
compute_query_chunk_attn : ComputeQueryChunkAttn = partial (
_get_attention_scores_no_kv_chunking ,
2023-01-25 06:22:43 +00:00
scale = scale ,
upcast_attention = upcast_attention
2023-01-03 06:53:32 +00:00
) if k_tokens < = kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial (
_query_chunk_attention ,
kv_chunk_size = kv_chunk_size ,
summarize_chunk = summarize_chunk ,
)
)
if q_tokens < = query_chunk_size :
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn (
query = query ,
key_t = key_t ,
value = value ,
2024-01-07 09:13:58 +00:00
mask = mask ,
2023-01-03 06:53:32 +00:00
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch . cat ( [
compute_query_chunk_attn (
query = get_query_chunk ( i * query_chunk_size ) ,
key_t = key_t ,
value = value ,
2024-01-07 09:13:58 +00:00
mask = get_mask_chunk ( i * query_chunk_size )
2023-01-03 06:53:32 +00:00
) for i in range ( math . ceil ( q_tokens / query_chunk_size ) )
] , dim = 1 )
return res