ai-business-write/services/ai_service.py
2025-12-04 14:41:20 +08:00

159 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI服务 - 封装大模型调用
支持硅基流动和华为大模型(预留)
"""
import os
import requests
import json
from typing import Dict, List, Optional
class AIService:
"""AI服务类"""
def __init__(self):
self.siliconflow_api_key = os.getenv('SILICONFLOW_API_KEY')
self.siliconflow_model = os.getenv('SILICONFLOW_MODEL', 'deepseek-ai/DeepSeek-V3.2-Exp')
self.siliconflow_url = "https://api.siliconflow.cn/v1/chat/completions"
# 华为大模型配置(预留)
self.huawei_api_endpoint = os.getenv('HUAWEI_API_ENDPOINT')
self.huawei_api_key = os.getenv('HUAWEI_API_KEY')
# 确定使用的AI服务
self.ai_provider = self._determine_ai_provider()
def _determine_ai_provider(self) -> str:
"""确定使用的AI服务提供商"""
if self.siliconflow_api_key:
return 'siliconflow'
elif self.huawei_api_endpoint and self.huawei_api_key:
return 'huawei'
else:
return 'none'
def extract_fields(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
从提示词中提取结构化字段
Args:
prompt: AI提示词
output_fields: 输出字段列表
Returns:
提取的字段字典,格式: {field_code: field_value}
"""
if self.ai_provider == 'none':
raise Exception("未配置AI服务请设置SILICONFLOW_API_KEY或华为大模型配置")
if self.ai_provider == 'siliconflow':
return self._extract_with_siliconflow(prompt, output_fields)
elif self.ai_provider == 'huawei':
return self._extract_with_huawei(prompt, output_fields)
else:
raise Exception(f"未知的AI服务提供商: {self.ai_provider}")
def _extract_with_siliconflow(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""使用硅基流动API提取字段"""
try:
payload = {
"model": self.siliconflow_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 2000
}
headers = {
"Authorization": f"Bearer {self.siliconflow_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.siliconflow_url,
json=payload,
headers=headers,
timeout=30
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 尝试解析JSON
try:
# 如果返回的是代码块提取JSON部分
if '```json' in content:
json_start = content.find('```json') + 7
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
elif '```' in content:
json_start = content.find('```') + 3
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
extracted_data = json.loads(content)
return extracted_data
except json.JSONDecodeError:
# 如果不是JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
raise Exception("AI服务调用超时")
except Exception as e:
raise Exception(f"AI服务调用失败: {str(e)}")
def _extract_with_huawei(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""使用华为大模型API提取字段预留实现"""
# TODO: 实现华为大模型接口调用
raise Exception("华为大模型接口暂未实现请使用硅基流动API")
def _parse_text_response(self, text: str, output_fields: List[Dict]) -> Dict:
"""
从文本响应中解析字段值(备用方案)
"""
result = {}
for field in output_fields:
field_code = field['field_code']
field_name = field['name']
# 尝试在文本中查找字段值
# 这里使用简单的关键词匹配,实际可以更复杂
if field_name in text:
# 提取字段值(简单实现)
start_idx = text.find(field_name)
if start_idx != -1:
# 查找冒号后的内容
colon_idx = text.find(':', start_idx)
if colon_idx != -1:
value_start = colon_idx + 1
value_end = text.find('\n', value_start)
if value_end == -1:
value_end = len(text)
value = text[value_start:value_end].strip()
result[field_code] = value
else:
result[field_code] = ''
else:
result[field_code] = ''
else:
result[field_code] = ''
return result