# -*- coding: utf-8 -*-
"""
大眼瘦脸变形处理器
基于局部图像变形算法 (Local Image Warping)

原理：
使用反向映射 (Inverse Mapping) 和 cv2.remap
对于每个像素点 (x, y)，计算其在原图中的对应位置 (u, v)

1. 大眼：局部放大
   以眼睛为中心，将周围像素向中心拉伸（反向映射时，屏幕点取原图更靠中心的点）
   公式：u = c + (x - c) * (1 - strength * coeff)

2. 瘦脸：局部缩小/平移
   以脸颊为控制区域，将脸部轮廓向内推（反向映射时，屏幕点取原图更靠外的点）
   实际上是把脸部区域"缩小"，或者把背景区域"拉伸"进来
"""

import cv2
import numpy as np
import math
from typing import Optional, Tuple

try:
    from .face_landmarker import FaceLandmarker, LandmarkIndices
    MEDIAPIPE_AVAILABLE = True
except ImportError:
    try:
        from src.utils.face_landmarker import FaceLandmarker, LandmarkIndices
        MEDIAPIPE_AVAILABLE = True
    except ImportError:
        MEDIAPIPE_AVAILABLE = False


class FaceDeformation:
    """人脸变形处理器（大眼瘦脸）"""
    
    def __init__(self):
        # 变形参数 (0-100)
        self.big_eye_strength = 0    # 大眼强度
        self.face_lift_strength = 0  # 瘦脸强度
        
        self.enabled = True
        
        # 初始化人脸检测
        self.landmarker = None
        if MEDIAPIPE_AVAILABLE:
            self.landmarker = FaceLandmarker()
            
        # 缓存映射表，避免重复申请内存
        self._map_x = None
        self._map_y = None
        self._last_shape = None
        
    def set_params(self, big_eye: int = None, face_lift: int = None):
        """设置参数"""
        if big_eye is not None:
            self.big_eye_strength = max(0, min(100, big_eye))
        if face_lift is not None:
            self.face_lift_strength = max(0, min(100, face_lift))
            
    def process(self, frame: np.ndarray) -> np.ndarray:
        """处理一帧图像"""
        # 如果没有启用或强度都为0，直接返回
        if (not self.enabled or 
            (self.big_eye_strength == 0 and self.face_lift_strength == 0) or 
            self.landmarker is None):
            return frame
            
        # 获取人脸关键点
        landmarks_data = self.landmarker.process(frame)
        if not landmarks_data:
            return frame
            
        h, w = frame.shape[:2]
        
        # 初始化映射表 (如果是第一帧或尺寸变化)
        if self._map_x is None or self._last_shape != (h, w):
            self._map_x, self._map_y = np.meshgrid(np.arange(w), np.arange(h))
            self._map_x = self._map_x.astype(np.float32)
            self._map_y = self._map_y.astype(np.float32)
            self._last_shape = (h, w)
            
        # 重置映射表为初始状态
        # 为了性能，我们不仅是重置，而是在原图坐标基础上叠加偏移量
        # 由于每帧人脸位置不同，我们需要基于每一帧重新计算变形量
        # 直接使用 copy 可能会慢，但比重新 meshgrid 快
        # 优化：只创建一个 base map，每帧复制
        if not hasattr(self, '_base_map_x'):
            self._base_map_x, self._base_map_y = np.meshgrid(np.arange(w), np.arange(h))
            self._base_map_x = self._base_map_x.astype(np.float32)
            self._base_map_y = self._base_map_y.astype(np.float32)
            
        map_x = self._base_map_x.copy()
        map_y = self._base_map_y.copy()
        
        landmarks = landmarks_data.landmarks
        
        # 1. 应用大眼
        if self.big_eye_strength > 0:
            self._apply_big_eye(map_x, map_y, landmarks)
            
        # 2. 应用瘦脸
        if self.face_lift_strength > 0:
            self._apply_face_lift(map_x, map_y, landmarks)
            
        # 执行重映射
        result = cv2.remap(frame, map_x, map_y, interpolation=cv2.INTER_LINEAR)
        
        return result
        
    def _apply_big_eye(self, map_x, map_y, landmarks):
        """应用大眼变形"""
        left_eye_idx = LandmarkIndices.LEFT_EYE_CENTER
        right_eye_idx = LandmarkIndices.RIGHT_EYE_CENTER
        
        left_center = landmarks[left_eye_idx]
        right_center = landmarks[right_eye_idx]
        
        # 计算眼睛半径 (用眼角距离估计)
        # 左眼角: 33, 133
        l_p1 = landmarks[33]
        l_p2 = landmarks[133]
        r_radius = np.linalg.norm(l_p1 - l_p2) * 1.5 # 稍微扩大范围
        
        # 右眼角: 362, 263
        r_p1 = landmarks[362]
        r_p2 = landmarks[263]
        l_radius = np.linalg.norm(r_p1 - r_p2) * 1.5
        
        radius = max(l_radius, r_radius)
        strength = self.big_eye_strength / 100.0 * 0.5  # 缩放系数
        
        self._warp_local_scaling(map_x, map_y, left_center, radius, strength)
        self._warp_local_scaling(map_x, map_y, right_center, radius, strength)
        
    def _apply_face_lift(self, map_x, map_y, landmarks):
        """应用瘦脸变形"""
        # 瘦脸通常是把脸颊往嘴巴/鼻子方向推
        # 简单实现：选取左右下颌点，向中心点变形
        
        left_cheek = landmarks[LandmarkIndices.LEFT_CHEEK]   # 234
        right_cheek = landmarks[LandmarkIndices.RIGHT_CHEEK] # 454
        chin = landmarks[LandmarkIndices.CHIN]               # 152
        
        # 鼻子作为中心参考
        nose = landmarks[LandmarkIndices.NOSE_TIP]           # 1
        
        # 计算变形半径 (脸颊到鼻子的距离的一半)
        l_dist = np.linalg.norm(left_cheek - nose)
        r_dist = np.linalg.norm(right_cheek - nose)
        radius = max(l_dist, r_dist) * 0.8
        
        # 变形强度 (负值表示拉向目标，正值表示推开)
        # 瘦脸：屏幕上的脸颊位置，需要显示原图更靠外的内容 -> 局部缩小
        # 实际上我们希望把脸变窄，意味着背景“入侵”脸部区域
        # 这是一个向内的平移变形
        # 使用 warp_local_translation
        
        strength = self.face_lift_strength / 100.0 * 1.0
        
        # 左脸颊：向右推 (x增加)
        # 右脸颊：向左推 (x减小)
        # 或者直接指向鼻子
        
        self._warp_local_translation(map_x, map_y, left_cheek, nose, radius, strength)
        self._warp_local_translation(map_x, map_y, right_cheek, nose, radius, strength)
        
    def _warp_local_scaling(self, map_x, map_y, center, radius, strength):
        """
        局部缩放变形 (大眼)
        center: 变形中心
        radius: 变形半径
        strength: 变形强度 (>0 放大, <0 缩小)
        """
        cx, cy = center
        
        # 截取感兴趣区域 (ROI) 以加速计算
        x_min = int(max(0, cx - radius))
        x_max = int(min(map_x.shape[1], cx + radius + 1))
        y_min = int(max(0, cy - radius))
        y_max = int(min(map_x.shape[0], cy + radius + 1))
        
        if x_min >= x_max or y_min >= y_max:
            return
            
        # 提取ROI坐标
        grid_x = map_x[y_min:y_max, x_min:x_max]
        grid_y = map_y[y_min:y_max, x_min:x_max]
        
        # 计算相对于中心的偏移
        dx = grid_x - cx
        dy = grid_y - cy
        d2 = dx*dx + dy*dy
        r2 = radius * radius
        
        # 蒙版：只处理半径内的点
        mask = d2 < r2
        
        # 变形公式
        # r = sqrt(d2)
        # scale = 1 - strength * (1 - r/radius)^2
        # map = center + (grid - center) * scale
        
        # 为了避免开方，使用 d2/r2
        # (1 - r/radius)^2 = (1 - sqrt(d2)/radius)^2
        # 这是一个钟形曲线，中心 strength 最大，边缘为 0
        
        # 归一化距离平方
        dist_ratio = d2 / r2
        # 简单的线性衰减可能不够平滑，使用 (1 - sqrt(dist))^2
        # 这里使用近似：scale = 1.0 - strength * (1.0 - dist_ratio) 
        # 中心变形最大
        
        # 更准确的大眼公式 (Interactive Image Warping):
        # f(r) = r * (1 - strength * (1 - r/R)^2)
        # 这里我们需要反向映射，放大效果对应 k < 1
        # 所以 strength > 0 时，scale < 1
        
        # 计算 scale
        # 只有 mask 为 True 的地方计算
        d = np.sqrt(d2[mask])
        r = radius
        
        # scale = 1.0 - strength * (1.0 - d/r)**2
        # 放大：取原图更靠里的点，所以偏移量减小，scale < 1
        scale = 1.0 - strength * (1.0 - d/r)**2
        
        # 应用变形
        # new_pos = center + offset * scale
        # offset = grid - center
        
        grid_x[mask] = cx + dx[mask] * scale
        grid_y[mask] = cy + dy[mask] * scale
        
    def _warp_local_translation(self, map_x, map_y, start_point, end_point, radius, strength):
        """
        局部平移变形 (瘦脸)
        start_point: 变形起始中心 (脸颊)
        end_point: 变形方向目标点 (鼻子)
        radius: 影响半径
        strength: 变形强度
        """
        cx, cy = start_point
        tx, ty = end_point
        
        # 变形向量：从 start 指向 end
        vx = tx - cx
        vy = ty - cy
        norm = math.sqrt(vx*vx + vy*vy)
        if norm == 0:
            return
            
        vx /= norm
        vy /= norm
        
        # 截取 ROI
        x_min = int(max(0, cx - radius))
        x_max = int(min(map_x.shape[1], cx + radius + 1))
        y_min = int(max(0, cy - radius))
        y_max = int(min(map_x.shape[0], cy + radius + 1))
        
        if x_min >= x_max or y_min >= y_max:
            return
            
        grid_x = map_x[y_min:y_max, x_min:x_max]
        grid_y = map_y[y_min:y_max, x_min:x_max]
        
        dx = grid_x - cx
        dy = grid_y - cy
        d2 = dx*dx + dy*dy
        r2 = radius * radius
        
        mask = d2 < r2
        
        # 瘦脸：屏幕上的点 (ROI内) 需要显示 原图更靠外(反方向) 的点
        # 向量 V 指向脸内 (start -> end)
        # 我们希望把脸变窄，即屏幕上靠近 V 方向的点，其实显示的是原图上还没到那里的点？
        # 不，瘦脸是把宽脸变窄。
        # 屏幕上的脸颊轮廓 X (窄)，应该显示 原图的脸颊轮廓 X' (宽)。
        # X' 比 X 更靠外 (逆着 V 方向)。
        # 所以 map = grid - strength * V * weight
        
        d = np.sqrt(d2[mask])
        r = radius
        
        # 权重函数：中心最大，边缘为0
        weight = (1.0 - d/r)**2
        
        # 变形量
        move_x = -vx * strength * 100 * weight  # 100是像素系数
        move_y = -vy * strength * 100 * weight
        
        grid_x[mask] += move_x
        grid_y[mask] += move_y

