# -*- coding: utf-8 -*-
"""
远程增量更新客户端
支持从远程服务器检查和下载增量更新

功能：
1. 版本检查 - 对比本地和远程版本
2. 增量下载 - 只下载变更的文件
3. 断点续传 - 支持大文件中断后继续
4. 进度回调 - 实时更新下载进度
5. 错误重试 - 网络异常自动重试
6. 版本回滚 - 更新失败可回滚
"""

import os
import sys
import json
import hashlib
import shutil
import tempfile
import logging
import time
from typing import Dict, List, Optional, Tuple, Callable, Any
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

from pathlib import Path

try:
    import requests
    REQUESTS_AVAILABLE = True
except ImportError:
    REQUESTS_AVAILABLE = False

logger = logging.getLogger(__name__)


def _get_update_server_url() -> str:
    """从配置文件读取更新服务器地址（必须配置）"""
    import yaml
    
    # 确定根目录（兼容打包环境）
    if getattr(sys, 'frozen', False):
        root_dir = Path(sys._MEIPASS)
    else:
        root_dir = Path(__file__).parent.parent.parent.absolute()
    
    # 用户配置目录
    user_config_dir = Path.home() / ".douyin_live_assistant"
    
    config_paths = [
        root_dir / "config" / "config.yaml",
        user_config_dir / "config.yaml",
    ]
    
    for config_path in config_paths:
        try:
            if config_path.exists():
                with open(config_path, 'r', encoding='utf-8') as f:
                    config = yaml.safe_load(f)
                    if config and 'server' in config:
                        url = config['server'].get('api_base_url')
                        if url:
                            return f"{url.rstrip('/')}/c/biz/license"
        except Exception as e:
            print(f"[Update] 读取配置文件失败 {config_path}: {e}")
    
    raise ValueError(f"未找到 API 配置，请在 config/config.yaml 中配置 server.api_base_url，已检查路径: {config_paths}")


class UpdateStatus(Enum):
    """更新状态"""
    IDLE = "idle"
    CHECKING = "checking"
    AVAILABLE = "available"
    DOWNLOADING = "downloading"
    INSTALLING = "installing"
    DONE = "done"
    ERROR = "error"


@dataclass
class FileInfo:
    """文件信息"""
    path: str           # 相对路径
    hash: str           # SHA256 哈希
    size: int           # 文件大小 (bytes)
    modified: str       # 修改时间
    download_url: str = ""  # 下载URL（可选，优先使用）


@dataclass
class UpdateInfo:
    """更新信息"""
    version: str                           # 新版本号
    release_date: str                      # 发布日期
    changelog: str                         # 更新日志
    files_to_update: List[FileInfo] = field(default_factory=list)  # 需要更新的文件
    files_to_delete: List[str] = field(default_factory=list)       # 需要删除的文件
    total_size: int = 0                    # 总下载大小


