ai-business-write/services/ai_service.py

616 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI服务 - 封装大模型调用
仅支持华为大模型
"""
import os
import re
import time
import requests
import json
from typing import Dict, List, Optional
class AIService:
"""AI服务类"""
def __init__(self):
# 华为大模型配置(必需)
self.huawei_api_endpoint = os.getenv('HUAWEI_API_ENDPOINT', 'http://10.100.31.26:3001/v1/chat/completions')
self.huawei_api_key = os.getenv('HUAWEI_API_KEY', 'sk-PoeiV3qwyTIRqcVc84E8E24cD2904872859a87922e0d9186')
self.huawei_model = os.getenv('HUAWEI_MODEL', 'DeepSeek-R1-Distill-Llama-70B')
# API超时配置
# 开启思考模式时,响应时间会显著增加,需要更长的超时时间
# 可以通过环境变量 HUAWEI_API_TIMEOUT 自定义默认180秒3分钟
self.api_timeout = int(os.getenv('HUAWEI_API_TIMEOUT', '180'))
# API最大token数配置
# 开启思考模式时模型可能生成更长的响应需要更多的token
# 可以通过环境变量 HUAWEI_API_MAX_TOKENS 自定义默认12000
self.api_max_tokens = int(os.getenv('HUAWEI_API_MAX_TOKENS', '12000'))
# 确定使用的AI服务
self.ai_provider = self._determine_ai_provider()
def _determine_ai_provider(self) -> str:
"""确定使用的AI服务提供商仅支持华为大模型"""
if self.huawei_api_endpoint and self.huawei_api_key:
return 'huawei'
else:
return 'none'
def extract_fields(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
从提示词中提取结构化字段
Args:
prompt: AI提示词
output_fields: 输出字段列表
Returns:
提取的字段字典,格式: {field_code: field_value}
"""
if self.ai_provider == 'none':
raise Exception("未配置华为大模型服务,请设置 HUAWEI_API_KEY 和 HUAWEI_API_ENDPOINT")
if self.ai_provider == 'huawei':
return self._extract_with_huawei(prompt, output_fields)
else:
raise Exception(f"未知的AI服务提供商: {self.ai_provider}")
def _extract_with_siliconflow(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
使用硅基流动API提取字段已不再使用仅保留用于参考
系统仅支持华为大模型,不再支持自动回退
"""
try:
payload = {
"model": self.siliconflow_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手能够从文本中准确提取结构化信息。请严格按照JSON格式返回结果。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"max_tokens": 2000
}
headers = {
"Authorization": f"Bearer {self.siliconflow_api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.siliconflow_url,
json=payload,
headers=headers,
timeout=30
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 尝试解析JSON
try:
# 如果返回的是代码块提取JSON部分
if '```json' in content:
json_start = content.find('```json') + 7
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
elif '```' in content:
json_start = content.find('```') + 3
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
extracted_data = json.loads(content)
return extracted_data
except json.JSONDecodeError:
# 如果不是JSON尝试从文本中提取
return self._parse_text_response(content, output_fields)
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
raise Exception("AI服务调用超时")
except Exception as e:
raise Exception(f"AI服务调用失败: {str(e)}")
def _extract_with_huawei(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
使用华为大模型API提取字段带重试机制
至少重试3次总共最多尝试4次
"""
max_retries = 3 # 最多重试3次总共4次尝试
retry_delay = 2 # 重试延迟每次重试延迟递增从2秒开始
last_exception = None
for attempt in range(max_retries + 1): # 0, 1, 2, 3 (总共4次)
try:
if attempt > 0:
# 重试前等待延迟时间递增2秒、4秒、6秒
wait_time = retry_delay * attempt
print(f"[AI服务] 第 {attempt} 次重试,等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
print(f"[AI服务] 正在调用华为大模型API (尝试 {attempt + 1}/{max_retries + 1})...")
result = self._call_huawei_api_once(prompt, output_fields)
if result is not None:
if attempt > 0:
print(f"[AI服务] 重试成功!")
return result
except requests.exceptions.Timeout as e:
last_exception = e
error_msg = f"AI服务调用超时 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except requests.exceptions.ConnectionError as e:
last_exception = e
error_msg = f"连接错误 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except requests.exceptions.RequestException as e:
last_exception = e
error_msg = f"请求异常 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
except Exception as e:
last_exception = e
error_msg = f"AI服务调用失败 (尝试 {attempt + 1}/{max_retries + 1})"
print(f"[AI服务] {error_msg}: {str(e)}")
# 对于其他类型的错误,也进行重试
if attempt < max_retries:
continue
else:
raise Exception(f"{error_msg}: {str(e)}")
# 如果所有重试都失败了
if last_exception:
raise Exception(f"AI服务调用失败已重试 {max_retries} 次: {str(last_exception)}")
else:
raise Exception(f"AI服务调用失败已重试 {max_retries}")
def _call_huawei_api_once(self, prompt: str, output_fields: List[Dict]) -> Optional[Dict]:
"""
单次调用华为大模型API不包含重试逻辑
"""
payload = {
"model": self.huawei_model,
"messages": [
{
"role": "system",
"content": "你是一个专业的数据提取助手。请仔细分析用户提供的输入文本提取所有相关信息并严格按照指定的JSON格式返回结果。只返回JSON对象不要包含任何其他文字说明、思考过程或markdown代码块标记。"
},
{
"role": "user",
"content": prompt
}
],
"stream": False,
"presence_penalty": 1.03,
"frequency_penalty": 1.0,
"repetition_penalty": 1.0,
"temperature": 0.3,
"top_p": 0.95,
"top_k": 1,
"seed": 1,
"max_tokens": self.api_max_tokens,
"n": 1,
"enable_thinking": True
}
headers = {
"Authorization": f"Bearer {self.huawei_api_key}",
"Content-Type": "application/json"
}
# 根据是否开启思考模式动态调整超时时间
# 开启思考模式时,模型需要更多时间进行推理,超时时间需要更长
enable_thinking = payload.get('enable_thinking', False)
if enable_thinking:
# 思考模式使用配置的超时时间默认180秒
timeout = self.api_timeout
print(f"[AI服务] 思考模式已开启,使用超时时间: {timeout}")
else:
# 非思考模式:使用较短的超时时间
timeout = min(self.api_timeout, 120) # 最多120秒
print(f"[AI服务] 思考模式未开启,使用超时时间: {timeout}")
response = requests.post(
self.huawei_api_endpoint,
json=payload,
headers=headers,
timeout=timeout
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code} - {response.text}")
result = response.json()
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
raw_content = result['choices'][0]['message']['content']
# 调试打印原始返回内容前500字符
print(f"[AI服务] API返回的原始内容前500字符: {raw_content[:500]}")
# 处理思考过程标签(支持多种可能的标签格式)
content = raw_content
# 处理 </think> 标签DeepSeek-R1常用格式
if '</think>' in content:
parts = content.split('</think>')
if len(parts) > 1:
content = parts[-1].strip()
print(f"[AI服务] 检测到 </think> 标签,提取标签后的内容")
# 处理 </think> 标签
elif '</think>' in content:
parts = content.split('</think>')
if len(parts) > 1:
content = parts[-1].strip()
print(f"[AI服务] 检测到 </think> 标签,提取标签后的内容")
# 处理 <reasoning>...</reasoning> 标签
elif '<reasoning>' in content and '</reasoning>' in content:
reasoning_start = content.find('</reasoning>')
if reasoning_start != -1:
content = content[reasoning_start + 11:].strip()
print(f"[AI服务] 检测到 <reasoning> 标签,提取标签后的内容")
# 清理后的内容前500字符
print(f"[AI服务] 清理后的内容前500字符: {content[:500]}")
# 尝试解析JSON
extracted_data = self._extract_json_from_text(content)
if extracted_data:
print(f"[AI服务] JSON解析成功提取到 {len(extracted_data)} 个字段")
print(f"[AI服务] 原始字段名: {list(extracted_data.keys())}")
# 规范化字段名并映射到正确的字段编码
normalized_data = self._normalize_field_names(extracted_data, output_fields)
print(f"[AI服务] 规范化后的字段名: {list(normalized_data.keys())}")
return normalized_data
# 如果无法提取JSON记录错误
print(f"[AI服务] 警告无法从内容中提取JSON尝试备用解析方法")
print(f"[AI服务] 完整内容: {content}")
# 尝试从文本中提取
parsed_data = self._parse_text_response(content, output_fields)
if parsed_data and any(v for v in parsed_data.values() if v): # 至少有一个非空字段
print(f"[AI服务] 使用备用方法解析成功,提取到 {len(parsed_data)} 个字段")
return parsed_data
# 如果所有方法都失败,抛出异常
raise Exception(f"无法从API返回内容中提取JSON数据。原始内容长度: {len(raw_content)}, 清理后内容长度: {len(content)}。请检查API返回的内容格式是否正确。")
else:
raise Exception("API返回格式异常未找到choices字段或choices为空")
def _extract_json_from_text(self, text: str) -> Optional[Dict]:
"""
从文本中提取JSON对象
支持多种格式:
1. 纯JSON对象
2. 包裹在 ```json 代码块中的JSON
3. 包裹在 ``` 代码块中的JSON
4. 文本中包含的JSON对象
"""
# 方法1: 尝试提取代码块中的JSON
if '```json' in text:
json_start = text.find('```json') + 7
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
# 尝试清理和修复JSON
json_str = self._clean_json_string(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"[AI服务] JSON解析失败代码块: {e}")
# 尝试修复后再次解析
json_str = self._fix_json_string(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
if '```' in text:
json_start = text.find('```') + 3
json_end = text.find('```', json_start)
if json_end != -1:
json_str = text[json_start:json_end].strip()
# 尝试清理和修复JSON
json_str = self._clean_json_string(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"[AI服务] JSON解析失败代码块: {e}")
# 尝试修复后再次解析
json_str = self._fix_json_string(json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
# 方法2: 尝试直接解析整个文本
cleaned_text = self._clean_json_string(text.strip())
try:
return json.loads(cleaned_text)
except json.JSONDecodeError as e:
print(f"[AI服务] JSON解析失败直接解析: {e}")
# 尝试修复后再次解析
fixed_text = self._fix_json_string(cleaned_text)
try:
return json.loads(fixed_text)
except json.JSONDecodeError:
pass
# 方法3: 尝试查找文本中的JSON对象以 { 开始,以 } 结束)
# 使用正则表达式找到最外层的JSON对象
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.finditer(json_pattern, text, re.DOTALL)
for match in matches:
json_str = match.group(0)
try:
data = json.loads(json_str)
# 验证是否包含预期的字段至少有一个输出字段的key
if isinstance(data, dict) and len(data) > 0:
return data
except json.JSONDecodeError:
continue
# 方法4: 尝试查找嵌套的JSON对象更复杂的匹配
# 找到第一个 { 和最后一个匹配的 }
start_idx = text.find('{')
if start_idx != -1:
brace_count = 0
end_idx = start_idx
for i in range(start_idx, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
end_idx = i
break
if end_idx > start_idx:
json_str = text[start_idx:end_idx + 1]
try:
return json.loads(json_str)
except json.JSONDecodeError:
pass
return None
def _clean_json_string(self, json_str: str) -> str:
"""
清理JSON字符串移除常见的格式问题
"""
# 移除前导/尾随空白
json_str = json_str.strip()
# 移除可能的BOM标记
if json_str.startswith('\ufeff'):
json_str = json_str[1:]
# 移除可能的XML/HTML标签残留
json_str = re.sub(r'<[^>]+>', '', json_str)
return json_str
def _fix_json_string(self, json_str: str) -> str:
"""
尝试修复常见的JSON格式错误
"""
# 移除末尾的逗号(在 } 或 ] 之前)
json_str = re.sub(r',\s*}', '}', json_str)
json_str = re.sub(r',\s*]', ']', json_str)
# 修复字段名中的错误(如 .target_gender -> target_gender
json_str = re.sub(r'["\']\.([^"\']+)["\']\s*:', r'"\1":', json_str)
# 修复字段名中的空格(如 "target name" -> "target_name"
json_str = re.sub(r'["\']([^"\']+)\s+([^"\']+)["\']\s*:', r'"\1_\2":', json_str)
# 尝试修复未加引号的字段名
json_str = re.sub(r'(\w+)\s*:', r'"\1":', json_str)
return json_str
def _normalize_field_names(self, extracted_data: Dict, output_fields: List[Dict]) -> Dict:
"""
规范化字段名,将模型返回的各种字段名格式映射到正确的字段编码
Args:
extracted_data: 模型返回的原始数据字典
output_fields: 输出字段列表,包含正确的字段编码
Returns:
规范化后的字段字典使用正确的字段编码作为key
"""
# 创建字段编码到字段信息的映射
field_code_map = {field['field_code']: field for field in output_fields}
# 创建字段名到字段编码的映射(支持多种变体)
name_to_code_map = {}
for field in output_fields:
field_code = field['field_code']
field_name = field.get('name', '')
# 添加标准字段编码
name_to_code_map[field_code] = field_code
# 添加字段名(如果有)
if field_name:
name_to_code_map[field_name] = field_code
# 处理驼峰命名变体(如 politicalStatus -> target_political_status
# 将 target_political_status 转换为可能的驼峰形式
if '_' in field_code:
parts = field_code.split('_')
# 生成驼峰形式targetPoliticalStatus
camel_case = parts[0] + ''.join(word.capitalize() for word in parts[1:])
name_to_code_map[camel_case] = field_code
# 生成首字母大写的驼峰形式TargetPoliticalStatus
pascal_case = ''.join(word.capitalize() for word in parts)
name_to_code_map[pascal_case] = field_code
# 处理去掉前缀的变体(如 name -> target_name
if field_code.startswith('target_'):
short_name = field_code.replace('target_', '')
name_to_code_map[short_name] = field_code
# 驼峰形式name -> target_name
camel_short = short_name.split('_')[0] + ''.join(word.capitalize() for word in short_name.split('_')[1:]) if '_' in short_name else short_name
name_to_code_map[camel_short] = field_code
# 添加常见的Schema.org格式字段名映射
schema_mapping = {
'name': 'target_name',
'gender': 'target_gender',
'dateOfBirth': 'target_date_of_birth',
'date_of_birth': 'target_date_of_birth',
'politicalStatus': 'target_political_status',
'political_status': 'target_political_status',
'organizationAndPosition': 'target_organization_and_position',
'organization_and_position': 'target_organization_and_position',
'organization': 'target_organization',
'position': 'target_position',
'educationLevel': 'target_education_level',
'education_level': 'target_education_level',
'professionalRank': 'target_professional_rank',
'professional_rank': 'target_professional_rank',
'clueSource': 'clue_source',
'clue_source': 'clue_source',
'issueDescription': 'target_issue_description',
'issue_description': 'target_issue_description',
'description': 'target_issue_description', # description可能是问题描述
'age': 'target_age',
}
# 添加Schema.org格式的映射仅当字段编码存在时
for schema_key, code in schema_mapping.items():
if code in field_code_map:
name_to_code_map[schema_key] = code
# 规范化数据
normalized_data = {}
for key, value in extracted_data.items():
# 跳过特殊字段(如 @context
if key.startswith('@'):
continue
# 处理嵌套对象(如 description: {violationOfFamilyPlanningPolicies: "..."}
if isinstance(value, dict):
# 尝试从嵌套对象中提取值
# 通常嵌套对象中只有一个值,取第一个非空值
nested_values = [v for v in value.values() if v and isinstance(v, str)]
if nested_values:
value = nested_values[0]
else:
# 如果嵌套对象中没有字符串值,尝试转换为字符串
value = str(value) if value else ''
# 清理字段名:去掉前导点、空格等
cleaned_key = key.strip().lstrip('.')
# 尝试直接匹配
if cleaned_key in name_to_code_map:
correct_code = name_to_code_map[cleaned_key]
normalized_data[correct_code] = value
continue
# 尝试不区分大小写匹配
for name, code in name_to_code_map.items():
if cleaned_key.lower() == name.lower():
normalized_data[code] = value
break
else:
# 如果找不到匹配,尝试模糊匹配
# 检查是否包含字段编码的关键部分
matched = False
for field_code in field_code_map.keys():
# 如果清理后的key包含字段编码的关键部分或者字段编码包含key的关键部分
key_parts = cleaned_key.lower().replace('_', '').replace('-', '')
code_parts = field_code.lower().replace('_', '').replace('-', '')
# 检查相似度(简单匹配)
if key_parts in code_parts or code_parts in key_parts:
normalized_data[field_code] = value
matched = True
print(f"[AI服务] 模糊匹配: '{cleaned_key}' -> '{field_code}'")
break
if not matched:
# 如果仍然找不到匹配,保留原字段名(可能模型返回了意外的字段)
print(f"[AI服务] 警告:无法匹配字段名 '{cleaned_key}',保留原字段名")
normalized_data[cleaned_key] = value
# 确保所有输出字段都有对应的值(即使为空字符串)
for field_code in field_code_map.keys():
if field_code not in normalized_data:
normalized_data[field_code] = ''
return normalized_data
def _parse_text_response(self, text: str, output_fields: List[Dict]) -> Dict:
"""
从文本响应中解析字段值(备用方案)
"""
result = {}
for field in output_fields:
field_code = field['field_code']
field_name = field['name']
# 尝试在文本中查找字段值
# 这里使用简单的关键词匹配,实际可以更复杂
if field_name in text:
# 提取字段值(简单实现)
start_idx = text.find(field_name)
if start_idx != -1:
# 查找冒号后的内容
colon_idx = text.find(':', start_idx)
if colon_idx != -1:
value_start = colon_idx + 1
value_end = text.find('\n', value_start)
if value_end == -1:
value_end = len(text)
value = text[value_start:value_end].strip()
result[field_code] = value
else:
result[field_code] = ''
else:
result[field_code] = ''
else:
result[field_code] = ''
return result