diff --git a/main.py b/main.py index b2b3f1c4..889e2cef 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import os import sys +import shutil import threading import asyncio @@ -53,7 +54,14 @@ def hijack_progress(server): return v setattr(tqdm, "update", wrapped_func) +def cleanup_temp(): + temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + if __name__ == "__main__": + cleanup_temp() + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server = server.PromptServer(loop) @@ -93,3 +101,4 @@ if __name__ == "__main__": else: loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + cleanup_temp() diff --git a/nodes.py b/nodes.py index 0a0a0a9c..650d7f65 100644 --- a/nodes.py +++ b/nodes.py @@ -775,6 +775,7 @@ class KSamplerAdvanced: class SaveImage: def __init__(self): self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + self.url_suffix = "" @classmethod def INPUT_TYPES(s): @@ -808,6 +809,9 @@ class SaveImage: os.mkdir(self.output_dir) counter = 1 + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + paths = list() for image in images: i = 255. * image.cpu().numpy() @@ -820,10 +824,22 @@ class SaveImage: metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename_prefix}_{counter:05}_.png" img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True) - paths.append(file) + paths.append(file + self.url_suffix) counter += 1 return { "ui": { "images": paths } } +class PreviewImage(SaveImage): + def __init__(self): + self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + self.url_suffix = "?type=temp" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + class LoadImage: input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") @classmethod @@ -944,6 +960,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, + "PreviewImage": PreviewImage, "LoadImage": LoadImage, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, diff --git a/server.py b/server.py index a29d8597..eb685701 100644 --- a/server.py +++ b/server.py @@ -113,7 +113,7 @@ class PromptServer(): async def view_image(request): if "file" in request.match_info: type = request.rel_url.query.get("type", "output") - if type != "output" and type != "input": + if type not in ["output", "input", "temp"]: return web.Response(status=400) output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) diff --git a/web/scripts/app.js b/web/scripts/app.js index e70e1c15..445bc5d4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -142,7 +142,14 @@ class ComfyApp { if (numImages === 1 && !imageIndex) { this.imageIndex = imageIndex = 0; } - let shiftY = this.type === "SaveImage" ? 55 : this.imageOffset || 0; + + let shiftY; + if (this.imageOffset != null) { + shiftY = this.imageOffset; + } else { + shiftY = this.computeSize()[1]; + } + let dw = this.size[0]; let dh = this.size[1]; dh -= shiftY; @@ -497,7 +504,11 @@ class ComfyApp { if (Array.isArray(type)) { // Enums e.g. latent rotation - this.addWidget("combo", inputName, type[0], () => {}, { values: type }); + let defaultValue = type[0]; + if (inputData[1] && inputData[1].default) { + defaultValue = inputData[1].default; + } + this.addWidget("combo", inputName, defaultValue, () => {}, { values: type }); } else if (`${type}:${inputName}` in widgets) { // Support custom widgets by Type:Name Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});