diff --git a/comfy/cldm/mmdit.py b/comfy/cldm/mmdit.py
new file mode 100644
index 00000000..6e72474c
--- /dev/null
+++ b/comfy/cldm/mmdit.py
@@ -0,0 +1,91 @@
+import torch
+from typing import Dict, Optional
+import comfy.ldm.modules.diffusionmodules.mmdit
+import comfy.latent_formats
+
+class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
+    def __init__(
+        self,
+        num_blocks = None,
+        dtype = None,
+        device = None,
+        operations = None,
+        **kwargs,
+    ):
+        super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
+        # controlnet_blocks
+        self.controlnet_blocks = torch.nn.ModuleList([])
+        for _ in range(len(self.joint_blocks)):
+            self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
+
+        self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
+            None,
+            self.patch_size,
+            self.in_channels,
+            self.hidden_size,
+            bias=True,
+            strict_img_size=False,
+            dtype=dtype,
+            device=device,
+            operations=operations
+        )
+
+        self.latent_format = comfy.latent_formats.SD3()
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        timesteps: torch.Tensor,
+        y: Optional[torch.Tensor] = None,
+        context: Optional[torch.Tensor] = None,
+        hint = None,
+    ) -> torch.Tensor:
+
+        #weird sd3 controlnet specific stuff
+        hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint)
+        y = torch.zeros_like(y)
+
+
+        if self.context_processor is not None:
+            context = self.context_processor(context)
+
+        hw = x.shape[-2:]
+        x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
+        x += self.pos_embed_input(hint)
+
+        c = self.t_embedder(timesteps, dtype=x.dtype)
+        if y is not None and self.y_embedder is not None:
+            y = self.y_embedder(y)
+            c = c + y
+
+        if context is not None:
+            context = self.context_embedder(context)
+
+        if self.register_length > 0:
+            context = torch.cat(
+                (
+                    repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
+                    default(context, torch.Tensor([]).type_as(x)),
+                ),
+                1,
+            )
+
+        output = []
+
+        blocks = len(self.joint_blocks)
+        for i in range(blocks):
+            context, x = self.joint_blocks[i](
+                context,
+                x,
+                c=c,
+                use_checkpoint=self.use_checkpoint,
+            )
+
+            out = self.controlnet_blocks[i](x)
+            count = self.depth // blocks
+            if i == blocks - 1:
+                count -= 1
+            for j in range(count):
+                output.append(out)
+
+        return {"output": output}
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index f50df683..9202c319 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -11,6 +11,7 @@ import comfy.ops
 import comfy.cldm.cldm
 import comfy.t2i_adapter.adapter
 import comfy.ldm.cascade.controlnet
+import comfy.cldm.mmdit
 
 
 def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -94,13 +95,17 @@ class ControlBase:
 
         for key in control:
             control_output = control[key]
+            applied_to = set()
             for i in range(len(control_output)):
                 x = control_output[i]
                 if x is not None:
                     if self.global_average_pooling:
                         x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
 
-                    x *= self.strength
+                    if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
+                        applied_to.add(x)
+                        x *= self.strength
+
                     if x.dtype != output_dtype:
                         x = x.to(output_dtype)
 
@@ -120,17 +125,18 @@ class ControlBase:
                             if o[i].shape[0] < prev_val.shape[0]:
                                 o[i] = prev_val + o[i]
                             else:
-                                o[i] += prev_val
+                                o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
         return out
 
 class ControlNet(ControlBase):
-    def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
+    def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None):
         super().__init__(device)
         self.control_model = control_model
         self.load_device = load_device
         if control_model is not None:
             self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
 
+        self.compression_ratio = compression_ratio
         self.global_average_pooling = global_average_pooling
         self.model_sampling_current = None
         self.manual_cast_dtype = manual_cast_dtype
@@ -308,6 +314,37 @@ class ControlLora(ControlNet):
     def inference_memory_requirements(self, dtype):
         return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
 
+def load_controlnet_mmdit(sd):
+    new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
+    model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
+    num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
+    for k in sd:
+        new_sd[k] = sd[k]
+
+    supported_inference_dtypes = model_config.supported_inference_dtypes
+
+    controlnet_config = model_config.unet_config
+    unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
+    load_device = comfy.model_management.get_torch_device()
+    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+    if manual_cast_dtype is not None:
+        operations = comfy.ops.manual_cast
+    else:
+        operations = comfy.ops.disable_weight_init
+
+    control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
+    missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
+
+    if len(missing) > 0:
+        logging.warning("missing controlnet keys: {}".format(missing))
+
+    if len(unexpected) > 0:
+        logging.debug("unexpected controlnet keys: {}".format(unexpected))
+
+    control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+    return control
+
+
 def load_controlnet(ckpt_path, model=None):
     controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
     if "lora_controlnet" in controlnet_data:
@@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None):
         if len(leftover_keys) > 0:
             logging.warning("leftover keys: {}".format(leftover_keys))
         controlnet_data = new_sd
