mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Split optimization for VAE attention block.
This commit is contained in:
parent
5b4e312749
commit
e8c499ddd4
@ -186,18 +186,60 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
b,c,h,w = q.shape
|
||||||
|
scale = (int(c)**(-0.5))
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
q = q.reshape(b,c,h*w)
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
|
||||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b,c,h*w)
|
v = v.reshape(b,c,h*w)
|
||||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
|
||||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
r1 = torch.zeros_like(k, device=q.device)
|
||||||
h_ = h_.reshape(b,c,h,w)
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
first_op_done = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||||
|
first_op_done = True
|
||||||
|
|
||||||
|
torch.exp(s1, out=s1)
|
||||||
|
summed = torch.sum(s1, dim=2, keepdim=True)
|
||||||
|
s1 /= summed
|
||||||
|
s2 = s1.permute(0,2,1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
|
del s2
|
||||||
|
break
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
if first_op_done == False:
|
||||||
|
steps *= 2
|
||||||
|
if steps > 128:
|
||||||
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
h_ = r1.reshape(b,c,h,w)
|
||||||
|
del r1
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user