From 61ad962154a0b3d4e01f103352ad8b8dc3f560cd Mon Sep 17 00:00:00 2001 From: Maruhi <38591025+maruhidd@users.noreply.github.com> Date: Wed, 25 Dec 2024 20:39:06 +0900 Subject: [PATCH] Update server.py --- server.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/server.py b/server.py index 22525507..57f062ad 100644 --- a/server.py +++ b/server.py @@ -327,15 +327,37 @@ class PromptServer(): else: return web.Response(status=400) + async def get_post_from_request(request): + post = None + if request.content_type == 'application/json': + json_data = await request.json() + post = await request.post() + mutable_post = post.copy() + for key, value in json_data.items(): + if key == "image": + file_name = json_data.get("filename") + value = base64.b64decode(value) + import io + image = aiohttp.web.FileField(name=key, filename=file_name, file=io.BytesIO(value), content_type='image/png', headers=None) + mutable_post[key] = image + elif key != "filename": + mutable_post[key] = value + post = mutable_post + + elif request.content_type == 'multipart/form-data': + post = await request.post() + else: + raise web.HTTPBadRequest(text="Unsupported content type") + return post + @routes.post("/upload/image") async def upload_image(request): - post = await request.post() + post = await get_post_from_request(request) return image_upload(post) - @routes.post("/upload/mask") async def upload_mask(request): - post = await request.post() + post = await get_post_from_request(request) def image_save_function(image, post, filepath): original_ref = json.loads(post.get("original_ref"))