+    elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
+        return load_controlnet_mmdit(controlnet_data)
 
     pth_key = 'control_model.zero_convs.0.0.weight'
     pth = False
diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py
index 20d3a321..92745153 100644
--- a/comfy/ldm/modules/diffusionmodules/mmdit.py
+++ b/comfy/ldm/modules/diffusionmodules/mmdit.py
@@ -745,6 +745,8 @@ class MMDiT(nn.Module):
         qkv_bias: bool = True,
         context_processor_layers = None,
         context_size = 4096,
+        num_blocks = None,
+        final_layer = True,
         dtype = None, #TODO
         device = None,
         operations = None,
@@ -766,7 +768,10 @@ class MMDiT(nn.Module):
         # apply magic --> this defines a head_size of 64
         self.hidden_size = 64 * depth
         num_heads = depth
+        if num_blocks is None:
+            num_blocks = depth
 
+        self.depth = depth
         self.num_heads = num_heads
 
         self.x_embedder = PatchEmbed(
@@ -821,7 +826,7 @@ class MMDiT(nn.Module):
                     mlp_ratio=mlp_ratio,
                     qkv_bias=qkv_bias,
                     attn_mode=attn_mode,
-                    pre_only=i == depth - 1,
+                    pre_only=(i == num_blocks - 1) and final_layer,
                     rmsnorm=rmsnorm,
                     scale_mod_only=scale_mod_only,
                     swiglu=swiglu,
@@ -830,11 +835,12 @@ class MMDiT(nn.Module):
                     device=device,
                     operations=operations
                 )
-                for i in range(depth)
+                for i in range(num_blocks)
             ]
         )
 
-        self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
+        if final_layer:
+            self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
 
         if compile_core:
             assert False
@@ -893,6 +899,7 @@ class MMDiT(nn.Module):
         x: torch.Tensor,
         c_mod: torch.Tensor,
         context: Optional[torch.Tensor] = None,
+        control = None,
     ) -> torch.Tensor:
         if self.register_length > 0:
             context = torch.cat(
@@ -905,13 +912,20 @@ class MMDiT(nn.Module):
 
         # context is B, L', D
         # x is B, L, D
-        for block in self.joint_blocks:
-            context, x = block(
+        blocks = len(self.joint_blocks)
+        for i in range(blocks):
+            context, x = self.joint_blocks[i](
                 context,
                 x,
                 c=c_mod,
                 use_checkpoint=self.use_checkpoint,
             )
+            if control is not None:
+                control_o = control.get("output")
+                if i < len(control_o):
+                    add = control_o[i]
+                    if add is not None:
+                        x += add
 
         x = self.final_layer(x, c_mod)  # (N, T, patch_size ** 2 * out_channels)
         return x
@@ -922,6 +936,7 @@ class MMDiT(nn.Module):
         t: torch.Tensor,
         y: Optional[torch.Tensor] = None,
         context: Optional[torch.Tensor] = None,
+        control = None,
     ) -> torch.Tensor:
         """
         Forward pass of DiT.
@@ -943,7 +958,7 @@ class MMDiT(nn.Module):
         if context is not None:
             context = self.context_embedder(context)
 
-        x = self.forward_core_with_concat(x, c, context)
+        x = self.forward_core_with_concat(x, c, context, control)
 
         x = self.unpatchify(x, hw=hw)  # (N, out_channels, H, W)
         return x[:,:,:hw[-2],:hw[-1]]
@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
         timesteps: torch.Tensor,
         context: Optional[torch.Tensor] = None,
         y: Optional[torch.Tensor] = None,
+        control = None,
         **kwargs,
     ) -> torch.Tensor:
-        return super().forward(x, timesteps, context=context, y=y)
+        return super().forward(x, timesteps, context=context, y=y, control=control)
 
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index e09dd381..0b678480 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix):
         unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
         patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
         unet_config["patch_size"] = patch_size
-        unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
+        final_layer = '{}final_layer.linear.weight'.format(key_prefix)
+        if final_layer in state_dict:
+            unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
 
         unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
         unet_config["input_size"] = None
@@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict):
     return None
 
 def convert_diffusers_mmdit(state_dict, output_prefix=""):
-    depth = count_blocks(state_dict, 'transformer_blocks.{}.')
-    if depth > 0:
+    num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
+    if num_blocks > 0:
+        depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
         out_sd = {}
-        sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
+        sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
         for k in sd_map:
             weight = state_dict.get(k, None)
             if weight is not None:
diff --git a/comfy/utils.py b/comfy/utils.py
index ed6c58a6..48618e07 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
     key_map = {}
 
     depth = mmdit_config.get("depth", 0)
-    for i in range(depth):
+    num_blocks = mmdit_config.get("num_blocks", depth)
+    for i in range(num_blocks):
         block_from = "transformer_blocks.{}".format(i)
         block_to = "{}joint_blocks.{}".format(output_prefix, i)