import re
import hashlib
from typing import Any, List, Optional, Tuple

import torch


_HASH_SHOT = re.compile(r"^\s*#\s*shot\s*(\d+)\s*[:：]?\s*$", re.IGNORECASE)
_TAG_SHOT = re.compile(r"^\s*\[?\s*shot\s*(\d+)\s*\]?\s*[:：]?\s*", re.IGNORECASE)


def _txt(x: Any) -> str:
    if x is None:
        return ""
    if isinstance(x, str):
        return x
    if isinstance(x, (list, tuple)):
        if not x:
            return ""
        if len(x) == 1 and isinstance(x[0], str):
            return x[0]
        return "\n".join(str(i) for i in x if i is not None)
    return str(x)


def _sha12(pos: List[str], neg: List[str]) -> str:
    h = hashlib.sha1()
    for s in pos:
        h.update((s or "").encode("utf-8", errors="ignore"))
        h.update(b"\0")
    h.update(b"\n--NEG--\n")
    for s in neg:
        h.update((s or "").encode("utf-8", errors="ignore"))
        h.update(b"\0")
    return h.hexdigest()[:12]


def _clean(block: str, joiner: str, strip_lines: bool, drop_empty: bool) -> str:
    s = block or ""
    if not s:
        return ""
    if not (strip_lines or drop_empty):
        return s.strip()

    lines = s.splitlines()
    out: List[str] = []
    for ln in lines:
        ln2 = ln.strip() if strip_lines else ln
        if drop_empty and not ln2:
            continue
        out.append(ln2)

    j = " " if joiner is None else joiner
    return (j.join(out) if j != "" else "\n".join(out)).strip()


def _split_marker(text: str, marker: str, drop_header: bool) -> List[str]:
    if not text or not marker or marker not in text:
        return [text.strip()] if text and text.strip() else []

    parts = [p.strip() for p in text.split(marker)]
    if not parts:
        return []

    if drop_header:
        parts = parts[1:] if len(parts) > 1 else []

    return [p for p in parts if p]


def _split_hash(text: str) -> List[str]:
    if not text:
        return []

    lines = text.splitlines()
    shots: List[List[str]] = []
    cur: Optional[List[str]] = None

    for ln in lines:
        if _HASH_SHOT.match(ln):
            cur = []
            shots.append(cur)
            continue
        if cur is None:
            continue
        cur.append(ln)

    out: List[str] = []
    for chunk in shots:
        s = "\n".join(chunk).strip()
        if s:
            out.append(s)
    return out


def _split_tag(text: str) -> List[str]:
    if not text:
        return []

    lines = text.splitlines()
    shots: List[List[str]] = []
    cur: Optional[List[str]] = None

    for ln in lines:
        if _TAG_SHOT.match(ln):
            cur = []
            shots.append(cur)
            continue
        if cur is None:
            continue
        cur.append(ln)

    out: List[str] = []
    for chunk in shots:
        s = "\n".join(chunk).strip()
        if s:
            out.append(s)
    return out


def _split_auto(text: str, marker: str, drop_header: bool) -> List[str]:
    if not text or not text.strip():
        return []
    if marker and marker in text:
        return _split_marker(text, marker, drop_header)
    if _HASH_SHOT.search(text):
        return _split_hash(text)
    if _TAG_SHOT.search(text):
        return _split_tag(text)

    return [c.strip() for c in re.split(r"\n\s*\n+", text) if c.strip()]


def _neg_fill(neg: List[str], n: int) -> List[str]:
    if n <= 0:
        return []
    if not neg:
        return [""] * n
    if len(neg) == 1:
        return [neg[0]] * n
    if len(neg) < n:
        return neg + [neg[0]] * (n - len(neg))
    return neg[:n]


def _pad_len(x: torch.Tensor, target_len: int) -> torch.Tensor:
    if x is None or x.dim() != 3:
        return x

    b, l, c = x.shape
    if l == target_len:
        return x
    if l > target_len:
        return x[:, :target_len, :]

    pad = torch.zeros((b, target_len - l, c), device=x.device, dtype=x.dtype)
    return torch.cat([x, pad], dim=1)


def _pooled_from(cond: torch.Tensor, mode: str) -> torch.Tensor:
    if cond is None or cond.dim() != 3:
        raise RuntimeError("pooled_fallback 需要 cond 为 [B,L,C] 张量")
    if mode == "first":
        return cond[:, 0, :]
    if mode == "zeros":
        b, _, c = cond.shape
        return torch.zeros((b, c), device=cond.device, dtype=cond.dtype)
    return cond.mean(dim=1)


def _encode_one(clip: Any, text: str) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    if clip is None:
        raise RuntimeError("clip 为空")

    t = text or ""

    if hasattr(clip, "tokenize") and hasattr(clip, "encode_from_tokens"):
        tokens = clip.tokenize(t)
        try:
            cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
            return cond, pooled
        except TypeError:
            out = clip.encode_from_tokens(tokens)
            if isinstance(out, tuple) and len(out) >= 2:
                return out[0], out[1]
            return out, None

    if hasattr(clip, "encode"):
        out = clip.encode(t)
        if isinstance(out, tuple) and len(out) >= 2:
            return out[0], out[1]
        if torch.is_tensor(out):
            return out, None

    raise RuntimeError("不支持的 clip：缺少 tokenize/encode_from_tokens 或 encode")


