# -*- coding: utf-8 -*-
import json
import uuid
import requests
import base64
import os
import time
import asyncio
from .base import BaseTTS
from src.utils.logger import logger

class VolcEngineTTS(BaseTTS):
    def __init__(self, app_id, access_token, voice_type="BV001_streaming", cluster="volcano_tts", output_dir="./output"):
        super().__init__()
        self.app_id = app_id
        self.access_token = access_token
        self.voice_type = voice_type
        self.cluster = cluster
        self.output_dir = output_dir
        
        # 确保输出目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        self.api_url = "https://openspeech.bytedance.com/api/v1/tts"

    async def synthesize(self, text: str) -> str:
        if not text:
            return ""
            
        try:
            # 构造请求体
            req = {
                "app": {
                    "appid": self.app_id,
                    "token": self.access_token,
                    "cluster": self.cluster
                },
                "user": {
                    "uid": "douyin_assistant"
                },
                "audio": {
                    "voice_type": self.voice_type,
                    "encoding": "mp3",
                    "speed_ratio": 1.0,
                    "volume_ratio": 1.0,
                    "pitch_ratio": 1.0,
                },
                "request": {
                    "reqid": str(uuid.uuid4()),
                    "text": text,
                    "text_type": "plain",
                    "operation": "query",
                    "with_frontend": 1,
                    "frontend_type": "unitTson"
                }
            }
            
            headers = {
                "Authorization": f"Bearer; {self.access_token}",
                "Content-Type": "application/json"
            }
            
            # 使用 asyncio.to_thread 避免阻塞主线程
            response = await asyncio.to_thread(
                requests.post, 
                self.api_url, 
                json=req, 
                headers=headers,
                timeout=5
            )
            
            if response.status_code != 200:
                logger.error(f"VolcEngine API Error: {response.status_code} - {response.text}")
                return ""
                
            resp_json = response.json()
            
            if "code" in resp_json and resp_json["code"] != 3000:
                logger.error(f"VolcEngine Logic Error: {resp_json}")
                return ""
                
            if "data" not in resp_json:
                logger.error(f"VolcEngine No Data: {resp_json}")
                return ""
                
            data = resp_json["data"]
            if not data:
                return ""
                
            # Base64解码音频
            audio_data = base64.b64decode(data)
            
            # 保存文件
            filename = f"volc_{int(time.time() * 1000)}.mp3"
            # 获取绝对路径
            abs_output_dir = os.path.abspath(self.output_dir)
            filepath = os.path.join(abs_output_dir, filename)
            
            with open(filepath, "wb") as f:
                f.write(audio_data)
                
            return filepath
            
        except Exception as e:
            logger.error(f"VolcTTS Synthesis Error: {e}")
            return ""
