mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Fix saving text encoder in fp8.
This commit is contained in:
parent
e6482fbbfc
commit
6c6a39251f
@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||||
|
def cat_tensors(tensors):
|
||||||
|
x = 0
|
||||||
|
for t in tensors:
|
||||||
|
x += t.shape[0]
|
||||||
|
|
||||||
|
shape = [x] + list(tensors[0].shape)[1:]
|
||||||
|
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
for t in tensors:
|
||||||
|
out[x:x + t.shape[0]] = t
|
||||||
|
x += t.shape[0]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
||||||
|
|
||||||
for k_pre, tensors in capture_qkv_bias.items():
|
for k_pre, tensors in capture_qkv_bias.items():
|
||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user