# -*- coding: utf-8 -*-
"""
网络抓包服务 - 完整版
移植自 stream_capture (C#) 的所有功能

功能特性:
1. TCP 流重组 - 处理跨包的 RTMP 数据
2. AMF0 协议解析 - 字节级别提取 streamKey
3. 多平台专用解析器 - 抖音、小红书、B站、快手、京东
4. tcUrl 缓存机制 - 关联 connect 和 publish 命令
5. 去重机制 - 避免重复触发
6. 全流量调试模式 - 用于问题排查
"""

import re
import json
import time
import threading
import subprocess
from typing import Optional, Callable, List, Dict, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from datetime import datetime, timedelta

from src.utils.logger import logger

# 尝试导入scapy
try:
    from scapy.all import sniff, IP, TCP, Raw, get_if_list, get_if_hwaddr, conf, UDP
    from scapy.arch.windows import get_windows_if_list
    SCAPY_AVAILABLE = True
    print("[PacketCapture] scapy loaded OK")
except ImportError as e:
    SCAPY_AVAILABLE = False
    print(f"[PacketCapture] scapy load FAILED: {e}")

# 尝试导入 psutil
try:
    import psutil
    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False


class Platform(Enum):
    """直播平台"""
    DOUYIN = "douyin"
    UNKNOWN = "unknown"


@dataclass
class StreamInfo:
    """推流信息"""
    platform: Platform
    server_url: str
    stream_key: str
    full_url: str
    protocol: str = "RTMP"
    room_id: str = ""
    web_rid: str = ""
    anchor_name: str = ""
    parameters: Dict[str, str] = field(default_factory=dict)
    raw_data: str = ""
    capture_time: float = field(default_factory=time.time)
    
    @property
    def is_valid(self) -> bool:
        return bool(self.server_url and self.stream_key)


@dataclass
class NetworkInterface:
    """网络接口信息"""
    name: str
    description: str
    ip: str = ""
    mac: str = ""
    is_up: bool = True


@dataclass
class PlatformConfig:
    """平台配置"""
    id: str
    name: str
    enabled: bool
    protocols: List[str]
    ports: List[int]
    domain_patterns: List[str]
    url_patterns: List[str]
    keyword_patterns: List[str]


class ReassemblyState:
    """TCP 流重组状态"""
    def __init__(self):
        self.buffer: bytearray = bytearray()
        self.last_seen: datetime = datetime.now()


class PlatformMatcher:
    """平台匹配器 - 移植自 stream_capture"""
    
    # 平台配置 - 与 platforms.json 对应
    PLATFORM_CONFIGS: Dict[Platform, PlatformConfig] = {
        Platform.DOUYIN: PlatformConfig(
            id="douyin",
            name="抖音",
            enabled=True,
            protocols=["rtmp", "rtmps"],
            ports=[1935, 443, 19350],
            domain_patterns=[
                r"push\.live\.douyin\.com",
                r"live\.douyin\.com",
                r".*\.douyincdn\.com",
                r".*\.douyinpic\.com",
                r"webcast.*\.douyincdn\.com"
            ],
            url_patterns=[
                r"rtmps?://[^\s\"'<>]+",
                r"push_url[\"']?:\s*[\"']([^\"']+)[\"']",
                r"stream_url[\"']?:\s*[\"']([^\"']+)[\"']"
            ],
            keyword_patterns=[
                "rtmp://",
                "rtmps://",
                "live.douyin.com",
                "stream_id",
                "stream_key",
                "push_url"
            ]
        ),
    }
    
    # 兼容旧代码的 PATTERNS 格式
    PATTERNS = {
        Platform.DOUYIN: {
            'domains': [
                r'push\.live\.douyin\.com',
                r'live\.douyin\.com',
                r'.*\.douyincdn\.com',
                r'.*\.douyinpic\.com',
                r'webcast.*\.douyincdn\.com'
            ],
            'url_patterns': [
                r'rtmps?://[^\s"\'<>]+douyincdn\.com[^\s"\'<>]*',
                r'rtmps?://[^\s"\'<>]+douyin[^\s"\'<>]*',
            ],
            'room_id_patterns': [
                r'"room_id"\s*:\s*"?(\d+)"?',
                r'"roomId"\s*:\s*"?(\d+)"?',
                r'room_id=(\d+)',
                r'roomId=(\d+)',
            ],
            'web_rid_patterns': [
                r'"web_rid"\s*:\s*"(\d+)"',
                r'"webRid"\s*:\s*"(\d+)"',
                r'web_rid=(\d+)',
            ]
        }
    }
    
    @classmethod
    def get_config(cls, platform: Platform) -> Optional[PlatformConfig]:
        """获取平台配置"""
        return cls.PLATFORM_CONFIGS.get(platform)
    
    @classmethod
    def detect_platform(cls, data: str) -> Platform:
        """检测平台"""
        for platform, patterns in cls.PATTERNS.items():
            for domain_pattern in patterns['domains']:
                if re.search(domain_pattern, data, re.IGNORECASE):
                    return platform
        return Platform.UNKNOWN
    
    @classmethod
    def detect_platform_from_config(cls, data: str) -> Platform:
        """使用平台配置检测"""
        for platform, config in cls.PLATFORM_CONFIGS.items():
            # 检查域名模式
            for pattern in config.domain_patterns:
                if re.search(pattern, data, re.IGNORECASE):
                    return platform
            # 检查关键字模式
            for keyword in config.keyword_patterns:
                if keyword.lower() in data.lower():
                    return platform
        return Platform.UNKNOWN
    
    @classmethod
    def matches_platform_domain(cls, url: str, platform: Platform) -> bool:
        """检查 URL 是否匹配平台域名"""
        config = cls.PLATFORM_CONFIGS.get(platform)
        if not config:
            return False
        for pattern in config.domain_patterns:
            if re.search(pattern, url, re.IGNORECASE):
                return True
        return False
    
    @classmethod
    def extract_stream_url(cls, data: str, platform: Platform) -> Optional[str]:
        """提取推流地址"""
        if platform == Platform.UNKNOWN:
            # 尝试所有平台
            for p, patterns in cls.PATTERNS.items():
                for pattern in patterns['url_patterns']:
                    match = re.search(pattern, data, re.IGNORECASE)
                    if match:
                        return match.group(0)
        else:
            patterns = cls.PATTERNS.get(platform, {})
            for pattern in patterns.get('url_patterns', []):
                match = re.search(pattern, data, re.IGNORECASE)
                if match:
                    return match.group(0)
        return None
    
    @classmethod
    def extract_room_id(cls, data: str, platform: Platform) -> Tuple[str, str]:
        """提取room_id和web_rid"""
        room_id = ""
        web_rid = ""
        
        patterns = cls.PATTERNS.get(platform, cls.PATTERNS[Platform.DOUYIN])
        
        for pattern in patterns.get('room_id_patterns', []):
            match = re.search(pattern, data)
            if match:
                room_id = match.group(1)
                break
        
        for pattern in patterns.get('web_rid_patterns', []):
            match = re.search(pattern, data)
            if match:
                web_rid = match.group(1)
                break
        
        return room_id, web_rid