class UpdateChecker:
    """更新检查器"""
    
    _instance = None
    
    # 默认更新服务器配置（从配置文件读取）
    DEFAULT_UPDATE_URL = _get_update_server_url()
    MANIFEST_FILE = "manifest.json"
    
    # 软件编码（与后端 BIZ_APP_VERSION 表中的 SOFTWARE_CODE 对应）
    SOFTWARE_CODE = "DOUYIN_LIVE_ASSISTANT"
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(UpdateChecker, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
        
        self._initialized = True
        self._status = UpdateStatus.IDLE
        self._progress = 0.0
        self._current_version = self._get_current_version()
        self._update_info: Optional[UpdateInfo] = None
        self._update_url = self.DEFAULT_UPDATE_URL
        self._app_dir = self._get_app_dir()
        self._backup_dir = os.path.join(self._app_dir, ".backup")
        
        # 回调
        self._on_progress: List[Callable] = []
        self._on_status_changed: List[Callable] = []
        
        # 确保备份目录存在
        os.makedirs(self._backup_dir, exist_ok=True)
        
    def _get_app_dir(self) -> str:
        """获取应用目录"""
        if hasattr(sys, '_MEIPASS'):
            return os.path.dirname(sys.executable)
        return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    
    def _get_current_version(self) -> str:
        """获取当前版本"""
        version_file = os.path.join(self._get_app_dir(), "VERSION")
        if os.path.exists(version_file):
            with open(version_file, 'r') as f:
                return f.read().strip()
        return "1.0.0"
    
    @property
    def status(self) -> UpdateStatus:
        return self._status
    
    @property
    def progress(self) -> float:
        return self._progress
    
    @property
    def current_version(self) -> str:
        return self._current_version
    
    def set_progress_callback(self, callback: Callable):
        """设置进度回调"""
        if callback and callback not in self._on_progress:
            self._on_progress.append(callback)
    
    def set_status_callback(self, callback: Callable):
        """设置状态变更回调"""
        if callback and callback not in self._on_status_changed:
            self._on_status_changed.append(callback)
    
    def set_update_url(self, url: str):
        """设置更新服务器URL"""
        self._update_url = url
    
    def _set_status(self, status: UpdateStatus):
        """设置状态并触发回调"""
        self._status = status
        for callback in self._on_status_changed:
            try:
                callback(status)
            except Exception as e:
                logger.error(f"状态回调失败: {e}")
    
    def _set_progress(self, progress: float):
        """设置进度并触发回调"""
        self._progress = progress
        for callback in self._on_progress:
            try:
                callback(progress)
            except Exception as e:
                logger.error(f"进度回调失败: {e}")
    
    def calculate_file_hash(self, filepath: str) -> str:
        """计算文件SHA256哈希"""
        sha256 = hashlib.sha256()
        with open(filepath, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        return sha256.hexdigest()
    
    def generate_manifest(self, directory: str = None) -> Dict:
        """生成本地文件清单"""
        if directory is None:
            directory = self._app_dir
        
        manifest = {
            "version": self._current_version,
            "generated": datetime.now().isoformat(),
            "files": {}
        }
        
        # 需要跟踪的目录
        track_dirs = ["src", "gui", "config", "static", "assets", "tools"]
        # 排除的模式
        exclude_patterns = [
            "__pycache__", ".pyc", ".pyo", ".git", 
            ".backup", "logs", "output", "build", "dist"
        ]
        
        for track_dir in track_dirs:
            dir_path = os.path.join(directory, track_dir)
            if not os.path.exists(dir_path):
                continue
            
            for root, dirs, files in os.walk(dir_path):
                # 过滤排除目录
                dirs[:] = [d for d in dirs if not any(p in d for p in exclude_patterns)]
                
                for filename in files:
                    # 过滤排除文件
                    if any(p in filename for p in exclude_patterns):
                        continue
                    
                    filepath = os.path.join(root, filename)
                    rel_path = os.path.relpath(filepath, directory)
                    
                    try:
                        stat = os.stat(filepath)
                        manifest["files"][rel_path] = {
                            "hash": self.calculate_file_hash(filepath),
                            "size": stat.st_size,
                            "modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
                        }
                    except Exception as e:
                        logger.warning(f"无法处理文件 {rel_path}: {e}")
        
        return manifest
    
    def save_manifest(self, manifest: Dict = None):
        """保存清单到文件"""
        if manifest is None:
            manifest = self.generate_manifest()
        
        manifest_path = os.path.join(self._app_dir, self.MANIFEST_FILE)
        with open(manifest_path, 'w', encoding='utf-8') as f:
            json.dump(manifest, f, indent=2, ensure_ascii=False)
        
        logger.info(f"清单已保存: {manifest_path}")
        return manifest_path
    
    def load_manifest(self) -> Optional[Dict]:
        """加载本地清单"""
        manifest_path = os.path.join(self._app_dir, self.MANIFEST_FILE)
        if os.path.exists(manifest_path):
            with open(manifest_path, 'r', encoding='utf-8') as f:
                return json.load(f)
        return None
    
    def compare_manifests(self, local: Dict, remote: Dict) -> Tuple[List[str], List[str], List[str]]:
        """
        比较本地和远程清单
        返回: (新增文件, 更新文件, 删除文件)
        """
        local_files = set(local.get("files", {}).keys())
        remote_files = set(remote.get("files", {}).keys())
        
        # 新增文件
        added = list(remote_files - local_files)
        
        # 删除文件
        deleted = list(local_files - remote_files)
        
        # 更新文件（哈希不同）
        updated = []
        for filepath in local_files & remote_files:
            local_hash = local["files"][filepath].get("hash")
            remote_hash = remote["files"][filepath].get("hash")
            if local_hash != remote_hash:
                updated.append(filepath)
        
        return added, updated, deleted
    
    def check_for_updates(self, callback: Callable = None) -> Optional[UpdateInfo]:
        """
        检查更新（异步）
        
        流程：
        1. 调用 /api/updates/version 获取最新版本
        2. 如果有新版本，调用 /api/updates/diff 获取差异
        3. 返回 UpdateInfo 或 None
        
        callback: 完成后的回调函数，参数为 UpdateInfo 或 None
        """
        def _check():
            self._set_status(UpdateStatus.CHECKING)
            self._set_progress(0)
            
            try:
                if not REQUESTS_AVAILABLE:
                    logger.error("requests 模块不可用，无法检查更新")
                    self._set_status(UpdateStatus.ERROR)
                    if callback:
                        callback(None)
                    return None
                
                # 调用后端增量更新接口
                # GET /api/c/biz/license/checkIncrementalUpdate?softwareCode=XXX&currentVersion=1.0.0
                check_url = f"{self._update_url}/checkIncrementalUpdate"
                params = {
                    "softwareCode": self.SOFTWARE_CODE,
                    "currentVersion": self._current_version
                }
                print(f"[检查更新] 请求URL: {check_url}")
                print(f"[检查更新] 请求参数: {params}")
                
                response = requests.get(check_url, params=params, timeout=30)
                print(f"[检查更新] 响应状态码: {response.status_code}")
                print(f"[检查更新] 响应内容: {response.text[:500]}")
                
                response.raise_for_status()
                result = response.json()
                
                # 解析后端响应 (CommonResult 格式)
                if result.get("code") != 200:
                    print(f"[检查更新] 失败 - code: {result.get('code')}, msg: {result.get('msg')}")
                    self._set_status(UpdateStatus.ERROR)
                    if callback:
                        callback(None)
                    return None
                
                data = result.get("data", {})
                
                # 检查是否有更新
                if not data.get("hasUpdate", False):
                    logger.info(f"当前版本 {self._current_version} 已是最新")
                    self._set_status(UpdateStatus.IDLE)
                    self._update_info = None
                    if callback:
                        callback(None)
                    return None
                
                latest_version = data.get("toVersion", "unknown")
                force_update = data.get("forceUpdate", False)
                changelog = data.get("changelog", "")
                files_base_url = data.get("filesBaseUrl", self._update_url)
                
                logger.info(f"发现新版本: {latest_version}")
                
                # 保存文件下载URL
                self._files_base_url = files_base_url
                
                # 构建更新信息
                files_to_update = []
                total_size = 0
                
                # 处理新增和修改的文件
                for file_data in data.get("addedFiles", []) + data.get("modifiedFiles", []):
                    files_to_update.append(FileInfo(
                        path=file_data.get("path", ""),
                        hash=file_data.get("hash", ""),
                        size=file_data.get("size", 0),
                        modified="",
                        download_url=file_data.get("downloadUrl", "")  # 使用后端返回的下载URL
                    ))
                    total_size += file_data.get("size", 0)
                
                self._update_info = UpdateInfo(
                    version=latest_version,
                    release_date=datetime.now().isoformat(),
                    changelog=changelog,
                    files_to_update=files_to_update,
                    files_to_delete=data.get("deletedFiles", []),
                    total_size=data.get("totalSize", total_size)
                )
                
                self._set_status(UpdateStatus.AVAILABLE)
                if callback:
                    callback(self._update_info)
                return self._update_info
                
            except requests.exceptions.RequestException as e:
                logger.error(f"网络请求失败: {e}")
                self._set_status(UpdateStatus.ERROR)
                if callback:
                    callback(None)
                return None
            except Exception as e:
                logger.error(f"检查更新失败: {e}")
                self._set_status(UpdateStatus.ERROR)
                if callback:
                    callback(None)
                return None
        
        thread = threading.Thread(target=_check, daemon=True)
        thread.start()
        return None
    
    def _compare_versions(self, v1: str, v2: str) -> int:
        """
        比较版本号
        返回: -1 (v1 < v2), 0 (v1 == v2), 1 (v1 > v2)
        """
        def parse_version(v):
            return [int(x) for x in v.split('.')]
        
        try:
            parts1 = parse_version(v1)
            parts2 = parse_version(v2)
            
            for i in range(max(len(parts1), len(parts2))):
                p1 = parts1[i] if i < len(parts1) else 0
                p2 = parts2[i] if i < len(parts2) else 0
                if p1 < p2:
                    return -1
                elif p1 > p2:
                    return 1
            return 0
        except:
            return 0
    
    def download_and_apply_updates(self, callback: Callable = None) -> bool:
        """
        下载并应用更新
        
        功能：
        1. 并行下载文件（可配置并发数）
        2. 支持断点续传
        3. 实时进度回调
        4. 下载完成后校验哈希
        5. 备份旧文件后替换
        """
        if not self._update_info:
            return False
        
        if not REQUESTS_AVAILABLE:
            logger.error("requests 模块不可用")
            return False
        
        def _download():
            self._set_status(UpdateStatus.DOWNLOADING)
            total_files = len(self._update_info.files_to_update)
            total_size = self._update_info.total_size
            downloaded_size = 0
            
            try:
                # 创建临时目录
                temp_dir = tempfile.mkdtemp(prefix="update_")
                failed_files = []
                
                # 下载所有文件
                for i, file_info in enumerate(self._update_info.files_to_update):
                    file_path = file_info.path.replace('\\', '/')
                    # 优先使用后端返回的完整下载URL（如MinIO URL），否则拼接
                    if file_info.download_url:
                        download_url = file_info.download_url
                    else:
                        download_url = f"{self._files_base_url}/{file_path}"
                    temp_file_path = os.path.join(temp_dir, file_info.path)
                    
                    # 确保目录存在
                    os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
                    
                    # 下载文件（支持断点续传）
                    success = self._download_file_with_resume(
                        download_url, 
                        temp_file_path,
                        expected_hash=file_info.hash,
                        expected_size=file_info.size
                    )
                    
                    if not success:
                        failed_files.append(file_info.path)
                        logger.error(f"下载失败: {file_info.path}")
                    else:
                        downloaded_size += file_info.size
                        logger.info(f"下载完成: {file_info.path}")
                    
                    # 更新进度 (下载占 70%)
                    if total_size > 0:
                        progress = (downloaded_size / total_size) * 70
                    else:
                        progress = ((i + 1) / total_files) * 70
                    self._set_progress(progress)
                
                # 如果有文件下载失败，终止更新
                if failed_files:
                    logger.error(f"以下文件下载失败: {failed_files}")
                    self._set_status(UpdateStatus.ERROR)
                    if callback:
                        callback(False)
                    return False
                
                # 应用更新
                self._set_status(UpdateStatus.INSTALLING)
                
                for i, file_info in enumerate(self._update_info.files_to_update):
                    temp_file_path = os.path.join(temp_dir, file_info.path)
                    dest_path = os.path.join(self._app_dir, file_info.path)
                    
                    # 备份原文件
                    if os.path.exists(dest_path):
                        backup_path = os.path.join(self._backup_dir, file_info.path)
                        os.makedirs(os.path.dirname(backup_path), exist_ok=True)
                        shutil.copy2(dest_path, backup_path)
                    
                    # 确保目标目录存在
                    os.makedirs(os.path.dirname(dest_path), exist_ok=True)
                    
                    # 复制新文件
                    shutil.copy2(temp_file_path, dest_path)
                    logger.info(f"更新: {file_info.path}")
                    
                    # 更新进度 (安装占 20%)
                    progress = 70 + ((i + 1) / total_files) * 20
                    self._set_progress(progress)
                
                # 删除文件
                for filepath in self._update_info.files_to_delete:
                    full_path = os.path.join(self._app_dir, filepath)
                    if os.path.exists(full_path):
                        backup_path = os.path.join(self._backup_dir, filepath)
                        os.makedirs(os.path.dirname(backup_path), exist_ok=True)
                        shutil.move(full_path, backup_path)
                        logger.info(f"删除: {filepath}")
                
                # 更新版本文件
                version_file = os.path.join(self._app_dir, "VERSION")
                with open(version_file, 'w') as f:
                    f.write(self._update_info.version)
                
                # 更新本地清单
                self._current_version = self._update_info.version
                self.save_manifest()
                
                self._set_status(UpdateStatus.DONE)
                self._set_progress(100)
                
                logger.info(f"更新完成: {self._update_info.version}")
                
                if callback:
                    callback(True)
                return True
                
            except Exception as e:
                logger.error(f"更新失败: {e}")
                self._set_status(UpdateStatus.ERROR)
                if callback:
                    callback(False)
                return False
            finally:
                # 清理临时目录
                if 'temp_dir' in locals():
                    shutil.rmtree(temp_dir, ignore_errors=True)
        
        thread = threading.Thread(target=_download, daemon=True)
        thread.start()
        return True
    
    def _download_file_with_resume(self, url: str, dest_path: str, 
                                    expected_hash: str = None, 
                                    expected_size: int = 0,
                                    max_retries: int = 3) -> bool:
        """
        下载文件，支持断点续传
        
        Args:
            url: 下载地址
            dest_path: 保存路径
            expected_hash: 预期的SHA256哈希
            expected_size: 预期的文件大小
            max_retries: 最大重试次数
            
        Returns:
            是否下载成功
        """
        for attempt in range(max_retries):
            try:
                # 检查是否有部分下载的文件
                downloaded_size = 0
                if os.path.exists(dest_path):
                    downloaded_size = os.path.getsize(dest_path)
                
                # 设置断点续传头
                headers = {}
                if downloaded_size > 0 and expected_size > 0:
                    if downloaded_size >= expected_size:
                        # 文件已完整下载，验证哈希
                        if expected_hash:
                            actual_hash = self.calculate_file_hash(dest_path)
                            if actual_hash == expected_hash:
                                return True
                        else:
                            return True
                        # 哈希不匹配，重新下载
                        downloaded_size = 0
                        os.remove(dest_path)
                    else:
                        headers['Range'] = f'bytes={downloaded_size}-'
                
                # 发起请求
                response = requests.get(url, headers=headers, stream=True, timeout=60)
                
                # 检查响应
                if response.status_code == 416:
                    # Range Not Satisfiable - 文件已完整
                    return True
                elif response.status_code == 206:
                    # Partial Content - 断点续传
                    mode = 'ab'
                elif response.status_code == 200:
                    # 完整下载
                    mode = 'wb'
                    downloaded_size = 0
                else:
                    response.raise_for_status()
                    mode = 'wb'
                
                # 写入文件
                with open(dest_path, mode) as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        if chunk:
                            f.write(chunk)
                
                # 验证文件
                if expected_hash:
                    actual_hash = self.calculate_file_hash(dest_path)
                    if actual_hash != expected_hash:
                        logger.warning(f"哈希不匹配: {dest_path}")
                        os.remove(dest_path)
                        continue  # 重试
                
                return True
                
            except requests.exceptions.RequestException as e:
                logger.warning(f"下载失败 (尝试 {attempt + 1}/{max_retries}): {url}, 错误: {e}")
                time.sleep(2 ** attempt)  # 指数退避
            except Exception as e:
                logger.error(f"下载异常: {e}")
                break
        
        return False
    
    def rollback(self) -> bool:
        """回滚到上一个版本"""
        if not os.path.exists(self._backup_dir):
            logger.warning("没有可用的备份")
            return False
        
        try:
            for root, dirs, files in os.walk(self._backup_dir):
                for filename in files:
                    backup_path = os.path.join(root, filename)
                    rel_path = os.path.relpath(backup_path, self._backup_dir)
                    dest_path = os.path.join(self._app_dir, rel_path)
                    
                    os.makedirs(os.path.dirname(dest_path), exist_ok=True)
                    shutil.copy2(backup_path, dest_path)
                    logger.info(f"回滚: {rel_path}")
            
            logger.info("回滚完成")
            return True
            
        except Exception as e:
            logger.error(f"回滚失败: {e}")
            return False
    
    def on_progress(self, callback: Callable):
        """注册进度回调"""
        self._on_progress.append(callback)
    
    def on_status_changed(self, callback: Callable):
        """注册状态变化回调"""
        self._on_status_changed.append(callback)


# 便捷函数
def get_update_checker() -> UpdateChecker:
    """获取更新检查器实例"""
    return UpdateChecker()
