diff --git a/main.py b/main.py index bc0af3dd4..e0035fdc6 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ import sys import copy import json import threading -import queue +import heapq import traceback if '--dont-upcast-attention' in sys.argv: @@ -148,6 +148,7 @@ class PromptExecutor: to_execute += [(0, x)] while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] @@ -266,10 +267,63 @@ def validate_prompt(prompt): def prompt_worker(q): e = PromptExecutor() while True: - item = q.get() + item, item_id = q.get() e.execute(item[-2], item[-1]) - q.task_done() + q.task_done(item_id) +class PromptQueue: + def __init__(self): + self.mutex = threading.RLock() + self.not_empty = threading.Condition(self.mutex) + self.task_counter = 0 + self.queue = [] + self.currently_running = {} + + def put(self, item): + with self.mutex: + heapq.heappush(self.queue, item) + self.not_empty.notify() + + def get(self): + with self.not_empty: + while len(self.queue) == 0: + self.not_empty.wait() + item = heapq.heappop(self.queue) + i = self.task_counter + self.currently_running[i] = copy.deepcopy(item) + self.task_counter += 1 + return (item, i) + + def task_done(self, item_id): + with self.mutex: + self.currently_running.pop(item_id) + + def get_current_queue(self): + with self.mutex: + out = [] + for x in self.currently_running.values(): + out += [x] + return (out, copy.deepcopy(self.queue)) + + def get_tasks_remaining(self): + with self.mutex: + return len(self.queue) + len(self.currently_running) + + def wipe_queue(self): + with self.mutex: + self.queue = [] + + def delete_queue_item(self, function): + with self.mutex: + for x in range(len(self.queue)): + if function(self.queue[x]): + if len(self.queue) == 1: + self.wipe_queue() + else: + self.queue.pop(x) + heapq.heapify(self.queue) + return True + return False from http.server import BaseHTTPRequestHandler, HTTPServer @@ -285,9 +339,16 @@ class PromptServer(BaseHTTPRequestHandler): self._set_headers(ct='application/json') prompt_info = {} exec_info = {} - exec_info['queue_remaining'] = self.server.prompt_queue.unfinished_tasks + exec_info['queue_remaining'] = self.server.prompt_queue.get_tasks_remaining() prompt_info['exec_info'] = exec_info self.wfile.write(json.dumps(prompt_info).encode('utf-8')) + elif self.path == "/queue": + self._set_headers(ct='application/json') + queue_info = {} + current_queue = self.server.prompt_queue.get_current_queue() + queue_info['queue_running'] = current_queue[0] + queue_info['queue_pending'] = current_queue[1] + self.wfile.write(json.dumps(queue_info).encode('utf-8')) elif self.path == "/object_info": self._set_headers(ct='application/json') out = {} @@ -325,12 +386,16 @@ class PromptServer(BaseHTTPRequestHandler): out_string = "" if self.path == "/prompt": print("got prompt") - self.data_string = self.rfile.read(int(self.headers['Content-Length'])) - json_data = json.loads(self.data_string) + data_string = self.rfile.read(int(self.headers['Content-Length'])) + json_data = json.loads(data_string) if "number" in json_data: number = float(json_data['number']) else: number = self.server.number + if "front" in json_data: + if json_data['front']: + number = -number + self.server.number += 1 if "prompt" in json_data: prompt = json_data["prompt"] @@ -344,6 +409,18 @@ class PromptServer(BaseHTTPRequestHandler): resp_code = 400 out_string = valid[1] print("invalid prompt:", valid[1]) + elif self.path == "/queue": + data_string = self.rfile.read(int(self.headers['Content-Length'])) + json_data = json.loads(data_string) + if "clear" in json_data: + if json_data["clear"]: + self.server.prompt_queue.wipe_queue() + if "delete" in json_data: + to_delete = json_data['delete'] + for id_to_delete in to_delete: + delete_func = lambda a: a[1] == int(id_to_delete) + self.server.prompt_queue.delete_queue_item(delete_func) + self._set_headers(code=resp_code) self.end_headers() self.wfile.write(out_string.encode('utf8')) @@ -366,7 +443,7 @@ def run(prompt_queue, address='', port=8188): if __name__ == "__main__": - q = queue.PriorityQueue() + q = PromptQueue() threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() run(q, address='127.0.0.1', port=8188)