diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 037dbf28..ad4b4623 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Any -from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit @@ -81,6 +81,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, + optimized_attention=None, ): batch_size, seq_length, _ = hidden_states.shape @@ -124,6 +125,7 @@ class TransformerBlock(nn.Module): x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, + optimized_attention=None, ): # Self Attention residual = x @@ -132,6 +134,7 @@ class TransformerBlock(nn.Module): hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, ) x = residual + x @@ -180,6 +183,7 @@ class Llama2_(nn.Module): mask += causal_mask else: mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None if intermediate_output is not None: @@ -191,6 +195,7 @@ class Llama2_(nn.Module): x=x, attention_mask=mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, ) if i == intermediate_output: intermediate = x.clone()