# -*- coding: utf-8 -*-
"""
口型同步 WebSocket 服务器
实时发送音量数据到 Live2D 前端
"""

import asyncio
import threading
import json
from typing import Optional, Set, Callable
import wave
import struct
import os

try:
    import websockets
    from websockets.server import serve
    WEBSOCKETS_AVAILABLE = True
except ImportError:
    WEBSOCKETS_AVAILABLE = False

from src.utils.logger import logger


class LipSyncServer:
    """口型同步 WebSocket 服务器"""
    
    def __init__(self, host: str = "127.0.0.1", port: int = 8765):
        self.host = host
        self.port = port
        self._running = False
        self._thread: Optional[threading.Thread] = None
        self._loop: Optional[asyncio.AbstractEventLoop] = None
        self._clients: Set = set()
        self._server = None
        
    def start(self) -> bool:
        """启动 WebSocket 服务器"""
        if not WEBSOCKETS_AVAILABLE:
            logger.warning("websockets 未安装，口型同步功能不可用")
            return False
            
        if self._running:
            return True
            
        self._running = True
        self._thread = threading.Thread(target=self._run_server, daemon=True)
        self._thread.start()
        logger.info(f"口型同步服务器启动: ws://{self.host}:{self.port}")
        return True
    
    def stop(self):
        """停止服务器"""
        self._running = False
        if self._loop:
            self._loop.call_soon_threadsafe(self._loop.stop)
        logger.info("口型同步服务器已停止")
    
    def _run_server(self):
        """运行服务器（在独立线程中）"""
        self._loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._loop)
        
        try:
            self._loop.run_until_complete(self._start_server())
            self._loop.run_forever()
        except Exception as e:
            logger.error(f"口型同步服务器错误: {e}")
        finally:
            self._loop.close()
    
    async def _start_server(self):
        """启动 WebSocket 服务"""
        self._server = await serve(
            self._handle_client,
            self.host,
            self.port
        )
    
    async def _handle_client(self, websocket):
        """处理客户端连接"""
        self._clients.add(websocket)
        logger.info(f"Live2D 客户端连接: {websocket.remote_address}")
        
        try:
            async for message in websocket:
                # 可以接收来自前端的消息
                pass
        except websockets.exceptions.ConnectionClosed:
            pass
        finally:
            self._clients.discard(websocket)
            logger.info("Live2D 客户端断开")
    
    def send_mouth_value(self, value: float):
        """
        发送嘴巴开合度值
        
        Args:
            value: 0.0 - 1.0，0为闭嘴，1为张开
        """
        if not self._running or not self._clients:
            return
            
        value = max(0.0, min(1.0, value))
        message = json.dumps({
            "type": "mouth",
            "value": value
        })
        
        # 发送给所有连接的客户端
        if self._loop:
            asyncio.run_coroutine_threadsafe(
                self._broadcast(message),
                self._loop
            )
    
    def send_speak_text(self, text: str):
        """
        发送说话文字（显示在模型旁边）
        
        Args:
            text: 要显示的文字
        """
        if not self._running or not self._clients:
            return
            
        message = json.dumps({
            "type": "speak",
            "text": text
        })
        
        if self._loop:
            asyncio.run_coroutine_threadsafe(
                self._broadcast(message),
                self._loop
            )
    
    async def _broadcast(self, message: str):
        """广播消息给所有客户端"""
        if self._clients:
            await asyncio.gather(
                *[client.send(message) for client in self._clients],
                return_exceptions=True
            )
    
    def analyze_audio_file(self, audio_path: str, callback: Callable[[float], None], 
                           sample_interval: float = 0.05):
        """
        分析音频文件并回调音量值
        
        Args:
            audio_path: 音频文件路径
            callback: 回调函数，参数为音量值 (0.0-1.0)
            sample_interval: 采样间隔（秒）
        """
        thread = threading.Thread(
            target=self._analyze_audio_thread,
            args=(audio_path, callback, sample_interval),
            daemon=True
        )
        thread.start()
        return thread
    
    def _analyze_audio_thread(self, audio_path: str, callback: Callable[[float], None],
                               sample_interval: float):
        """音频分析线程"""
        import time
        
        try:
            # 尝试使用 pydub 分析（更准确）
            try:
                from pydub import AudioSegment
                audio = AudioSegment.from_file(audio_path)
                
                # 计算采样点
                duration_ms = len(audio)
                interval_ms = int(sample_interval * 1000)
                
                start_time = time.time()
                
                for i in range(0, duration_ms, interval_ms):
                    if not self._running:
                        break
                    
                    # 获取这个时间段的音频
                    chunk = audio[i:i + interval_ms]
                    
                    # 计算 RMS 音量
                    rms = chunk.rms
                    max_rms = 32768  # 16-bit audio max
                    volume = min(1.0, rms / max_rms * 3)  # 放大3倍更明显
                    
                    callback(volume)
                    
                    # 同步时间
                    elapsed = time.time() - start_time
                    expected = (i + interval_ms) / 1000
                    sleep_time = expected - elapsed
                    if sleep_time > 0:
                        time.sleep(sleep_time)
                
                # 结束时闭嘴
                callback(0.0)
                
            except ImportError:
                # 使用简单的 wave 分析
                self._analyze_wav_simple(audio_path, callback, sample_interval)
                
        except Exception as e:
            logger.error(f"音频分析错误: {e}")
            callback(0.0)
    
    def _analyze_wav_simple(self, audio_path: str, callback: Callable[[float], None],
                            sample_interval: float):
        """简单的 WAV 分析（不需要额外依赖）"""
        import time
        
        # 如果是 mp3，跳过
        if audio_path.lower().endswith('.mp3'):
            # 简单模拟口型
            self._simulate_lipsync(callback, sample_interval)
            return
        
        try:
            with wave.open(audio_path, 'rb') as wf:
                framerate = wf.getframerate()
                nchannels = wf.getnchannels()
                sampwidth = wf.getsampwidth()
                nframes = wf.getnframes()
                
                frames_per_sample = int(framerate * sample_interval)
                duration = nframes / framerate
                
                start_time = time.time()
                
                while True:
                    frames = wf.readframes(frames_per_sample)
                    if not frames or not self._running:
                        break
                    
                    # 计算 RMS
                    if sampwidth == 2:  # 16-bit
                        count = len(frames) // 2
                        format_str = f"<{count}h"
                        samples = struct.unpack(format_str, frames)
                        
                        if samples:
                            rms = (sum(s**2 for s in samples) / len(samples)) ** 0.5
                            volume = min(1.0, rms / 10000)  # 归一化
                            callback(volume)
                    
                    time.sleep(sample_interval)
                
                callback(0.0)
                
        except Exception as e:
            logger.error(f"WAV 分析错误: {e}")
            self._simulate_lipsync(callback, sample_interval)
    
    def _simulate_lipsync(self, callback: Callable[[float], None], interval: float = 0.1):
        """模拟口型（当无法分析音频时）"""
        import time
        import random
        
        # 模拟说话3秒
        duration = 3.0
        elapsed = 0
        
        while elapsed < duration and self._running:
            # 随机口型值，模拟说话
            value = random.uniform(0.2, 0.8)
            callback(value)
            time.sleep(interval)
            elapsed += interval
        
        callback(0.0)


# 全局实例
_lipsync_server: Optional[LipSyncServer] = None


def get_lipsync_server() -> LipSyncServer:
    """获取口型同步服务器实例"""
    global _lipsync_server
    if _lipsync_server is None:
        _lipsync_server = LipSyncServer()
    return _lipsync_server
