361 lines
14 KiB
Python
361 lines
14 KiB
Python
"""
|
||
模板AI辅助工具 - 使用AI智能分析文档内容并识别占位符替换位置
|
||
"""
|
||
import os
|
||
import json
|
||
import requests
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
|
||
class TemplateAIHelper:
|
||
"""模板AI辅助类,用于智能分析文档内容"""
|
||
|
||
def __init__(self):
|
||
# ========== AI服务提供商选择 ==========
|
||
# 通过环境变量 AI_PROVIDER 选择使用的AI服务
|
||
# 可选值: 'huawei' 或 'siliconflow',默认为 'siliconflow'
|
||
ai_provider = os.getenv('AI_PROVIDER', 'siliconflow').lower()
|
||
|
||
# ========== 华为大模型配置 ==========
|
||
huawei_key = os.getenv('HUAWEI_API_KEY', 'sk-PoeiV3qwyTIRqcVc84E8E24cD2904872859a87922e0d9186')
|
||
huawei_endpoint = os.getenv('HUAWEI_API_ENDPOINT', 'http://10.100.31.26:3001/v1/chat/completions')
|
||
huawei_model = os.getenv('HUAWEI_MODEL', 'DeepSeek-R1-Distill-Llama-70B')
|
||
|
||
# ========== 硅基流动配置 ==========
|
||
siliconflow_key = os.getenv('SILICONFLOW_API_KEY', '')
|
||
siliconflow_url = os.getenv('SILICONFLOW_URL', 'https://api.siliconflow.cn/v1/chat/completions')
|
||
siliconflow_model = os.getenv('SILICONFLOW_MODEL', 'deepseek-ai/DeepSeek-V3.2-Exp')
|
||
|
||
# 根据配置选择服务提供商
|
||
if ai_provider == 'huawei':
|
||
if not huawei_key or not huawei_endpoint:
|
||
raise Exception("未配置华为大模型服务,请设置 HUAWEI_API_KEY 和 HUAWEI_API_ENDPOINT,或设置 AI_PROVIDER=siliconflow 使用硅基流动")
|
||
self.api_key = huawei_key
|
||
self.model = huawei_model
|
||
self.api_url = huawei_endpoint
|
||
print(f"[模板AI助手] 使用华为大模型: {huawei_model}")
|
||
elif ai_provider == 'siliconflow':
|
||
if not siliconflow_key:
|
||
raise Exception("未配置硅基流动服务,请设置 SILICONFLOW_API_KEY,或设置 AI_PROVIDER=huawei 使用华为大模型")
|
||
self.api_key = siliconflow_key
|
||
self.model = siliconflow_model
|
||
self.api_url = siliconflow_url
|
||
print(f"[模板AI助手] 使用硅基流动: {siliconflow_model}")
|
||
else:
|
||
# 自动检测:优先使用硅基流动,如果未配置则使用华为大模型
|
||
if siliconflow_key and siliconflow_url:
|
||
self.api_key = siliconflow_key
|
||
self.model = siliconflow_model
|
||
self.api_url = siliconflow_url
|
||
print(f"[模板AI助手] 自动选择硅基流动: {siliconflow_model}")
|
||
elif huawei_key and huawei_endpoint:
|
||
self.api_key = huawei_key
|
||
self.model = huawei_model
|
||
self.api_url = huawei_endpoint
|
||
print(f"[模板AI助手] 自动选择华为大模型: {huawei_model}")
|
||
else:
|
||
raise Exception("未配置AI服务,请设置 AI_PROVIDER 环境变量('huawei' 或 'siliconflow'),并配置相应的API密钥")
|
||
|
||
def test_api_connection(self) -> bool:
|
||
"""
|
||
测试API连接是否正常
|
||
|
||
Returns:
|
||
是否连接成功
|
||
"""
|
||
try:
|
||
print(f" [测试] 正在测试API连接...")
|
||
|
||
# 测试payload
|
||
test_payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": "测试"
|
||
}
|
||
],
|
||
"temperature": 0.5,
|
||
"max_tokens": 10
|
||
}
|
||
|
||
# 如果是华为大模型,添加额外的参数
|
||
if 'huawei' in self.api_url.lower() or '10.100.31.26' in self.api_url:
|
||
test_payload["stream"] = False
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
self.api_url,
|
||
json=test_payload,
|
||
headers=headers,
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
print(" [测试] ✓ API连接正常")
|
||
return True
|
||
else:
|
||
print(f" [测试] ✗ API连接失败: {response.status_code} - {response.text[:200]}")
|
||
return False
|
||
except Exception as e:
|
||
print(f" [测试] ✗ API连接测试失败: {e}")
|
||
return False
|
||
|
||
def analyze_document_content(
|
||
self,
|
||
document_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知"
|
||
) -> List[Dict]:
|
||
"""
|
||
分析文档内容,识别需要替换为占位符的位置
|
||
|
||
Args:
|
||
document_text: 文档文本内容
|
||
available_fields: 可用字段列表,格式: [{"field_code": "xxx", "field_name": "xxx", "description": "xxx"}]
|
||
document_type: 文档类型
|
||
|
||
Returns:
|
||
替换建议列表,格式: [
|
||
{
|
||
"original_text": "原始文本",
|
||
"replacement": "{{field_code}}",
|
||
"field_code": "field_code",
|
||
"field_name": "字段名称",
|
||
"confidence": 0.9,
|
||
"position": "段落/表格"
|
||
}
|
||
]
|
||
"""
|
||
try:
|
||
print(f" [AI] 正在分析文本内容(长度: {len(document_text)} 字符)...")
|
||
# 构建字段信息字符串
|
||
fields_info = "\n".join([
|
||
f"- {field['field_name']} ({{{{field_code}}}}): {field.get('description', '')}"
|
||
for field in available_fields
|
||
])
|
||
|
||
# 构建提示词
|
||
prompt = f"""你是一个专业的文档模板分析助手。请分析以下文档内容,识别所有可以替换为占位符的位置。
|
||
|
||
文档类型:{document_type}
|
||
|
||
可用字段列表:
|
||
{fields_info}
|
||
|
||
文档内容:
|
||
{document_text}
|
||
|
||
请仔细分析文档内容,识别以下类型的可替换内容:
|
||
1. 明确的字段值(如姓名、单位、职务等)
|
||
2. 示例值(如"XXX"、"张三"、"某公司"等)
|
||
3. 组合字段(如"山西XXXX集团有限公司(职务+姓名)"应替换为对应的占位符组合)
|
||
4. 日期、时间等格式化的值
|
||
5. 任何看起来是示例或占位符的内容
|
||
|
||
对于组合字段,如果包含多个字段信息,请使用多个占位符的组合,例如:
|
||
- "山西XXXX集团有限公司(职务+姓名)" → "{{{{target_organization_and_position}}}}({{{{target_name}}}})"
|
||
- "张三,男,1980年5月" → "{{{{target_name}}}},{{{{target_gender}}}},{{{{target_date_of_birth}}}}"
|
||
|
||
请以JSON格式返回分析结果,格式如下:
|
||
{{
|
||
"replacements": [
|
||
{{
|
||
"original_text": "原始文本内容",
|
||
"replacement": "{{{{field_code}}}}或组合占位符",
|
||
"field_code": "字段编码(如果是组合,用逗号分隔)",
|
||
"field_name": "字段名称",
|
||
"confidence": 0.9,
|
||
"position": "位置描述(如:第X段、表格第X行第X列)",
|
||
"reason": "替换原因"
|
||
}}
|
||
]
|
||
}}
|
||
|
||
只返回JSON,不要其他说明文字。"""
|
||
|
||
# 调用AI API
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的文档模板分析助手,能够准确识别文档中需要替换为占位符的内容。请严格按照JSON格式返回结果,不要添加任何解释性文字。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": 0.5,
|
||
"max_tokens": 8192
|
||
}
|
||
|
||
# 如果是华为大模型,添加额外的参数
|
||
if 'huawei' in self.api_url.lower() or '10.100.31.26' in self.api_url:
|
||
payload["stream"] = False
|
||
payload["presence_penalty"] = 1.03
|
||
payload["frequency_penalty"] = 1.0
|
||
payload["repetition_penalty"] = 1.0
|
||
payload["top_p"] = 0.95
|
||
payload["top_k"] = 1
|
||
payload["seed"] = 1
|
||
payload["n"] = 1
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
print(f" [AI] 正在调用API...")
|
||
response = requests.post(
|
||
self.api_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=60
|
||
)
|
||
|
||
print(f" [AI] API响应状态: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = response.text[:500] if len(response.text) > 500 else response.text
|
||
raise Exception(f"API调用失败: {response.status_code} - {error_msg}")
|
||
|
||
result = response.json()
|
||
print(f" [AI] API调用成功,正在解析响应...")
|
||
|
||
# 提取AI返回的内容
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
# 如果返回的是代码块,提取JSON部分
|
||
if '```json' in content:
|
||
json_start = content.find('```json') + 7
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
elif '```' in content:
|
||
json_start = content.find('```') + 3
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
|
||
parsed_result = json.loads(content)
|
||
replacements = parsed_result.get('replacements', [])
|
||
print(f" [AI] ✓ 分析完成,识别到 {len(replacements)} 个替换建议")
|
||
return replacements
|
||
except json.JSONDecodeError as e:
|
||
print(f" [AI] ⚠ JSON解析失败: {e}")
|
||
print(f" [AI] 原始响应内容: {content[:200]}...")
|
||
# 如果JSON解析失败,返回空列表
|
||
return []
|
||
else:
|
||
raise Exception("API返回格式异常")
|
||
|
||
except requests.exceptions.Timeout:
|
||
print(f" [AI] ✗ 请求超时(60秒),跳过AI分析")
|
||
return []
|
||
except Exception as e:
|
||
print(f" [AI] ✗ 分析失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
def analyze_paragraph(
|
||
self,
|
||
paragraph_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知"
|
||
) -> List[Dict]:
|
||
"""
|
||
分析单个段落,识别需要替换的位置
|
||
|
||
Args:
|
||
paragraph_text: 段落文本
|
||
available_fields: 可用字段列表
|
||
document_type: 文档类型
|
||
|
||
Returns:
|
||
替换建议列表
|
||
"""
|
||
# 跳过空内容或太短的内容(少于10个字符)
|
||
if not paragraph_text or len(paragraph_text.strip()) < 10:
|
||
return []
|
||
|
||
# 如果文本已经包含占位符,跳过
|
||
if '{{' in paragraph_text and '}}' in paragraph_text:
|
||
return []
|
||
|
||
return self.analyze_document_content(paragraph_text, available_fields, document_type)
|
||
|
||
def analyze_table_cell(
|
||
self,
|
||
cell_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知",
|
||
row: int = 0,
|
||
col: int = 0
|
||
) -> List[Dict]:
|
||
"""
|
||
分析表格单元格,识别需要替换的位置
|
||
|
||
Args:
|
||
cell_text: 单元格文本
|
||
available_fields: 可用字段列表
|
||
document_type: 文档类型
|
||
row: 行号
|
||
col: 列号
|
||
|
||
Returns:
|
||
替换建议列表
|
||
"""
|
||
if not cell_text or len(cell_text.strip()) < 3:
|
||
return []
|
||
|
||
replacements = self.analyze_document_content(cell_text, available_fields, document_type)
|
||
|
||
# 添加位置信息
|
||
for replacement in replacements:
|
||
replacement['position'] = f"表格第{row+1}行第{col+1}列"
|
||
|
||
return replacements
|
||
|
||
|
||
def get_available_fields_for_document(doc_config: Dict, field_name_to_code: Dict) -> List[Dict]:
|
||
"""
|
||
根据文档配置获取可用字段列表
|
||
|
||
Args:
|
||
doc_config: 文档配置
|
||
field_name_to_code: 字段名称到编码的映射
|
||
|
||
Returns:
|
||
可用字段列表
|
||
"""
|
||
available_fields = []
|
||
|
||
for field_code in doc_config.get('fields', []):
|
||
# 查找字段名称
|
||
field_name = None
|
||
for name, code in field_name_to_code.items():
|
||
if code == field_code:
|
||
field_name = name
|
||
break
|
||
|
||
if field_name:
|
||
available_fields.append({
|
||
'field_code': field_code,
|
||
'field_name': field_name,
|
||
'description': f"{field_name}字段"
|
||
})
|
||
|
||
return available_fields
|