488 lines
20 KiB
Python
488 lines
20 KiB
Python
"""
|
||
大模型 API 客户端
|
||
"""
|
||
import json
|
||
import hashlib
|
||
import asyncio
|
||
from typing import Optional, Dict, Any, List
|
||
import httpx
|
||
from app.core.config import settings
|
||
from app.core.exceptions import LLMAPIException
|
||
from app.utils.logger import logger
|
||
from app.utils.cache import get_cache_manager
|
||
|
||
|
||
class LLMClient:
|
||
"""大模型 API 客户端"""
|
||
|
||
def __init__(self, model: Optional[str] = None):
|
||
"""初始化 LLM 客户端"""
|
||
self.model = model or settings.DEFAULT_LLM_MODEL
|
||
self.timeout = settings.LLM_TIMEOUT
|
||
self.max_retries = settings.LLM_MAX_RETRIES
|
||
self.cache_manager = get_cache_manager()
|
||
|
||
async def call(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
model: Optional[str] = None,
|
||
use_cache: bool = True,
|
||
**kwargs
|
||
) -> str:
|
||
"""
|
||
调用大模型 API
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
system_prompt: 系统提示词
|
||
temperature: 温度参数
|
||
model: 模型名称
|
||
use_cache: 是否使用缓存
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
模型返回的文本内容
|
||
"""
|
||
model = model or self.model
|
||
temperature = temperature or settings.DEFAULT_TEMPERATURE
|
||
|
||
# 根据模型类型选择调用方法
|
||
# 通义千问(DashScope)
|
||
if model.startswith("qwen") and "siliconflow" not in model.lower():
|
||
return await self._call_qwen(prompt, system_prompt, temperature, model, use_cache, **kwargs)
|
||
# OpenAI
|
||
elif model.startswith("gpt") or model.startswith("openai"):
|
||
return await self._call_openai(prompt, system_prompt, temperature, model, use_cache, **kwargs)
|
||
# 硅基流动(支持 deepseek、qwen 等模型)
|
||
elif model.startswith("siliconflow") or model.startswith("deepseek") or \
|
||
model in ["deepseek-chat", "deepseek-coder", "qwen-turbo", "qwen-plus", "qwen-max"]:
|
||
return await self._call_siliconflow(prompt, system_prompt, temperature, model, use_cache, **kwargs)
|
||
# 视觉大模型(Qwen3-VL)
|
||
elif model.startswith("Qwen") or model.startswith("Qwen3"):
|
||
return await self._call_vision_model(prompt, system_prompt, temperature, model, use_cache, **kwargs)
|
||
else:
|
||
raise LLMAPIException(
|
||
f"不支持的大模型: {model}。支持的模型: qwen-* (通义千问), gpt-* (OpenAI), "
|
||
f"deepseek-* (硅基流动), Qwen/Qwen3-VL (视觉模型), 或配置 SILICONFLOW_API_KEY 使用硅基流动平台"
|
||
)
|
||
|
||
async def _call_qwen(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
model: str = "qwen-max",
|
||
use_cache: bool = True,
|
||
**kwargs
|
||
) -> str:
|
||
"""调用通义千问 API"""
|
||
if not settings.DASHSCOPE_API_KEY:
|
||
raise LLMAPIException("未配置 DASHSCOPE_API_KEY")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
payload = {
|
||
"model": model,
|
||
"input": {"messages": messages},
|
||
"parameters": {
|
||
"temperature": temperature,
|
||
"result_format": "message",
|
||
**kwargs
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.DASHSCOPE_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
logger.debug(f"通义千问 API 请求 - 模型: {model}, 消息数量: {len(messages)}")
|
||
|
||
# 检查缓存
|
||
if use_cache:
|
||
# 生成缓存键(基于提示词的哈希)
|
||
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:16]
|
||
key = self.cache_manager._generate_key("llm", model, str(temperature), prompt_hash)
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.cache_manager.get("llm", model, str(temperature), prompt_hash)
|
||
|
||
if cached:
|
||
logger.info(f"LLM 响应缓存命中: {key}")
|
||
return cached
|
||
|
||
logger.debug(f"LLM 响应缓存未命中: {key}")
|
||
|
||
# 调用 API
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = await client.post(
|
||
settings.DASHSCOPE_BASE_URL,
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# 解析响应
|
||
content = result["output"]["choices"][0]["message"]["content"]
|
||
logger.info(f"通义千问 API 调用成功 (attempt {attempt + 1})")
|
||
|
||
# 缓存响应(如果成功)
|
||
if use_cache:
|
||
success = await self.cache_manager.set("llm", model, str(temperature), prompt_hash, data=content, ttl=3600)
|
||
if success:
|
||
logger.info(f"LLM 响应已缓存: {key}")
|
||
else:
|
||
logger.warning(f"LLM 响应缓存设置失败")
|
||
|
||
return content
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"通义千问 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"通义千问 API 调用失败: {str(e)}",
|
||
retryable=True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
except httpx.HTTPError as e:
|
||
logger.error(f"通义千问 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"通义千问 API 调用失败: {str(e)}",
|
||
error_detail=str(e),
|
||
retryable=True
|
||
)
|
||
|
||
async def _call_openai(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
model: str = "gpt-4",
|
||
use_cache: bool = True,
|
||
**kwargs
|
||
) -> str:
|
||
"""调用 OpenAI API"""
|
||
if not settings.OPENAI_API_KEY:
|
||
raise LLMAPIException("未配置 OPENAI_API_KEY")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
**kwargs
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.OPENAI_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
logger.debug(f"OpenAI API 请求 - 模型: {model}, 消息数量: {len(messages)}")
|
||
|
||
# 检查缓存
|
||
if use_cache:
|
||
# 生成缓存键(基于提示词的哈希)
|
||
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:16]
|
||
key = self.cache_manager._generate_key("llm", model, str(temperature), prompt_hash)
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.cache_manager.get("llm", model, str(temperature), prompt_hash)
|
||
|
||
if cached:
|
||
logger.info(f"OpenAI 响应缓存命中: {key}")
|
||
return cached
|
||
|
||
logger.debug(f"OpenAI 响应缓存未命中: {key}")
|
||
|
||
# 调用 API
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = await client.post(
|
||
settings.OPENAI_BASE_URL,
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# 解析响应
|
||
content = result["choices"][0]["message"]["content"]
|
||
logger.info(f"OpenAI API 调用成功 (attempt {attempt + 1})")
|
||
|
||
# 缓存响应(如果成功)
|
||
if use_cache:
|
||
success = await self.cache_manager.set("llm", model, str(temperature), prompt_hash, data=content, ttl=3600)
|
||
if success:
|
||
logger.info(f"OpenAI 响应已缓存: {key}")
|
||
else:
|
||
logger.warning(f"OpenAI 响应缓存设置失败")
|
||
|
||
return content
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"OpenAI API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"OpenAI API 调用失败: {str(e)}",
|
||
retryable=True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
except httpx.HTTPError as e:
|
||
logger.error(f"OpenAI API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"OpenAI API 调用失败: {str(e)}",
|
||
error_detail=str(e),
|
||
retryable=True
|
||
)
|
||
|
||
async def _call_siliconflow(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
model: str = "deepseek-chat",
|
||
use_cache: bool = True,
|
||
**kwargs
|
||
) -> str:
|
||
"""调用硅基流动 API"""
|
||
if not settings.SILICONFLOW_API_KEY:
|
||
raise LLMAPIException("未配置 SILICONFLOW_API_KEY")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
**kwargs
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.SILICONFLOW_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
logger.debug(f"硅基流动 API 请求 - 模型: {model}, 消息数量: {len(messages)}")
|
||
|
||
# 检查缓存
|
||
if use_cache:
|
||
# 生成缓存键(基于提示词的哈希)
|
||
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:16]
|
||
key = self.cache_manager._generate_key("llm", model, str(temperature), prompt_hash)
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.cache_manager.get("llm", model, str(temperature), prompt_hash)
|
||
|
||
if cached:
|
||
logger.info(f"硅基流动 API 响应缓存命中: {key}")
|
||
return cached
|
||
|
||
logger.debug(f"硅基流动 API 响应缓存未命中: {key}")
|
||
|
||
# 调用 API
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = await client.post(
|
||
settings.SILICONFLOW_BASE_URL,
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# 解析响应(硅基流动格式与 OpenAI 兼容)
|
||
content = result["choices"][0]["message"]["content"]
|
||
logger.info(f"硅基流动 API 调用成功 (attempt {attempt + 1})")
|
||
|
||
# 缓存响应(如果成功)
|
||
if use_cache:
|
||
success = await self.cache_manager.set("llm", model, str(temperature), prompt_hash, data=content, ttl=3600)
|
||
if success:
|
||
logger.info(f"硅基流动 API 响应已缓存: {key}")
|
||
else:
|
||
logger.warning(f"硅基流动 API 响应缓存设置失败")
|
||
|
||
return content
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"硅基流动 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"硅基流动 API 调用失败: {str(e)}",
|
||
retryable=True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
except httpx.HTTPError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"硅基流动 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"硅基流动 API 调用失败: {str(e)}",
|
||
error_detail=str(e),
|
||
retryable=hasattr(e, 'response') and e.response and e.response.status_code in [429, 500, 502, 503, 504] if hasattr(e, 'response') and e.response else True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
async def _call_vision_model(
|
||
self,
|
||
prompt: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.3,
|
||
model: str = "Qwen/Qwen3-VL-32B-Instruct",
|
||
use_cache: bool = True,
|
||
**kwargs
|
||
) -> str:
|
||
"""调用视觉大模型(Qwen3-VL)"""
|
||
if not settings.VISION_MODEL:
|
||
raise LLMAPIException("未配置 VISION_MODEL")
|
||
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
**kwargs
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.SILICONFLOW_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
logger.debug(f"视觉大模型 API 请求 - 模型: {model}, 消息数量: {len(messages)}")
|
||
|
||
# 检查缓存
|
||
if use_cache:
|
||
# 生成缓存键(基于提示词的哈希)
|
||
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:16]
|
||
key = self.cache_manager._generate_key("llm", model, str(temperature), prompt_hash)
|
||
|
||
# 尝试从缓存获取
|
||
cached = await self.cache_manager.get("llm", model, str(temperature), prompt_hash)
|
||
|
||
if cached:
|
||
logger.info(f"视觉大模型 API 响应缓存命中: {key}")
|
||
return cached
|
||
|
||
logger.debug(f"视觉大模型 API 响应缓存未命中: {key}")
|
||
|
||
# 调用 API
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
response = await client.post(
|
||
settings.VISION_MODEL_BASE_URL,
|
||
headers=headers,
|
||
json=payload
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# 解析响应
|
||
content = result["choices"][0]["message"]["content"]
|
||
logger.info(f"视觉大模型 API 调用成功 (attempt {attempt + 1})")
|
||
|
||
# 缓存响应(如果成功)
|
||
if use_cache:
|
||
success = await self.cache_manager.set("llm", model, str(temperature), prompt_hash, data=content, ttl=3600)
|
||
if success:
|
||
logger.info(f"视觉大模型 API 响应已缓存: {key}")
|
||
else:
|
||
logger.warning(f"视觉大模型 API 响应缓存设置失败")
|
||
|
||
return content
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"视觉大模型 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"视觉大模型 API 调用失败: {str(e)}",
|
||
retryable=True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
except httpx.HTTPError as e:
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"视觉大模型 API 调用失败: {str(e)}")
|
||
raise LLMAPIException(
|
||
f"视觉大模型 API 调用失败: {str(e)}",
|
||
error_detail=str(e),
|
||
retryable=hasattr(e, 'response') and e.response and e.response.status_code in [429, 500, 502, 503, 504] if hasattr(e, 'response') and e.response else True
|
||
)
|
||
# 指数退避
|
||
wait_time = 2 ** attempt
|
||
logger.warning(f"API 调用失败,{wait_time}秒后重试 (attempt {attempt + 1})")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
def parse_json_response(self, response_text: str) -> Dict[str, Any]:
|
||
"""
|
||
解析大模型返回的 JSON 结果
|
||
|
||
Args:
|
||
response_text: 模型返回的文本
|
||
|
||
Returns:
|
||
解析后的 JSON 字典
|
||
"""
|
||
try:
|
||
# 提取 JSON 部分(如果返回的是 Markdown 格式)
|
||
text = response_text.strip()
|
||
if "```json" in text:
|
||
json_text = text.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in text:
|
||
json_text = text.split("```")[1].split("```")[0].strip()
|
||
else:
|
||
json_text = text
|
||
|
||
# 解析 JSON
|
||
result = json.loads(json_text)
|
||
return result
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON 解析失败: {str(e)}")
|
||
logger.error(f"原始响应: {response_text[:500]}")
|
||
raise LLMAPIException(f"大模型返回的 JSON 格式错误: {str(e)}")
|
||
|
||
|
||
# 全局 LLM 客户端实例
|
||
llm_client = LLMClient()
|