ai-business-write/services/ai_service.py

296 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI服务 - 封装大模型调用
仅支持华为大模型
"""
import os
import re
import requests
import json
from typing import Dict, List, Optional
class AIService:
"""AI服务类"""
def __init__(self):
# 华为大模型配置(必需)
self.huawei_api_endpoint = os.getenv('HUAWEI_API_ENDPOINT', 'http://10.100.31.26:3001/v1/chat/completions')
self.huawei_api_key = os.getenv('HUAWEI_API_KEY', 'sk-PoeiV3qwyTIRqcVc84E8E24cD2904872859a87922e0d9186')
self.huawei_model = os.getenv('HUAWEI_MODEL', 'DeepSeek-R1-Distill-Llama-70B')
# 确定使用的AI服务
self.ai_provider = self._determine_ai_provider()
def _determine_ai_provider(self) -> str:
"""确定使用的AI服务提供商仅支持华为大模型"""
if self.huawei_api_endpoint and self.huawei_api_key:
return 'huawei'
else:
return 'none'
def extract_fields(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
从提示词中提取结构化字段
Args:
prompt: AI提示词
output_fields: 输出字段列表
Returns:
提取的字段字典,格式: {field_code: field_value}
"""
if self.ai_provider == 'none':
raise Exception("未配置华为大模型服务,请设置 HUAWEI_API_KEY 和 HUAWEI_API_ENDPOINT")
if self.ai_provider == 'huawei':
return self._extract_with_huawei(prompt, output_fields)
else:
raise Exception(f"未知的AI服务提供商: {self.ai_provider}")
def _extract_with_siliconflow(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
使用硅基流动API提取字段已不再使用仅保留用于参考
系统仅支持华为大模型,不再支持自动回退
"""
try:
payload = {
"model": self.siliconflow_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 2000
}
headers = {
"Authorization": f"Bearer {self.siliconflow_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.siliconflow_url,
json=payload,
headers=headers,
timeout=30
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 尝试解析JSON
try:
# 如果返回的是代码块提取JSON部分
if '```json' in content:
json_start = content.find('```json') + 7
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
elif '```' in content:
json_start = content.find('```') + 3
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
extracted_data = json.loads(content)
return extracted_data
except json.JSONDecodeError:
# 如果不是JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
raise Exception("AI服务调用超时")
except Exception as e:
raise Exception(f"AI服务调用失败: {str(e)}")
def _extract_with_huawei(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""使用华为大模型API提取字段"""
try:
payload = {
"model": self.huawei_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果只返回JSON对象不要包含任何其他文字说明、思考过程或markdown代码块标记。"
},
{
"role": "user",
"content": prompt
}
],
"stream": False,
"presence_penalty": 1.03,
"frequency_penalty": 1.0,
"repetition_penalty": 1.0,
"temperature": 0.5,
"top_p": 0.95,
"top_k": 1,
"seed": 1,
"max_tokens": 8192,
"n": 1
}
headers = {
"Authorization": f"Bearer {self.huawei_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.huawei_api_endpoint,
json=payload,
headers=headers,
timeout=60
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 处理思考过程标签(华为大模型可能返回思考过程)
# 移除思考过程标签之前的内容,只保留实际回答
# 根据用户提供的示例,华为大模型使用 </think> 标签
if '</think>' in content:
content = content.split('</think>')[-1].strip()
# 尝试解析JSON
extracted_data = self._extract_json_from_text(content)
if extracted_data:
return extracted_data
# 如果无法提取JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
raise Exception("AI服务调用超时")
except Exception as e:
raise Exception(f"AI服务调用失败: {str(e)}")
def _extract_json_from_text(self, text: str) -> Optional[Dict]:
"""
从文本中提取JSON对象
支持多种格式:
1. 纯JSON对象
2. 包裹在 ```json 代码块中的JSON
3. 包裹在 ``` 代码块中的JSON
4. 文本中包含的JSON对象
"""
# 方法1: 尝试提取代码块中的JSON
if '```json' in text:
json_start = text.find('```json') + 7
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
if '```' in text:
json_start = text.find('```') + 3
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
# 如果不是json标记尝试解析
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
# 方法2: 尝试直接解析整个文本
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# 方法3: 尝试查找文本中的JSON对象以 { 开始,以 } 结束)
# 使用正则表达式找到最外层的JSON对象
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.finditer(json_pattern, text, re.DOTALL)
for match in matches:
json_str = match.group(0)
try:
data = json.loads(json_str)
# 验证是否包含预期的字段至少有一个输出字段的key
if isinstance(data, dict) and len(data) > 0:
return data
except json.JSONDecodeError:
continue
# 方法4: 尝试查找嵌套的JSON对象更复杂的匹配
# 找到第一个 { 和最后一个匹配的 }
start_idx = text.find('{')
if start_idx != -1:
brace_count = 0
end_idx = start_idx
for i in range(start_idx, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
end_idx = i
break
if end_idx > start_idx:
json_str = text[start_idx:end_idx + 1]
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
return None
def _parse_text_response(self, text: str, output_fields: List[Dict]) -> Dict:
"""
从文本响应中解析字段值(备用方案)
"""
result = {}
for field in output_fields:
field_code = field['field_code']
field_name = field['name']
# 尝试在文本中查找字段值
# 这里使用简单的关键词匹配,实际可以更复杂
if field_name in text:
# 提取字段值(简单实现)
start_idx = text.find(field_name)
if start_idx != -1:
# 查找冒号后的内容
colon_idx = text.find(':', start_idx)
if colon_idx != -1:
value_start = colon_idx + 1
value_end = text.find('\n', value_start)
if value_end == -1:
value_end = len(text)
value = text[value_start:value_end].strip()
result[field_code] = value
else:
result[field_code] = ''
else:
result[field_code] = ''
else:
result[field_code] = ''
return result