Per device stream counters for async offload. (#7873)

This commit is contained in:
comfyanonymous 2025-04-29 17:28:52 -07:00 committed by GitHub
parent 5c5457a4ef
commit 0a66d4b0af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -946,9 +946,9 @@ if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
stream_counter = 0
stream_counters = {}
def get_offload_stream(device):
global stream_counter
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
return None
@ -958,6 +958,7 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
ss = []
@ -966,6 +967,7 @@ def get_offload_stream(device):
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None