diff --git a/execution.py b/execution.py index 9a5e27771..53bc1c02f 100644 --- a/execution.py +++ b/execution.py @@ -1,5 +1,4 @@ import sys -import copy import logging import threading import heapq @@ -906,6 +905,16 @@ class PromptQueue: self.flags = {} server.prompt_queue = self + def _copy(self, obj): + if isinstance(obj, dict): + return {k: self._copy(v) for k, v in obj.items()} + elif isinstance(obj, tuple): + return tuple(self._copy(v) for v in obj) + elif isinstance(obj, list): + return [self._copy(v) for v in obj] + else: + return obj + def put(self, item): with self.mutex: heapq.heappush(self.queue, item) @@ -920,7 +929,7 @@ class PromptQueue: return None item = heapq.heappop(self.queue) i = self.task_counter - self.currently_running[i] = copy.deepcopy(item) + self.currently_running[i] = self._copy(item) self.task_counter += 1 self.server.queue_updated() return (item, i) @@ -939,7 +948,7 @@ class PromptQueue: status_dict: Optional[dict] = None if status is not None: - status_dict = copy.deepcopy(status._asdict()) + status_dict = self._copy(status._asdict()) self.history[prompt[1]] = { "prompt": prompt, @@ -954,7 +963,7 @@ class PromptQueue: out = [] for x in self.currently_running.values(): out += [x] - return (out, copy.deepcopy(self.queue)) + return (out, self._copy(self.queue)) def get_tasks_remaining(self): with self.mutex: @@ -993,7 +1002,7 @@ class PromptQueue: i += 1 return out elif prompt_id in self.history: - return {prompt_id: copy.deepcopy(self.history[prompt_id])} + return {prompt_id: self._copy(self.history[prompt_id])} else: return {}