class PacketCaptureService:
    """
    网络抓包服务 - 完整版（移植自 stream_capture）
    
    功能特性:
    1. TCP 流重组 - 处理跨包的 RTMP 数据
    2. AMF0 协议解析 - 字节级别提取 streamKey
    3. 多平台专用解析器 - 抖音、小红书、B站、快手、京东
    4. tcUrl 缓存机制 - 关联 connect 和 publish 命令
    5. 去重机制 - 避免重复触发
    """
    
    # RTMP相关端口
    RTMP_PORTS = [1935, 443, 19350, 1936, 80, 8080]
    
    # TCP 流重组常量
    MAX_BUFFER_BYTES = 512 * 1024   # 512KB cap per flow
    TRIM_TO_BYTES = 128 * 1024      # trim to last 128KB
    REASM_WINDOW = 64 * 1024        # analyze last 64KB window
    FLOW_IDLE_TIMEOUT = timedelta(seconds=45)
    
    def __init__(self, debug_full_traffic: bool = False):
        self._capture_thread: Optional[threading.Thread] = None
        self._capture_threads: List[threading.Thread] = []  # 多接口捕获线程
        self._stop_capture = threading.Event()
        self._captured_streams: List[StreamInfo] = []
        self._lock = threading.Lock()
        
        # 回调
        self._on_stream_captured: Optional[Callable[[StreamInfo], None]] = None
        self._on_room_id_captured: Optional[Callable[[str, str], None]] = None
        self._on_error: Optional[Callable[[str], None]] = None
        self._on_log: Optional[Callable[[str], None]] = None
        
        # 状态
        self._is_capturing = False
        self._selected_interface: Optional[str] = None
        self._debug_full_traffic = debug_full_traffic
        self._total_packets_received = 0
        
        # TCP 流重组缓冲区
        self._flow_buffers: Dict[str, ReassemblyState] = {}
        self._packet_counter = 0
        
        # tcUrl 缓存（每个流的 tcUrl，来自 connect 命令）
        self._tc_url_by_flow: Dict[str, str] = {}
        
        # 去重：每个平台最后捕获的 URL+StreamKey
        self._last_key_by_platform: Dict[str, str] = {}
        
        # 已处理的URL（避免重复）
        self._processed_urls: set = set()
        
        # 日志文件
        self._log_file = None
        self._log_file_path = None
        self._setup_log_file()
    
    def _get_log_file_path(self) -> Path:
        """获取日志文件路径"""
        import sys
        if getattr(sys, 'frozen', False):
            base = Path(sys.executable).parent
        else:
            base = Path(__file__).parent.parent.parent
        log_dir = base / "logs" / "packet_capture"
        log_dir.mkdir(parents=True, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return log_dir / f"capture_{timestamp}.log"
    
    def _setup_log_file(self):
        """设置日志文件"""
        try:
            import sys
            self._log_file_path = self._get_log_file_path()
            self._log_file = open(self._log_file_path, 'w', encoding='utf-8', buffering=1)
            print(f"[PacketCapture] 日志文件: {self._log_file_path}")
            
            self._log_file.write(f"{'=' * 80}\n")
            self._log_file.write(f"DouyinLiveAssistant 网络抓包日志\n")
            self._log_file.write(f"{'=' * 80}\n")
            self._log_file.write(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            self._log_file.write(f"日志文件: {self._log_file_path}\n")
            self._log_file.write(f"调试模式: {self._debug_full_traffic}\n")
            self._log_file.write(f"端口: {self.RTMP_PORTS}\n")
            self._log_file.write(f"sys.frozen: {getattr(sys, 'frozen', False)}\n")
            self._log_file.write(f"scapy可用: {SCAPY_AVAILABLE}\n")
            self._log_file.write(f"Npcap已安装: {self.is_npcap_installed()}\n")
            self._log_file.write(f"{'=' * 80}\n\n")
        except Exception as e:
            print(f"[PacketCapture] 创建日志文件失败: {e}")
            self._log_file = None
    
    def _close_log_file(self):
        """关闭日志文件"""
        if self._log_file:
            try:
                self._log_file.write(f"\n{'=' * 60}\n")
                self._log_file.write(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                self._log_file.write(f"总数据包: {self._total_packets_received}\n")
                self._log_file.write(f"捕获结果: {len(self._captured_streams)} 个推流码\n")
                self._log_file.write(f"{'=' * 60}\n")
                self._log_file.close()
            except:
                pass
            self._log_file = None
    
    def _log(self, message: str):
        """日志输出 - 写入日志文件、控制台和回调"""
        timestamp = datetime.now().strftime('%H:%M:%S.%f')[:-3]
        full_msg = f"[{timestamp}] {message}"

        # 1. 写入日志文件
        if self._log_file:
            try:
                self._log_file.write(full_msg + "\n")
                self._log_file.flush()
            except:
                pass

        # 2. 输出到控制台（调试用）
        print(f"[PacketCapture] {message}")

        # 3. 回调到界面
        if self._on_log:
            try:
                self._on_log(message)
            except Exception as e:
                print(f"[PacketCapture] 日志回调失败: {e}")
        
    @property
    def is_capturing(self) -> bool:
        return self._is_capturing
    
    @staticmethod
    def is_npcap_installed() -> bool:
        """检查Npcap是否已安装 - 快速检测"""
        import os
        
        logger.info("[Npcap检测] 开始...")
        
        # 方法1: 检查Npcap DLL文件 (最快最可靠)
        npcap_dll_paths = [
            r"C:\Windows\System32\Npcap\wpcap.dll",
            r"C:\Windows\SysWOW64\Npcap\wpcap.dll",
        ]
        for path in npcap_dll_paths:
            logger.debug(f"[Npcap检测] 检查: {path}")
            if os.path.exists(path):
                logger.info(f"[Npcap检测] ✓ 找到DLL: {path}")
                return True
        
        # 方法2: 检查安装目录
        npcap_dir = r"C:\Program Files\Npcap"
        logger.debug(f"[Npcap检测] 检查目录: {npcap_dir}")
        if os.path.isdir(npcap_dir):
            logger.info(f"[Npcap检测] ✓ 找到安装目录: {npcap_dir}")
            return True
        
        # 方法3: 检查注册表
        try:
            import winreg
            logger.debug("[Npcap检测] 检查注册表...")
            key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, 
                               r"SOFTWARE\Npcap", 0, 
                               winreg.KEY_READ | winreg.KEY_WOW64_64KEY)
            winreg.CloseKey(key)
            logger.info("[Npcap检测] ✓ 注册表存在")
            return True
        except Exception as e:
            logger.debug(f"[Npcap检测] 注册表查询失败: {e}")
        
        logger.warning("[Npcap检测] ✗ 未找到Npcap")
        return False
    
    @staticmethod
    def get_network_interfaces() -> List[NetworkInterface]:
        """
        获取网络接口列表 - 增强版
        支持有线以太网、WiFi等所有类型的网络接口
        """
        interfaces = []
        
        if not SCAPY_AVAILABLE:
            return interfaces
        
        try:
            # Windows下获取接口
            win_interfaces = get_windows_if_list()
            
            # 需要过滤的虚拟网卡关键词
            skip_keywords = [
                'loopback', 'virtual', 'vmware', 'virtualbox', 'hyper-v',
                'wan miniport', 'npcap', 'filter', 'qos', 'scheduler',
                'lightweight', 'ndis', 'zerotier', 'tunnel', 'vpn',
                'teredo', 'isatap', '6to4', 'pseudo'
            ]
            
            for iface in win_interfaces:
                name = iface.get('name', '')
                description = iface.get('description', '')
                
                if not name:
                    continue
                
                # 过滤虚拟网卡
                desc_lower = description.lower()
                name_lower = name.lower()
                if any(kw in desc_lower or kw in name_lower for kw in skip_keywords):
                    continue
                
                # 获取IP
                ips = iface.get('ips', [])
                ip = ''
                for addr in ips:
                    if isinstance(addr, str) and '.' in addr and not addr.startswith('169.254'):
                        ip = addr
                        break
                
                # 有有效 IP 的接口才添加
                if ip:
                    interfaces.append(NetworkInterface(
                        name=name,
                        description=description,
                        ip=ip,
                        mac=iface.get('mac', ''),
                        is_up=True
                    ))
            
            logger.info(f"检测到 {len(interfaces)} 个有效网络接口")
            for iface in interfaces:
                logger.debug(f"  - {iface.description} ({iface.ip})")
                
        except Exception as e:
            logger.error(f"获取网络接口失败: {e}")
        
        return interfaces
    
    def start_capture(self,
                      interface: Optional[str] = None,
                      on_stream_captured: Optional[Callable[[StreamInfo], None]] = None,
                      on_room_id_captured: Optional[Callable[[str, str], None]] = None,
                      on_error: Optional[Callable[[str], None]] = None,
                      on_log: Optional[Callable[[str], None]] = None) -> bool:
        """
        开始抓包
        
        Args:
            interface: 网络接口名，None表示所有接口
            on_stream_captured: 捕获到推流码回调
            on_room_id_captured: 捕获到room_id回调 (room_id, web_rid)
            on_error: 错误回调
            on_log: 日志回调
        """
        if self._is_capturing:
            logger.warning("抓包已在进行中")
            return False
        
        if not SCAPY_AVAILABLE:
            error_msg = "scapy未安装，请执行: pip install scapy"
            logger.error(error_msg)
            if on_error:
                on_error(error_msg)
            return False
        
        if not self.is_npcap_installed():
            error_msg = "Npcap未安装，请先安装Npcap"
            logger.error(error_msg)
            if on_error:
                on_error(error_msg)
            return False
        
        self._on_stream_captured = on_stream_captured
        self._on_room_id_captured = on_room_id_captured
        self._on_error = on_error
        self._on_log = on_log
        self._selected_interface = interface
        self._stop_capture.clear()
        self._processed_urls.clear()
        self._flow_buffers.clear()
        self._tc_url_by_flow.clear()
        self._last_key_by_platform.clear()
        
        self._capture_thread = threading.Thread(target=self._capture_loop, daemon=True)
        self._capture_thread.start()
        self._is_capturing = True
        
        self._log(f"========== PacketCapture 初始化 ==========")
        self._log(f"开始网络抓包, 接口: {interface or '所有'}")
        self._log(f"监听端口: {', '.join(map(str, sorted(self.RTMP_PORTS)))}")
        self._log(f"=========================================")
        return True
    
    def stop_capture(self):
        """停止抓包"""
        self._log("停止抓包服务...")
        self._stop_capture.set()
        self._is_capturing = False
        
        # 等待单线程
        if self._capture_thread:
            self._capture_thread.join(timeout=3)
        
        # 等待多线程
        for thread in self._capture_threads:
            thread.join(timeout=2)
        self._capture_threads.clear()
        
        self._log(f"抓包统计: 总包数={self._total_packets_received}, 捕获推流码={len(self._captured_streams)}")
        self._close_log_file()
        logger.info("网络抓包已停止")
    
    def start_capture_all_interfaces(self,
                                     on_stream_captured: Callable = None,
                                     on_log: Callable = None,
                                     on_error: Callable = None) -> bool:
        """
        开始捕获 - 监听所有网卡
        """
        self._log(f"\n{'=' * 80}")
        self._log(f"启动抓包服务 (监听所有网卡)")
        self._log(f"{'=' * 80}")

        # ========== 诊断信息开始 ==========
        self._log(f"\n[诊断] 环境检测:")
        self._log(f"  - Python版本: {sys.version}")
        self._log(f"  - 操作系统: {sys.platform}")
        self._log(f"  - scapy可用: {SCAPY_AVAILABLE}")
        self._log(f"  - psutil可用: {PSUTIL_AVAILABLE}")

        # 检查管理员权限
        try:
            import ctypes
            is_admin = ctypes.windll.shell32.IsUserAnAdmin() != 0
            self._log(f"  - 管理员权限: {'✓ 是' if is_admin else '✗ 否 (建议以管理员身份运行)'}")
        except:
            self._log(f"  - 管理员权限: 无法检测")

        # 检查Npcap详细信息
        npcap_installed = self.is_npcap_installed()
        self._log(f"  - Npcap已安装: {'✓ 是' if npcap_installed else '✗ 否'}")
        if not npcap_installed:
            self._log(f"  [!] 请下载安装Npcap: https://npcap.com/")
            self._log(f"  [!] 安装时请勾选 'Install Npcap in WinPcap API-compatible Mode'")
        # ========== 诊断信息结束 ==========

        if self._is_capturing:
            self._log("已经在捕获中")
            return False

        if not SCAPY_AVAILABLE:
            self._log("❌ scapy 未安装，请执行: pip install scapy")
            return False

        if not npcap_installed:
            self._log("❌ Npcap 未安装，请访问 https://npcap.com/ 下载安装")
            return False
        
        self._on_stream_captured = on_stream_captured
        self._on_log = on_log
        self._on_error = on_error
        self._stop_capture.clear()
        self._flow_buffers.clear()
        self._tc_url_by_flow.clear()
        self._last_key_by_platform.clear()
        
        # 获取网络接口并记录详情
        interfaces = self._get_all_interfaces_with_log()
        
        if not interfaces:
            self._log("未检测到有效网卡，使用全局监听模式")
            thread = threading.Thread(target=self._capture_loop, daemon=True)
            thread.start()
            self._capture_threads.append(thread)
            self._is_capturing = True
            return True
        
        success_count = 0
        for iface in interfaces:
            try:
                name = iface.get('npf_name') or iface.get('name', '')
                if not name:
                    continue
                thread = threading.Thread(
                    target=self._capture_loop_interface,
                    args=(name,),
                    daemon=True
                )
                thread.start()
                self._capture_threads.append(thread)
                success_count += 1
                self._log(f"✓ 启动: {iface.get('description', name)} ({iface.get('ip', '')})")
            except Exception as e:
                self._log(f"⚠ 启动失败: {iface}: {e}")
        
        if success_count == 0:
            self._log("所有网卡启动失败，使用全局监听")
            thread = threading.Thread(target=self._capture_loop, daemon=True)
            thread.start()
            self._capture_threads.append(thread)
        
        self._is_capturing = True
        self._log(f"抓包服务启动完成，监听 {success_count} 个接口")
        self._log(f"端口: {self.RTMP_PORTS}")
        return True
    
    def _get_all_interfaces_with_log(self) -> List[dict]:
        """获取所有网络接口并记录日志"""
        interfaces = []

        if not SCAPY_AVAILABLE:
            self._log("[诊断] scapy不可用，无法获取网络接口")
            return interfaces

        try:
            win_ifaces = get_windows_if_list()
            npf_names = get_if_list()

            self._log(f"\n[诊断] 检测到 {len(win_ifaces)} 个Windows接口, {len(npf_names)} 个NPF设备")

            # 关键诊断：NPF设备数量
            if len(npf_names) == 0:
                self._log("[!] ⚠️ NPF设备数量为0，这通常意味着:")
                self._log("    1. Npcap未正确安装")
                self._log("    2. Npcap服务未启动 (尝试重启电脑)")
                self._log("    3. 安装时未勾选 'WinPcap API-compatible Mode'")
                self._log("    请重新安装Npcap: https://npcap.com/")
            else:
                self._log(f"[诊断] NPF设备列表: {npf_names[:5]}{'...' if len(npf_names) > 5 else ''}")
            
            # 记录所有接口详情
            for iface in win_ifaces:
                name = iface.get('name', '')
                desc = iface.get('description', '')
                ips = iface.get('ips', [])
                guid = iface.get('guid', '')
                self._log(f"  接口: {desc}, IP={ips}, GUID={guid[:20] if guid else 'N/A'}...")
            
            skip_keywords = [
                'loopback', 'virtual', 'vmware', 'virtualbox', 'hyper-v',
                'wan miniport', 'npcap', 'qos', 'scheduler', 'lightweight',
                'zerotier', 'tunnel', 'vpn', 'teredo', 'isatap', '6to4'
            ]
            
            for iface in win_ifaces:
                name = iface.get('name', '')
                desc = iface.get('description', '')
                ips = iface.get('ips', [])
                guid = iface.get('guid', '')
                
                if not name:
                    continue
                
                # 过滤虚拟网卡
                if any(kw in desc.lower() or kw in name.lower() for kw in skip_keywords):
                    self._log(f"  跳过虚拟网卡: {desc}")
                    continue
                
                # 获取有效IP
                ipv4_list = [ip for ip in ips if isinstance(ip, str) and '.' in ip and not ip.startswith('169.254')]
                if not ipv4_list:
                    self._log(f"  跳过无IP: {desc}")
                    continue
                
                # 匹配NPF名称
                npf_name = None
                if guid:
                    for npf in npf_names:
                        if guid.lower() in npf.lower():
                            npf_name = npf
                            break
                if not npf_name:
                    npf_name = name
                
                interfaces.append({
                    'name': name,
                    'npf_name': npf_name,
                    'description': desc,
                    'ip': ipv4_list[0],
                    'guid': guid
                })
                self._log(f"  ✓ 添加: {desc} ({ipv4_list[0]})")
            
            self._log(f"有效接口: {len(interfaces)} 个\n")
            
        except Exception as e:
            self._log(f"获取接口失败: {e}")
            import traceback
            self._log(traceback.format_exc())
        
        return interfaces
    
    def _capture_loop_interface(self, interface: str):
        """单接口捕获循环"""
        iface_display = interface[:40] if interface else "全局"
        self._log(f"[{iface_display}] 开始捕获")
        
        try:
            port_filter = ' or '.join([f'port {p}' for p in self.RTMP_PORTS])
            bpf_filter = f'tcp and ({port_filter})'
            
            while not self._stop_capture.is_set():
                try:
                    sniff(
                        iface=interface,
                        filter=bpf_filter,
                        prn=self._process_packet,
                        stop_filter=lambda x: self._stop_capture.is_set(),
                        store=False,
                        timeout=2
                    )
                except OSError as e:
                    if 'timeout' not in str(e).lower():
                        self._log(f"[{iface_display}] 错误: {e}")
                        break
                except Exception as e:
                    if 'timeout' not in str(e).lower():
                        self._log(f"[{iface_display}] 异常: {e}")
                        break
        except Exception as e:
            self._log(f"[{iface_display}] 捕获循环异常: {e}")
        
        self._log(f"[{iface_display}] 捕获结束")
    
    def _capture_loop(self):
        """
        抓包主循环 - 增强版
        支持有线以太网、WiFi等所有类型的网络接口
        """
        iface_display = self._selected_interface if self._selected_interface else "所有接口"
        try:
            # 构建过滤器 - TCP 限制端口，UDP 不限制（B站RTMPSRT使用动态UDP端口）
            port_filter = ' or '.join([f'port {p}' for p in self.RTMP_PORTS])
            bpf_filter = f'(tcp and ({port_filter})) or udp'
            self._log(f"过滤器: {bpf_filter}")
            self._log(f"接口: {iface_display}")
            
            # 循环抓包，每2秒检查一次停止标志
            while not self._stop_capture.is_set():
                try:
                    # self._selected_interface 为 None 时，scapy 会监听所有可用接口
                    sniff(
                        iface=self._selected_interface,
                        filter=bpf_filter,
                        prn=self._process_packet,
                        stop_filter=lambda x: self._stop_capture.is_set(),
                        store=False,
                        timeout=2
                    )
                except OSError as e:
                    # 处理接口不可用的错误
                    error_str = str(e).lower()
                    if 'no such device' in error_str or 'not found' in error_str:
                        self._log(f"⚠ 接口 {iface_display} 不可用，尝试使用全局模式...")
                        # 如果指定的接口不可用，尝试使用 None（监听所有接口）
                        if self._selected_interface is not None:
                            self._selected_interface = None
                            self._log("切换到全局监听模式")
                            continue
                        else:
                            raise
                    elif "timeout" not in error_str:
                        raise
                except Exception as e:
                    if "timeout" not in str(e).lower():
                        raise
        except Exception as e:
            error_msg = f"抓包异常 ({iface_display}): {e}"
            logger.error(error_msg)
            if self._on_error:
                self._on_error(error_msg)
        finally:
            self._is_capturing = False
    
    def _get_flow_key(self, packet) -> str:
        """获取 TCP 流的唯一标识"""
        try:
            if packet.haslayer(IP) and packet.haslayer(TCP):
                ip = packet[IP]
                tcp = packet[TCP]
                return f"{ip.src}:{tcp.sport}->{ip.dst}:{tcp.dport}"
        except:
            pass
        return ""
    
    def _append_to_flow(self, key: str, data: bytes) -> ReassemblyState:
        """将数据追加到 TCP 流缓冲区"""
        if key not in self._flow_buffers:
            self._flow_buffers[key] = ReassemblyState()
        state = self._flow_buffers[key]
        state.buffer.extend(data)
        state.last_seen = datetime.now()
        # 超过上限时截断
        if len(state.buffer) > self.MAX_BUFFER_BYTES:
            del state.buffer[:len(state.buffer) - self.TRIM_TO_BYTES]
        return state
    
    def _cleanup_old_flows(self):
        """清理超时的 TCP 流"""
        now = datetime.now()
        keys_to_remove = [k for k, s in self._flow_buffers.items() 
                         if (now - s.last_seen) > self.FLOW_IDLE_TIMEOUT]
        for key in keys_to_remove:
            del self._flow_buffers[key]
    
    def _is_duplicate(self, platform_id: str, url: str, stream_key: str) -> bool:
        """去重检查"""
        unique_key = f"{url}|{stream_key}"
        pid = platform_id.lower()
        if self._last_key_by_platform.get(pid) == unique_key:
            return True
        self._last_key_by_platform[pid] = unique_key
        return False
    
    def _extract_parameters(self, url: str) -> Dict[str, str]:
        """提取URL参数"""
        params = {}
        try:
            if '?' in url:
                query = url.split('?', 1)[1]
                for pair in query.split('&'):
                    if '=' in pair:
                        k, v = pair.split('=', 1)
                        params[k] = v
        except:
            pass
        return params
    
    def _process_packet(self, packet):
        """处理抓取的数据包 - 完整版"""
        try:
            # 计数
            self._total_packets_received += 1
            if self._total_packets_received % 500 == 0:
                self._log(f"📦 已处理 {self._total_packets_received} 个数据包")

            # 处理 UDP 包（B站 RTMPSRT）
            if packet.haslayer(UDP) and packet.haslayer(Raw):
                self._process_udp_packet(packet)
                return

            # 处理 TCP 包
            if not packet.haslayer(TCP) or not packet.haslayer(Raw):
                return

            tcp = packet[TCP]
            raw_data = bytes(packet[Raw].load)
            if len(raw_data) == 0:
                return

            # 端口限制
            if tcp.sport not in self.RTMP_PORTS and tcp.dport not in self.RTMP_PORTS:
                return

            # 获取流标识
            flow_key = self._get_flow_key(packet)

            # 解码为文本（使用 Latin-1 保持二进制安全）
            payload = raw_data.decode('latin-1', errors='ignore')

            # 检测RTMP特征时输出日志
            payload_lower = payload.lower()
            if 'rtmp' in payload_lower or 'publish' in payload_lower or 'connect' in payload_lower:
                self._log(f"🎯 检测到RTMP特征 端口:{tcp.sport}->{tcp.dport} 长度:{len(raw_data)}")

            # 单包分析
            self._analyze_payload(payload, raw_data, tcp, flow_key)

            # TCP 流重组
            if flow_key:
                state = self._append_to_flow(flow_key, raw_data)
                self._packet_counter += 1
                if self._packet_counter % 200 == 0:
                    self._cleanup_old_flows()
                # 分析重组窗口
                self._analyze_reassembled_window(flow_key, state, tcp)

        except Exception as e:
            logger.debug(f"处理数据包异常: {e}")
    
    def _process_udp_packet(self, packet):
        """处理 UDP 数据包（B站 RTMPSRT）"""
        try:
            raw_data = bytes(packet[Raw].load)
            payload = raw_data.decode('latin-1', errors='ignore')
            if len(payload) < 10:
                return
            
            # 检查抖音 UDP 特征（备用）
            if 'douyin' in payload.lower() or 'publish' in payload.lower():
                config = PlatformMatcher.PLATFORM_CONFIGS.get(Platform.DOUYIN)
                if config:
                    result = self._extract_rtmp_douyin(payload, raw_data, config)
                    if result:
                        self._emit_result(result)
        except Exception as e:
            logger.debug(f"处理UDP包异常: {e}")
    
    def _analyze_reassembled_window(self, flow_key: str, state: ReassemblyState, tcp):
        """分析 TCP 流重组窗口"""
        if len(state.buffer) == 0:
            return
        window_size = min(self.REASM_WINDOW, len(state.buffer))
        window_bytes = bytes(state.buffer[-window_size:])
        window_payload = window_bytes.decode('latin-1', errors='ignore')
        
        # 只有包含关键字时才分析
        if not ('publish' in window_payload or 'connect' in window_payload or 
                'tcUrl' in window_payload or 'rtmp://' in window_payload):
            return
        self._analyze_payload(window_payload, window_bytes, tcp, flow_key)
    
    def _analyze_payload(self, payload: str, raw_bytes: bytes, tcp, flow_key: str = ""):
        """分析 payload 数据 - 核心分析逻辑（移植自 stream_capture）"""
        for platform, config in PlatformMatcher.PLATFORM_CONFIGS.items():
            # 检查关键字或域名
            has_keyword = any(kw.lower() in payload.lower() for kw in config.keyword_patterns)
            has_domain = any(re.search(p, payload, re.IGNORECASE) for p in config.domain_patterns)
            
            
            # 检测 RTMP 命令
            has_rtmp_verb = ('publish' in payload.lower() or 'fcpublish' in payload.lower() or 'connect' in payload.lower())
            
            if not has_keyword and not has_domain and not has_rtmp_verb:
                continue
            
            # 检测到 RTMP 命令时尝试解析
            if has_rtmp_verb:
                result = self._extract_rtmp_publish(payload, raw_bytes, platform, config, flow_key)
                if result:
                    self._emit_result(result)
                    return
    
    def _emit_result(self, stream_info: StreamInfo):
        """输出捕获结果"""
        with self._lock:
            self._captured_streams.append(stream_info)
        
        self._log(f"✅ 捕获到推流码: {stream_info.platform.value}")
        self._log(f"  服务器: {stream_info.server_url}")
        self._log(f"  密钥: {stream_info.stream_key[:80]}..." if len(stream_info.stream_key) > 80 else f"  密钥: {stream_info.stream_key}")
        
        if self._on_stream_captured:
            self._on_stream_captured(stream_info)
    
    def _extract_rtmp_publish(self, payload: str, raw_bytes: bytes, platform: Platform, 
                              config: PlatformConfig, flow_key: str = "") -> Optional[StreamInfo]:
        """提取 RTMP publish 命令 - 根据平台选择不同的解析方法"""
        try:
            is_connect = 'connect' in payload.lower() and 'tcurl' in payload.lower()
            is_publish = ('publish' in payload.lower() or 'fcpublish' in payload.lower() or 
                         'releasestream' in payload.lower() or 'streamname' in payload.lower())
            
            if not is_connect and not is_publish:
                return None
            
            # 如果包含 connect 命令，提取并存储 tcUrl
            if is_connect and flow_key:
                tc_url_match = re.search(r'(rtmps?://[a-z0-9\.\-]+/[a-z0-9_\-]+)', payload, re.IGNORECASE)
                if tc_url_match:
                    self._tc_url_by_flow[flow_key] = tc_url_match.group(1)
                    self._log(f"[Flow {flow_key}] 存储tcUrl: {tc_url_match.group(1)}")
            
            if is_connect and not is_publish:
                return None
            
            # 只解析抖音平台
            if platform == Platform.DOUYIN:
                return self._extract_rtmp_douyin(payload, raw_bytes, config)
            return None
        except Exception as e:
            self._log(f"❌ RTMP解析异常: {e}")
            return None
    
    def _analyze_data(self, data: str):
        """分析数据，提取推流信息（兼容旧接口）"""
        platform = PlatformMatcher.detect_platform(data)
        stream_url = PlatformMatcher.extract_stream_url(data, platform)
        
        if stream_url and stream_url not in self._processed_urls:
            self._processed_urls.add(stream_url)
            stream_info = self._parse_stream_url(stream_url, platform)
            
            if stream_info and stream_info.is_valid:
                room_id, web_rid = PlatformMatcher.extract_room_id(data, platform)
                stream_info.room_id = room_id
                stream_info.web_rid = web_rid
                self._emit_result(stream_info)
                
                if room_id and self._on_room_id_captured:
                    self._on_room_id_captured(room_id, web_rid)
        
        elif platform != Platform.UNKNOWN:
            room_id, web_rid = PlatformMatcher.extract_room_id(data, platform)
            if room_id and room_id not in [s.room_id for s in self._captured_streams]:
                self._log(f"捕获到直播间ID: room_id={room_id}, web_rid={web_rid}")
                if self._on_room_id_captured:
                    self._on_room_id_captured(room_id, web_rid)
    
    def _parse_stream_url(self, url: str, platform: Platform) -> Optional[StreamInfo]:
        """解析推流URL"""
        try:
            # rtmp://xxx.com/live/stream_key
            # rtmps://xxx.com/live/stream_key?auth=xxx
            
            # 去除协议
            if url.startswith('rtmps://'):
                protocol = 'rtmps://'
                url_without_protocol = url[8:]
            elif url.startswith('rtmp://'):
                protocol = 'rtmp://'
                url_without_protocol = url[7:]
            else:
                return None
            
            # 分离路径和参数
            if '?' in url_without_protocol:
                path_part, query_part = url_without_protocol.split('?', 1)
            else:
                path_part = url_without_protocol
                query_part = ''
            
            # 分离服务器和流密钥
            parts = path_part.split('/')
            if len(parts) < 2:
                return None
            
            server = parts[0]
            app_name = parts[1] if len(parts) > 1 else 'live'
            stream_key = '/'.join(parts[2:]) if len(parts) > 2 else ''
            
            if query_part:
                stream_key = f"{stream_key}?{query_part}" if stream_key else query_part
            
            server_url = f"{protocol}{server}/{app_name}"
            
            return StreamInfo(
                platform=platform,
                server_url=server_url,
                stream_key=stream_key,
                full_url=url
            )
            
        except Exception as e:
            logger.debug(f"解析推流URL失败: {e}")
            return None
    
    # ========== 各平台专用解析器（移植自 stream_capture）==========
    
    def _extract_rtmp_douyin(self, payload: str, raw_bytes: bytes, config: PlatformConfig) -> Optional[StreamInfo]:
        """抖音专用RTMP解析（AMF0）"""
        self._log(f"[抖音] 使用专用解析器")
        
        # 1) 提取 Base URL
        base_url = ""
        url_match = re.search(r'(rtmps?://[^\x00]*?(push\.live\.douyin|live\.douyin|douyin|douyincdn|douyinpic)\.com/[^\x00]*?)[\x00]', payload, re.IGNORECASE)
        if url_match:
            base_url = re.sub(r'[\x00-\x1F]', '', url_match.group(1))
            base_url = re.sub(r'[^\x20-\x7E]', '', base_url)
        else:
            tc_match = re.search(r'tcUrl[\x00-\x20]+([^\x00]+?)[\x00]', payload, re.IGNORECASE)
            if tc_match:
                base_url = re.sub(r'[\x00-\x1F]', '', tc_match.group(1))
                base_url = re.sub(r'[^\x20-\x7E]', '', base_url)
        
        # 校验域名
        base_url_matches = any(re.search(p, base_url, re.IGNORECASE) for p in config.domain_patterns) if base_url else False
        
        # 2) AMF0 扫描 streamKey
        stream_key = ""
        publish_idx = payload.lower().find('publish')
        if publish_idx >= 0:
            scan_start = min(publish_idx + 7, len(raw_bytes) - 4)
            scan_end = min(len(raw_bytes) - 4, scan_start + 4096)
            for i in range(scan_start, scan_end):
                if raw_bytes[i] == 0x05 and raw_bytes[i + 1] == 0x02:
                    length = (raw_bytes[i + 2] << 8) | raw_bytes[i + 3]
                    key_start = i + 4
                    if length > 0 and key_start + length <= len(raw_bytes):
                        try:
                            candidate = raw_bytes[key_start:key_start + length].decode('utf-8', errors='ignore')
                            candidate = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', candidate)
                            if len(candidate) >= 4:
                                stream_key = re.sub(r'[^\x20-\x7E]', '', candidate)
                                break
                        except: pass
        
        # 正则后备
        if not stream_key:
            m = re.search(r'publish.{6,64}\x05\x02..([\x20-\x7E]{4,200})', payload, re.DOTALL)
            if m:
                stream_key = re.sub(r'[^\x20-\x7E]', '', m.group(1))
        
        if not stream_key or not base_url or not base_url_matches:
            return None
        
        if self._is_duplicate(config.id, base_url, stream_key):
            self._log("[抖音] 去重：忽略重复结果")
            return None
        
        return StreamInfo(platform=Platform.DOUYIN, server_url=base_url, stream_key=stream_key,
                         full_url=f"{base_url}/{stream_key}", protocol="RTMP",
                         parameters=self._extract_parameters(stream_key))
    
    
    def get_captured_streams(self) -> List[StreamInfo]:
        """获取所有捕获的推流信息"""
        with self._lock:
            return list(self._captured_streams)
    
    def get_latest_stream(self, platform: Optional[Platform] = None) -> Optional[StreamInfo]:
        """获取最新捕获的推流信息"""
        with self._lock:
            if not self._captured_streams:
                return None
            
            if platform:
                for stream in reversed(self._captured_streams):
                    if stream.platform == platform:
                        return stream
                return None
            
            return self._captured_streams[-1]
    
    def clear_captured(self):
        """清空捕获记录"""
        with self._lock:
            self._captured_streams.clear()
            self._processed_urls.clear()


class LogWatcherService:
    """
    日志监控服务
    作为抓包的备用方案，监控直播伴侣日志获取信息
    """
    
    def __init__(self):
        self._watch_thread: Optional[threading.Thread] = None
        self._stop_watch = threading.Event()
        self._on_room_id: Optional[Callable[[str], None]] = None
        self._last_position: Dict[str, int] = {}
    
    def start_watch(self, 
                    log_dir: str,
                    on_room_id: Optional[Callable[[str], None]] = None):
        """开始监控日志"""
        self._on_room_id = on_room_id
        self._stop_watch.clear()
        
        self._watch_thread = threading.Thread(
            target=self._watch_loop,
            args=(log_dir,),
            daemon=True
        )
        self._watch_thread.start()
        logger.info(f"开始监控日志目录: {log_dir}")
    
    def stop_watch(self):
        """停止监控"""
        self._stop_watch.set()
        if self._watch_thread:
            self._watch_thread.join(timeout=2)
    
    def _watch_loop(self, log_dir: str):
        """监控循环"""
        log_path = Path(log_dir)
        
        while not self._stop_watch.is_set():
            try:
                if not log_path.exists():
                    self._stop_watch.wait(timeout=1)
                    continue
                
                # 查找最新的日志文件
                log_files = list(log_path.glob("*.log"))
                if not log_files:
                    self._stop_watch.wait(timeout=1)
                    continue
                
                latest_log = max(log_files, key=lambda f: f.stat().st_mtime)
                
                # 读取新增内容
                self._read_new_content(str(latest_log))
                
            except Exception as e:
                logger.debug(f"日志监控异常: {e}")
            
            self._stop_watch.wait(timeout=0.5)
    
    def _read_new_content(self, log_file: str):
        """读取日志新增内容"""
        try:
            last_pos = self._last_position.get(log_file, 0)
            
            with open(log_file, 'r', encoding='utf-8', errors='ignore') as f:
                f.seek(last_pos)
                new_content = f.read()
                self._last_position[log_file] = f.tell()
            
            if new_content:
                self._analyze_log_content(new_content)
                
        except Exception as e:
            logger.debug(f"读取日志失败: {e}")
    
    def _analyze_log_content(self, content: str):
        """分析日志内容"""
        # 查找room_id
        room_id_patterns = [
            r'room_id["\s:=]+(\d+)',
            r'roomId["\s:=]+(\d+)',
            r'"room_id"\s*:\s*"?(\d+)"?',
        ]
        
        for pattern in room_id_patterns:
            match = re.search(pattern, content)
            if match:
                room_id = match.group(1)
                logger.info(f"从日志捕获到room_id: {room_id}")
                if self._on_room_id:
                    self._on_room_id(room_id)
                break
