# -*- coding: utf-8 -*-
"""
Seed-VC 模型下载器
预下载所有需要的模型文件
"""

import os
from pathlib import Path
from typing import Optional, Callable

# 设置HuggingFace镜像
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")

# 模型目录
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "seed_vc"

# 需要下载的模型列表
REQUIRED_MODELS = {
    "DiT": {
        "repo": "Plachta/Seed-VC",
        "filename": "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
        "size_mb": 800,
    },
    "config": {
        "repo": "Plachta/Seed-VC", 
        "filename": "config_dit_mel_seed_uvit_whisper_small_wavenet.yml",
        "size_mb": 1,
    },
    "bigvgan": {
        "repo": "nvidia/bigvgan_v2_22khz_80band_256x",
        "filename": "bigvgan_v2_22khz_80band_256x",
        "is_folder": True,
        "size_mb": 150,
    },
    "whisper": {
        "repo": "openai/whisper-small",
        "filename": "whisper-small",
        "is_folder": True,
        "size_mb": 500,
    },
    "campplus": {
        "url": "https://huggingface.co/Plachta/Seed-VC/resolve/main/campplus_cn_common.bin",
        "filename": "campplus_cn_common.bin",
        "size_mb": 27,
    }
}


def check_models() -> dict:
    """检查模型是否已下载"""
    status = {}
    
    for name, info in REQUIRED_MODELS.items():
        if info.get("is_folder"):
            model_path = MODELS_DIR / info["filename"]
            exists = model_path.exists() and any(model_path.iterdir()) if model_path.exists() else False
        else:
            model_path = MODELS_DIR / info["filename"]
            exists = model_path.exists()
        
        status[name] = {
            "exists": exists,
            "path": str(model_path),
            "size_mb": info["size_mb"]
        }
    
    return status


def download_models(
    progress_callback: Optional[Callable[[str, float], None]] = None,
    force: bool = False
) -> bool:
    """
    下载所有需要的模型
    
    Args:
        progress_callback: 进度回调函数 (model_name, progress_percent)
        force: 强制重新下载
    
    Returns:
        是否全部下载成功
    """
    from huggingface_hub import hf_hub_download, snapshot_download
    import requests
    
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    
    success = True
    total_models = len(REQUIRED_MODELS)
    
    for idx, (name, info) in enumerate(REQUIRED_MODELS.items()):
        try:
            if progress_callback:
                progress_callback(name, idx / total_models * 100)
            
            # 检查是否已存在
            if info.get("is_folder"):
                model_path = MODELS_DIR / info["filename"]
                if model_path.exists() and any(model_path.iterdir()) and not force:
                    print(f"✅ {name} 已存在，跳过")
                    continue
            else:
                model_path = MODELS_DIR / info["filename"]
                if model_path.exists() and not force:
                    print(f"✅ {name} 已存在，跳过")
                    continue
            
            print(f"📥 下载 {name} ({info['size_mb']}MB)...")
            
            if info.get("url"):
                # 直接URL下载
                response = requests.get(info["url"], stream=True)
                response.raise_for_status()
                
                with open(model_path, "wb") as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                        
            elif info.get("is_folder"):
                # 下载整个文件夹
                snapshot_download(
                    repo_id=info["repo"],
                    local_dir=model_path,
                    local_dir_use_symlinks=False
                )
            else:
                # 下载单个文件
                hf_hub_download(
                    repo_id=info["repo"],
                    filename=info["filename"],
                    local_dir=MODELS_DIR,
                    local_dir_use_symlinks=False
                )
            
            print(f"✅ {name} 下载完成")
            
        except Exception as e:
            print(f"❌ {name} 下载失败: {e}")
            success = False
    
    if progress_callback:
        progress_callback("完成", 100)
    
    return success


def get_model_paths() -> dict:
    """获取模型路径"""
    return {
        "dit_checkpoint": str(MODELS_DIR / "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"),
        "dit_config": str(MODELS_DIR / "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"),
        "bigvgan": str(MODELS_DIR / "bigvgan_v2_22khz_80band_256x"),
        "whisper": str(MODELS_DIR / "whisper-small"),
        "campplus": str(MODELS_DIR / "campplus_cn_common.bin"),
    }


if __name__ == "__main__":
    print("检查模型状态...")
    status = check_models()
    
    for name, info in status.items():
        icon = "✅" if info["exists"] else "❌"
        print(f"{icon} {name}: {info['path']}")
    
    missing = [n for n, s in status.items() if not s["exists"]]
    if missing:
        print(f"\n缺少模型: {missing}")
        print("开始下载...")
        download_models()
    else:
        print("\n所有模型已就绪！")
