diff --git a/comfy/conds.py b/comfy/conds.py index 211fb8d57..5fac44ef5 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +from comfy.model_management import device_should_use_non_blocking class CONDRegular: @@ -10,8 +11,17 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) + def _pin_memory(self, cond): + if cond.device == torch.device('cpu'): + return cond.pin_memory() + else: + return cond + def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + if device_should_use_non_blocking(device): + return self._copy_with(comfy.utils.repeat_to_batch_size(self._pin_memory(self.cond), batch_size).to(device, non_blocking=True)) + else: + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) def can_concat(self, other): if self.cond.shape != other.cond.shape: