From f30b992b18078415f7c31c6c2f5ad1513db0bf5e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 16:41:33 -0500 Subject: [PATCH] .sigma and .timestep now return tensors on the same device as the input. --- comfy/model_sampling.py | 6 +++--- comfy_extras/nodes_model_advanced.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index fac5c995..69c8b1f0 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 20261aad..efcdf193 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: