mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Add a way to set the timestep multiplier in the flow sampling.
This commit is contained in:
parent
ff63893d10
commit
2dc84d1444
@ -190,11 +190,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
|
||||
|
||||
def set_parameters(self, shift=1.0, timesteps=1000):
|
||||
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.multiplier = multiplier
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
@ -206,10 +207,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
return sigma * self.multiplier
|
||||
|
||||
def sigma(self, timestep):
|
||||
return time_snr_shift(self.shift, timestep / 1000)
|
||||
return time_snr_shift(self.shift, timestep / self.multiplier)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
|
Loading…
Reference in New Issue
Block a user