import sd1_clip
import open_clip
import torch

class SD2ClipModel(torch.nn.Module, sd1_clip.ClipTokenWeightEncoder):
    """
    Uses the OpenCLIP transformer encoder for text
    """
    LAYERS = [
        #"pooled",
        "last",
        "penultimate",
        "hidden"
    ]
    #version="laion2b_s32b_b79k"
    def __init__(self, arch="ViT-H-14", device="cpu", max_length=77,
                 freeze=True, layer="penultimate", layer_idx=None):
        super().__init__()
        assert layer in self.LAYERS
        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
        del model.visual
        self.model = model

        self.device = device
        self.max_length = max_length
        self.empty_tokens = [[49406] + [49407] + [0] * 75]
        if freeze:
            self.freeze()
        self.layer = layer
        if self.layer == "last":
            self.layer_idx = 0
        elif self.layer == "penultimate":
            self.layer_idx = 1
        elif self.layer == "hidden":
            assert layer_idx is not None
            assert abs(layer_idx) < 24
            self.clip_layer(layer_idx)
        else:
            raise NotImplementedError()

    def freeze(self):
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def clip_layer(self, layer_idx):
        #layer_idx should have the same logic as the one for SD1
        if abs(layer_idx) >= 24:
            self.layer_idx = 0
        else:
            if layer_idx < 0:
                self.layer_idx = -(layer_idx + 1)
            else:
                self.layer_idx = 24 - (layer_idx + 1)

    def forward(self, tokens):
        tokens = torch.LongTensor(tokens).to(self.device)
        z = self.encode_with_transformer(tokens)
        return z

    def encode_with_transformer(self, tokens):
        x = self.model.token_embedding(tokens)  # [batch_size, n_ctx, d_model]
        x = x + self.model.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.model.ln_final(x)
        return x

    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
        for i, r in enumerate(self.model.transformer.resblocks):
            if i == len(self.model.transformer.resblocks) - self.layer_idx:
                break
            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(r, x, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask)
        return x

    def encode(self, tokens):
        return self(tokens)



class SD2Tokenizer(sd1_clip.SD1Tokenizer):
    def __init__(self, tokenizer_path=None):
        super().__init__(tokenizer_path, pad_with_end=False)