diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 6bd3a5d79..2d95a83dc 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -190,11 +190,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): else: sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.0)) + self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000)) - def set_parameters(self, shift=1.0, timesteps=1000): + def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000): self.shift = shift - ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.multiplier = multiplier + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) self.register_buffer('sigmas', ts) @property @@ -206,10 +207,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): return self.sigmas[-1] def timestep(self, sigma): - return sigma * 1000 + return sigma * self.multiplier def sigma(self, timestep): - return time_snr_shift(self.shift, timestep / 1000) + return time_snr_shift(self.shift, timestep / self.multiplier) def percent_to_sigma(self, percent): if percent <= 0.0: