ai-business-write/services/ai_service.py

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI服务 - 封装大模型调用
仅支持华为大模型
"""
import os
import re
import time
import requests
import json
from typing import Dict, List, Optional
class AIService:
"""AI服务类"""
def __init__(self):
# 华为大模型配置(必需)
self.huawei_api_endpoint = os.getenv('HUAWEI_API_ENDPOINT', 'http://10.100.31.26:3001/v1/chat/completions')
self.huawei_api_key = os.getenv('HUAWEI_API_KEY', 'sk-PoeiV3qwyTIRqcVc84E8E24cD2904872859a87922e0d9186')
self.huawei_model = os.getenv('HUAWEI_MODEL', 'DeepSeek-R1-Distill-Llama-70B')
# 确定使用的AI服务
self.ai_provider = self._determine_ai_provider()
def _determine_ai_provider(self) -> str:
"""确定使用的AI服务提供商仅支持华为大模型"""
if self.huawei_api_endpoint and self.huawei_api_key:
return 'huawei'
else:
return 'none'
def extract_fields(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
从提示词中提取结构化字段
Args:
prompt: AI提示词
output_fields: 输出字段列表
Returns:
提取的字段字典,格式: {field_code: field_value}
"""
if self.ai_provider == 'none':
raise Exception("未配置华为大模型服务,请设置 HUAWEI_API_KEY 和 HUAWEI_API_ENDPOINT")
if self.ai_provider == 'huawei':
return self._extract_with_huawei(prompt, output_fields)
else:
raise Exception(f"未知的AI服务提供商: {self.ai_provider}")
def _extract_with_siliconflow(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
使用硅基流动API提取字段已不再使用仅保留用于参考
系统仅支持华为大模型,不再支持自动回退
"""
try:
payload = {
"model": self.siliconflow_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 2000
}
headers = {
"Authorization": f"Bearer {self.siliconflow_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.siliconflow_url,
json=payload,
headers=headers,
timeout=30
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 尝试解析JSON
try:
# 如果返回的是代码块提取JSON部分
if '```json' in content:
json_start = content.find('```json') + 7
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
elif '```' in content:
json_start = content.find('```') + 3
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
extracted_data = json.loads(content)
return extracted_data
except json.JSONDecodeError:
# 如果不是JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
raise Exception("AI服务调用超时")
except Exception as e:
raise Exception(f"AI服务调用失败: {str(e)}")
def _extract_with_huawei(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
使用华为大模型API提取字段带重试机制
至少重试3次总共最多尝试4次
"""
max_retries = 3 # 最多重试3次总共4次尝试
retry_delay = 1 # 重试延迟(秒),每次重试延迟递增
last_exception = None
for attempt in range(max_retries + 1): # 0, 1, 2, 3 (总共4次)
try:
if attempt > 0:
# 重试前等待延迟时间递增1秒、2秒、3秒
wait_time = retry_delay * attempt
print(f"[AI服务] 第 {attempt} 次重试,等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
print(f"[AI服务] 正在调用华为大模型API (尝试 {attempt + 1}/{max_retries + 1})...")
result = self._call_huawei_api_once(prompt, output_fields)
if result is not None:
if attempt > 0:
print(f"[AI服务] 重试成功!")
return result
except requests.exceptions.Timeout as e:
last_exception = e
error_msg = f"AI服务调用超时 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except requests.exceptions.ConnectionError as e:
last_exception = e
error_msg = f"连接错误 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except requests.exceptions.RequestException as e:
last_exception = e
error_msg = f"请求异常 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except Exception as e:
last_exception = e
error_msg = f"AI服务调用失败 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
# 对于其他类型的错误,也进行重试
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
# 如果所有重试都失败了
if last_exception:
raise Exception(f"AI服务调用失败已重试 {max_retries} 次: {str(last_exception)}")
else:
raise Exception(f"AI服务调用失败已重试 {max_retries}")
def _call_huawei_api_once(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
单次调用华为大模型API不包含重试逻辑
"""
payload = {
"model": self.huawei_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手擅长从非结构化文本中准确提取结构化信息。你需要仔细分析输入文本识别所有相关信息并按照指定的JSON格式返回结果。请确保提取的信息准确、完整对于能够从文本中推断出的信息也要尽量提取。"
},
{
"role": "user",
"content": prompt
}
],
"stream": False,
"presence_penalty": 1.03,
"frequency_penalty": 1.0,
"repetition_penalty": 1.0,
"temperature": 0.3,
"top_p": 0.95,
"top_k": 1,
"seed": 1,
"max_tokens": 8192,
"n": 1,
"enable_thinking": True
}
headers = {
"Authorization": f"Bearer {self.huawei_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.huawei_api_endpoint,
json=payload,
headers=headers,
timeout=60
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 处理思考过程标签(华为大模型可能返回思考过程)
# 移除思考过程标签之前的内容,只保留实际回答
# 根据用户提供的示例,华为大模型使用 </think> 标签
if '</think>' in content:
content = content.split('</think>')[-1].strip()
# 尝试解析JSON
extracted_data = self._extract_json_from_text(content)
if extracted_data:
return extracted_data
# 如果无法提取JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常未找到choices字段或choices为空")
def _extract_json_from_text(self, text: str) -> Optional[Dict]:
"""
从文本中提取JSON对象
支持多种格式:
1. 纯JSON对象
2. 包裹在 ```json 代码块中的JSON
3. 包裹在 ``` 代码块中的JSON
4. 文本中包含的JSON对象
"""
# 方法1: 尝试提取代码块中的JSON
if '```json' in text:
json_start = text.find('```json') + 7
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
if '```' in text:
json_start = text.find('```') + 3
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
# 如果不是json标记尝试解析
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
# 方法2: 尝试直接解析整个文本
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# 方法3: 尝试查找文本中的JSON对象以 { 开始,以 } 结束)
# 使用正则表达式找到最外层的JSON对象
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.finditer(json_pattern, text, re.DOTALL)
for match in matches:
json_str = match.group(0)
try:
data = json.loads(json_str)
# 验证是否包含预期的字段至少有一个输出字段的key
if isinstance(data, dict) and len(data) > 0:
return data
except json.JSONDecodeError:
continue
# 方法4: 尝试查找嵌套的JSON对象更复杂的匹配
# 找到第一个 { 和最后一个匹配的 }
start_idx = text.find('{')
if start_idx != -1:
brace_count = 0
end_idx = start_idx
for i in range(start_idx, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
end_idx = i
break
if end_idx > start_idx:
json_str = text[start_idx:end_idx + 1]
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
return None
def _parse_text_response(self, text: str, output_fields: List[Dict]) -> Dict:
"""
从文本响应中解析字段值(备用方案)
"""
result = {}
for field in output_fields:
field_code = field['field_code']
field_name = field['name']
# 尝试在文本中查找字段值
# 这里使用简单的关键词匹配,实际可以更复杂
if field_name in text:
# 提取字段值(简单实现)
start_idx = text.find(field_name)
if start_idx != -1:
# 查找冒号后的内容
colon_idx = text.find(':', start_idx)
if colon_idx != -1:
value_start = colon_idx + 1
value_end = text.find('\n', value_start)
if value_end == -1:
value_end = len(text)
value = text[value_start:value_end].strip()
result[field_code] = value
else:
result[field_code] = ''
else:
result[field_code] = ''
else:
result[field_code] = ''
return result