diff --git a/main.py b/main.py index a8f376cc..9983859b 100644 --- a/main.py +++ b/main.py @@ -10,13 +10,35 @@ import torch import nodes +def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): + valid_inputs = class_def.INPUT_TYPES() + input_data_all = {} + for x in inputs: + input_data = inputs[x] + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + obj = outputs[input_unique_id][output_index] + input_data_all[x] = obj + else: + if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): + input_data_all[x] = input_data + + if "hidden" in valid_inputs: + h = valid_inputs["hidden"] + for x in h: + if h[x] == "PROMPT": + input_data_all[x] = prompt + if h[x] == "EXTRA_PNGINFO": + if "extra_pnginfo" in extra_data: + input_data_all[x] = extra_data['extra_pnginfo'] + return input_data_all def recursive_execute(prompt, outputs, current_item, extra_data={}): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] - c_obj = nodes.NODE_CLASS_MAPPINGS[class_type] - valid_inputs = c_obj.INPUT_TYPES() + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: return [] @@ -31,28 +53,8 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}): if input_unique_id not in outputs: executed += recursive_execute(prompt, outputs, input_unique_id, extra_data) - input_data_all = {} - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - obj = outputs[input_unique_id][output_index] - input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data - - - obj = c_obj() - if "hidden" in valid_inputs: - h = valid_inputs["hidden"] - for x in h: - if h[x] == "PROMPT": - input_data_all[x] = prompt - if h[x] == "EXTRA_PNGINFO": - if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data) + obj = class_def() outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) return executed + [unique_id] @@ -61,12 +63,27 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + is_changed_old = '' + is_changed = '' + if hasattr(class_def, 'IS_CHANGED'): + if 'is_changed' not in prompt[unique_id]: + if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: + is_changed_old = old_prompt[unique_id]['is_changed'] + input_data_all = get_input_data(inputs, class_def) + is_changed = class_def.IS_CHANGED(**input_data_all) + prompt[unique_id]['is_changed'] = is_changed + else: + is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True to_delete = False - if unique_id not in old_prompt: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: to_delete = True elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: diff --git a/nodes.py b/nodes.py index 20ae2838..031c2620 100644 --- a/nodes.py +++ b/nodes.py @@ -3,6 +3,7 @@ import torch import os import sys import json +import hashlib from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -226,6 +227,14 @@ class LoadImage: image = torch.from_numpy(image[None])[None,] return image + @classmethod + def IS_CHANGED(s, image): + image_path = os.path.join(s.input_dir, image) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + NODE_CLASS_MAPPINGS = {