Add a way to set the timestep multiplier in the flow sampling.

This commit is contained in:
comfyanonymous 2024-07-06 04:06:03 -04:00
parent ff63893d10
commit 2dc84d1444

View File

@ -190,11 +190,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
def set_parameters(self, shift=1.0, timesteps=1000):
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.multiplier = multiplier
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
@property
@ -206,10 +207,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
return sigma * self.multiplier
def sigma(self, timestep):
return time_snr_shift(self.shift, timestep / 1000)
return time_snr_shift(self.shift, timestep / self.multiplier)
def percent_to_sigma(self, percent):
if percent <= 0.0: