296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""
|
||
AI服务 - 封装大模型调用
|
||
仅支持华为大模型
|
||
"""
|
||
import os
|
||
import re
|
||
import requests
|
||
import json
|
||
from typing import Dict, List, Optional
|
||
|
||
|
||
class AIService:
|
||
"""AI服务类"""
|
||
|
||
def __init__(self):
|
||
# 华为大模型配置(必需)
|
||
self.huawei_api_endpoint = os.getenv('HUAWEI_API_ENDPOINT', 'http://10.100.31.26:3001/v1/chat/completions')
|
||
self.huawei_api_key = os.getenv('HUAWEI_API_KEY', 'sk-PoeiV3qwyTIRqcVc84E8E24cD2904872859a87922e0d9186')
|
||
self.huawei_model = os.getenv('HUAWEI_MODEL', 'DeepSeek-R1-Distill-Llama-70B')
|
||
|
||
# 确定使用的AI服务
|
||
self.ai_provider = self._determine_ai_provider()
|
||
|
||
def _determine_ai_provider(self) -> str:
|
||
"""确定使用的AI服务提供商(仅支持华为大模型)"""
|
||
if self.huawei_api_endpoint and self.huawei_api_key:
|
||
return 'huawei'
|
||
else:
|
||
return 'none'
|
||
|
||
def extract_fields(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
|
||
"""
|
||
从提示词中提取结构化字段
|
||
|
||
Args:
|
||
prompt: AI提示词
|
||
output_fields: 输出字段列表
|
||
|
||
Returns:
|
||
提取的字段字典,格式: {field_code: field_value}
|
||
"""
|
||
if self.ai_provider == 'none':
|
||
raise Exception("未配置华为大模型服务,请设置 HUAWEI_API_KEY 和 HUAWEI_API_ENDPOINT")
|
||
|
||
if self.ai_provider == 'huawei':
|
||
return self._extract_with_huawei(prompt, output_fields)
|
||
else:
|
||
raise Exception(f"未知的AI服务提供商: {self.ai_provider}")
|
||
|
||
def _extract_with_siliconflow(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
|
||
"""
|
||
使用硅基流动API提取字段(已不再使用,仅保留用于参考)
|
||
系统仅支持华为大模型,不再支持自动回退
|
||
"""
|
||
try:
|
||
payload = {
|
||
"model": self.siliconflow_model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的数据提取助手,能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": 0.3,
|
||
"max_tokens": 2000
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.siliconflow_api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
self.siliconflow_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=30
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
|
||
|
||
result = response.json()
|
||
|
||
# 提取AI返回的内容
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
# 如果返回的是代码块,提取JSON部分
|
||
if '```json' in content:
|
||
json_start = content.find('```json') + 7
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
elif '```' in content:
|
||
json_start = content.find('```') + 3
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
|
||
extracted_data = json.loads(content)
|
||
return extracted_data
|
||
except json.JSONDecodeError:
|
||
# 如果不是JSON,尝试从文本中提取
|
||
return self._parse_text_response(content, output_fields)
|
||
else:
|
||
raise Exception("API返回格式异常")
|
||
|
||
except requests.exceptions.Timeout:
|
||
raise Exception("AI服务调用超时")
|
||
except Exception as e:
|
||
raise Exception(f"AI服务调用失败: {str(e)}")
|
||
|
||
def _extract_with_huawei(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
|
||
"""使用华为大模型API提取字段"""
|
||
try:
|
||
payload = {
|
||
"model": self.huawei_model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的数据提取助手,能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果,只返回JSON对象,不要包含任何其他文字说明、思考过程或markdown代码块标记。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"stream": False,
|
||
"presence_penalty": 1.03,
|
||
"frequency_penalty": 1.0,
|
||
"repetition_penalty": 1.0,
|
||
"temperature": 0.5,
|
||
"top_p": 0.95,
|
||
"top_k": 1,
|
||
"seed": 1,
|
||
"max_tokens": 8192,
|
||
"n": 1
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.huawei_api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
self.huawei_api_endpoint,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=60
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
|
||
|
||
result = response.json()
|
||
|
||
# 提取AI返回的内容
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
# 处理思考过程标签(华为大模型可能返回思考过程)
|
||
# 移除思考过程标签之前的内容,只保留实际回答
|
||
# 根据用户提供的示例,华为大模型使用 </think> 标签
|
||
if '</think>' in content:
|
||
content = content.split('</think>')[-1].strip()
|
||
|
||
# 尝试解析JSON
|
||
extracted_data = self._extract_json_from_text(content)
|
||
if extracted_data:
|
||
return extracted_data
|
||
|
||
# 如果无法提取JSON,尝试从文本中提取
|
||
return self._parse_text_response(content, output_fields)
|
||
else:
|
||
raise Exception("API返回格式异常")
|
||
|
||
except requests.exceptions.Timeout:
|
||
raise Exception("AI服务调用超时")
|
||
except Exception as e:
|
||
raise Exception(f"AI服务调用失败: {str(e)}")
|
||
|
||
def _extract_json_from_text(self, text: str) -> Optional[Dict]:
|
||
"""
|
||
从文本中提取JSON对象
|
||
支持多种格式:
|
||
1. 纯JSON对象
|
||
2. 包裹在 ```json 代码块中的JSON
|
||
3. 包裹在 ``` 代码块中的JSON
|
||
4. 文本中包含的JSON对象
|
||
"""
|
||
# 方法1: 尝试提取代码块中的JSON
|
||
if '```json' in text:
|
||
json_start = text.find('```json') + 7
|
||
json_end = text.find('```', json_start)
|
||
if json_end != -1:
|
||
json_str = text[json_start:json_end].strip()
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
if '```' in text:
|
||
json_start = text.find('```') + 3
|
||
json_end = text.find('```', json_start)
|
||
if json_end != -1:
|
||
json_str = text[json_start:json_end].strip()
|
||
# 如果不是json标记,尝试解析
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 方法2: 尝试直接解析整个文本
|
||
try:
|
||
return json.loads(text.strip())
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# 方法3: 尝试查找文本中的JSON对象(以 { 开始,以 } 结束)
|
||
# 使用正则表达式找到最外层的JSON对象
|
||
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
||
matches = re.finditer(json_pattern, text, re.DOTALL)
|
||
|
||
for match in matches:
|
||
json_str = match.group(0)
|
||
try:
|
||
data = json.loads(json_str)
|
||
# 验证是否包含预期的字段(至少有一个输出字段的key)
|
||
if isinstance(data, dict) and len(data) > 0:
|
||
return data
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# 方法4: 尝试查找嵌套的JSON对象(更复杂的匹配)
|
||
# 找到第一个 { 和最后一个匹配的 }
|
||
start_idx = text.find('{')
|
||
if start_idx != -1:
|
||
brace_count = 0
|
||
end_idx = start_idx
|
||
for i in range(start_idx, len(text)):
|
||
if text[i] == '{':
|
||
brace_count += 1
|
||
elif text[i] == '}':
|
||
brace_count -= 1
|
||
if brace_count == 0:
|
||
end_idx = i
|
||
break
|
||
|
||
if end_idx > start_idx:
|
||
json_str = text[start_idx:end_idx + 1]
|
||
try:
|
||
return json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _parse_text_response(self, text: str, output_fields: List[Dict]) -> Dict:
|
||
"""
|
||
从文本响应中解析字段值(备用方案)
|
||
"""
|
||
result = {}
|
||
for field in output_fields:
|
||
field_code = field['field_code']
|
||
field_name = field['name']
|
||
|
||
# 尝试在文本中查找字段值
|
||
# 这里使用简单的关键词匹配,实际可以更复杂
|
||
if field_name in text:
|
||
# 提取字段值(简单实现)
|
||
start_idx = text.find(field_name)
|
||
if start_idx != -1:
|
||
# 查找冒号后的内容
|
||
colon_idx = text.find(':', start_idx)
|
||
if colon_idx != -1:
|
||
value_start = colon_idx + 1
|
||
value_end = text.find('\n', value_start)
|
||
if value_end == -1:
|
||
value_end = len(text)
|
||
value = text[value_start:value_end].strip()
|
||
result[field_code] = value
|
||
else:
|
||
result[field_code] = ''
|
||
else:
|
||
result[field_code] = ''
|
||
else:
|
||
result[field_code] = ''
|
||
|
||
return result
|
||
|