diff --git a/comfy/samplers.py b/comfy/samplers.py index ddec9900..59dbab53 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -348,17 +348,27 @@ def encode_adm(noise_augmentor, conds, batch_size, device): if 'adm' in x[1]: adm_inputs = [] weights = [] + noise_aug = [] adm_in = x[1]["adm"] for adm_c in adm_in: adm_cond = adm_c[0].image_embeds weight = adm_c[1] - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device)) + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) + noise_aug.append(noise_augment) adm_inputs.append(adm_out) - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: Apply Noise to Embedding Mix + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + print(noise_level) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() diff --git a/nodes.py b/nodes.py index 963ff32a..ffbba9f9 100644 --- a/nodes.py +++ b/nodes.py @@ -445,17 +445,18 @@ class unCLIPConditioning: return {"required": {"conditioning": ("CONDITIONING", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" CATEGORY = "_for_testing/unclip" - def apply_adm(self, conditioning, clip_vision_output, strength): + def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): c = [] for t in conditioning: o = t[1].copy() - x = (clip_vision_output, strength) + x = (clip_vision_output, strength, noise_augmentation) if "adm" in o: o["adm"] = o["adm"][:] + [x] else: