Compare commits

...

4 Commits

Author SHA1 Message Date
drhead
a58ef65212
Merge d175cbd315 into e346d8584e 2025-04-09 20:15:32 +03:00
drhead
d175cbd315
fix syntax 2025-03-09 19:57:51 -04:00
drhead
8e5f33cc9c
better way of gating non blocking use 2025-03-09 19:45:29 -04:00
drhead
c7b8e7250d
make conds use non-blocking transfers 2025-03-09 19:42:19 -04:00

View File

@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
from comfy.model_management import device_should_use_non_blocking
class CONDRegular:
@ -10,8 +11,17 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def _pin_memory(self, cond):
if cond.device == torch.device('cpu'):
return cond.pin_memory()
else:
return cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
if device_should_use_non_blocking(device):
return self._copy_with(comfy.utils.repeat_to_batch_size(self._pin_memory(self.cond), batch_size).to(device, non_blocking=True))
else:
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape: