mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Add Flux model support for InstantX style controlnet residuals (#4444)
* Add Flux model support for InstantX style controlnet residuals * Refactor Flux controlnet residual step to a separate method * Rollback minor change * New format for applying controlnet residuals: input->double_blocks, output->single_blocks * Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals * Remove unnecessary import and minor style change
This commit is contained in:
parent
310ad09258
commit
e68763f40c
@ -82,7 +82,7 @@ class ControlNetFlux(Flux):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
@ -114,19 +114,28 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
Loading…
Reference in New Issue
Block a user