class StoryboardPromptParserBatch:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "positive_block": ("STRING", {
                    "multiline": True,
                    "default": "",
                    "tooltip": "正面分镜输入"
                }),
                "negative_block": ("STRING", {
                    "multiline": True,
                    "default": "",
                    "tooltip": "负面词输入（只写一段会自动复制到所有分镜）"
                }),

                "split_mode": (["auto(推荐)", "marker(分隔符)", "hash(井号分镜)"], {
                    "default": "auto(推荐)",
                    "tooltip": "auto：marker > #shot > [shot] > 空行；marker：按 shot_delim；hash：按 #shot"
                }),
                "shot_delim": ("STRING", {
                    "default": "===SHOT===",
                    "tooltip": "marker 模式的分隔符；auto 模式下只有文本中出现该符号才会按它切"
                }),

                "joiner": ("STRING", {"default": " ", "tooltip": "同一分镜内多行合并连接符（留空保留换行）"}),
                "strip_lines": ("BOOLEAN", {"default": True, "tooltip": "每行去首尾空格"}),
                "drop_empty_lines": ("BOOLEAN", {"default": True, "tooltip": "删除空行"}),

                "strip_header": (["auto(推荐)", "on", "off"], {
                    "default": "auto(推荐)",
                    "tooltip": "是否丢弃分镜标记前的头部说明"
                }),
                "min_shots": ("INT", {"default": 1, "min": 1, "max": 999, "step": 1, "tooltip": "最少分镜数"}),

                "preview_sep": ("STRING", {"default": "\n\n", "tooltip": "预览分隔符"}),
                "preview_prefix_index": ("BOOLEAN", {"default": True, "tooltip": "预览是否加 [SHOT n]"}),
            }
        }

    RETURN_TYPES = ("LIST", "LIST", "INT", "STRING", "STRING")
    RETURN_NAMES = ("pos_list", "neg_list", "count", "fingerprint", "preview_text")
    FUNCTION = "run"
    CATEGORY = "NB/Storyboard"

    def run(
        self,
        positive_block,
        negative_block,
        split_mode,
        shot_delim,
        joiner,
        strip_lines,
        drop_empty_lines,
        strip_header,
        min_shots,
        preview_sep,
        preview_prefix_index,
    ):
        pos_src = _txt(positive_block)
        if not pos_src.strip():
            raise RuntimeError("正面分镜为空")

        neg_src = _txt(negative_block)

        drop_header = (strip_header != "off")

        mode = "auto"
        if split_mode.startswith("marker"):
            mode = "marker"
        elif split_mode.startswith("hash"):
            mode = "hash"

        if mode == "marker":
            pos_raw = _split_marker(pos_src, shot_delim, drop_header)
            neg_raw = _split_marker(neg_src, shot_delim, drop_header) if neg_src.strip() else []
        elif mode == "hash":
            pos_raw = _split_hash(pos_src)
            neg_raw = _split_hash(neg_src) if neg_src.strip() else []
        else:
            pos_raw = _split_auto(pos_src, shot_delim, drop_header)
            neg_raw = _split_auto(neg_src, shot_delim, drop_header) if neg_src.strip() else []

        if not pos_raw:
            raise RuntimeError("没有解析到任何分镜，请检查分隔方式/标记")

        pos_list: List[str] = []
        for p in pos_raw:
            t = _clean(p, joiner, strip_lines, drop_empty_lines)
            if t:
                pos_list.append(t)

        if len(pos_list) < int(min_shots):
            raise RuntimeError(f"解析分镜数量={len(pos_list)} 小于 min_shots={min_shots}")

        neg_list: List[str] = []
        for n in neg_raw:
            neg_list.append(_clean(n, joiner, strip_lines, drop_empty_lines))

        neg_list = _neg_fill(neg_list, len(pos_list))

        sep = preview_sep if preview_sep is not None else "\n\n"
        view: List[str] = []
        for i, p in enumerate(pos_list, start=1):
            view.append(f"[SHOT {i}] {p}" if preview_prefix_index else p)

        preview_text = sep.join(view)
        fp = _sha12(pos_list, neg_list)
        return (pos_list, neg_list, len(pos_list), fp, preview_text)


class StoryboardPromptRouter:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "list_in": ("LIST", {"tooltip": "输入：list[str]"}),
                "index": ("INT", {"default": 0, "min": 0, "max": 9999, "step": 1, "tooltip": "索引（0开始）"}),
            }
        }

    RETURN_TYPES = ("LIST", "STRING", "INT")
    RETURN_NAMES = ("batch_text", "single_text", "index")
    FUNCTION = "run"
    CATEGORY = "NB/Storyboard"

    def run(self, list_in, index):
        if not isinstance(list_in, list):
            list_in = [_txt(list_in)]
        if not list_in:
            return ([], "", int(index))

        idx = int(index)
        if idx < 0:
            idx = 0
        if idx >= len(list_in):
            idx = len(list_in) - 1

        s = list_in[idx]
        return (list_in, s if isinstance(s, str) else _txt(s), idx)


class StoryboardFanOut16:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "list_in": ("LIST", {"tooltip": "输入：list[str]（不足16段会输出空字符串）"}),
            }
        }

    RETURN_TYPES = (
        "STRING", "STRING", "STRING", "STRING",
        "STRING", "STRING", "STRING", "STRING",
        "STRING", "STRING", "STRING", "STRING",
        "STRING", "STRING", "STRING", "STRING",
    )
    RETURN_NAMES = (
        "shot_01", "shot_02", "shot_03", "shot_04",
        "shot_05", "shot_06", "shot_07", "shot_08",
        "shot_09", "shot_10", "shot_11", "shot_12",
        "shot_13", "shot_14", "shot_15", "shot_16",
    )
    FUNCTION = "run"
    CATEGORY = "NB/Storyboard"

    def run(self, list_in):
        if not isinstance(list_in, list):
            list_in = [_txt(list_in)]

        arr: List[str] = []
        for x in list_in:
            arr.append(x if isinstance(x, str) else _txt(x))

        outs: List[str] = []
        for i in range(16):
            outs.append(arr[i] if i < len(arr) else "")

        return tuple(outs)


class StoryboardListPreviewReadable:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "list_in": ("LIST", {"tooltip": "输入：list[str]"}),
                "sep": ("STRING", {"default": "\n\n", "tooltip": "分隔符"}),
                "prefix_index": ("BOOLEAN", {"default": True, "tooltip": "是否加 [n] 前缀"}),
            }
        }

    RETURN_TYPES = ("STRING", "INT")
    RETURN_NAMES = ("text", "count")
    FUNCTION = "run"
    CATEGORY = "NB/Storyboard"

    def run(self, list_in, sep, prefix_index):
        if not isinstance(list_in, list):
            list_in = [_txt(list_in)]

        out: List[str] = []
        for i, s in enumerate(list_in, start=1):
            t = s if isinstance(s, str) else _txt(s)
            out.append(f"[{i}] {t}" if prefix_index else t)

        return ((sep or "\n\n").join(out), len(list_in))


class CLIPTextEncodeBatchList:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "clip": ("CLIP", {}),
                "text": ("LIST", {"tooltip": "输入：list[str]"}),
                "pooled_fallback": (["mean", "first", "zeros"], {
                    "default": "mean",
                    "tooltip": "pooled 缺失时的兜底方式"
                }),
            }
        }

    RETURN_TYPES = ("CONDITIONING",)
    RETURN_NAMES = ("conditioning",)
    FUNCTION = "encode"
    CATEGORY = "NB/Storyboard"

    def encode(self, clip, text, pooled_fallback):
        if clip is None:
            raise RuntimeError("clip 为空")

        if not isinstance(text, list):
            text = [_txt(text)]
        if not text:
            raise RuntimeError("text 为空（Batch List 没有任何内容）")

        texts = [(_txt(t)).strip() for t in text]

        conds: List[torch.Tensor] = []
        pools: List[Optional[torch.Tensor]] = []

        for i, t in enumerate(texts):
            cond, pooled = _encode_one(clip, t)
            if cond is None or not torch.is_tensor(cond):
                raise RuntimeError(f"第 {i} 条 prompt 编码失败：cond 为空")

            if cond.dim() == 2:
                cond = cond.unsqueeze(0)
            if cond.dim() != 3:
                raise RuntimeError(f"cond 形状异常：{tuple(cond.shape)}（期望 [B,L,C]）")

            conds.append(cond)
            pools.append(pooled if (pooled is not None and torch.is_tensor(pooled)) else None)

        max_len = max(c.shape[1] for c in conds)
        conds = [_pad_len(c, max_len) for c in conds]
        cond_cat = torch.cat(conds, dim=0)

        pooled_out: List[torch.Tensor] = []
        for c, p in zip(conds, pools):
            if p is None:
                pooled_out.append(_pooled_from(c, pooled_fallback))
                continue
            if p.dim() == 1:
                p = p.unsqueeze(0)
            pooled_out.append(p)

        pooled_cat = torch.cat(pooled_out, dim=0)
        return ([[cond_cat, {"pooled_output": pooled_cat}]],)


NODE_CLASS_MAPPINGS = {
    "StoryboardPromptParser": StoryboardPromptParserBatch,
    "StoryboardPromptRouter": StoryboardPromptRouter,
    "StoryboardFanOut16": StoryboardFanOut16,
    "StoryboardListPreview": StoryboardListPreviewReadable,
    "CLIPTextEncodeBatchList": CLIPTextEncodeBatchList,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "StoryboardPromptParser": "Storyboard Parser (分镜解析/预览/指纹)",
    "StoryboardPromptRouter": "Storyboard Router (按索引取一段)",
    "StoryboardFanOut16": "Storyboard Fan-Out x16 (扇出16口)",
    "StoryboardListPreview": "Storyboard Preview (列表可读预览)",
    "CLIPTextEncodeBatchList": "CLIP Encode (批量List)",
}
