305 lines
11 KiB
Python
305 lines
11 KiB
Python
"""
|
||
模板AI辅助工具 - 使用AI智能分析文档内容并识别占位符替换位置
|
||
"""
|
||
import os
|
||
import json
|
||
import requests
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
|
||
class TemplateAIHelper:
|
||
"""模板AI辅助类,用于智能分析文档内容"""
|
||
|
||
def __init__(self):
|
||
self.api_key = os.getenv('SILICONFLOW_API_KEY')
|
||
self.model = os.getenv('SILICONFLOW_MODEL', 'deepseek-ai/DeepSeek-V3.2-Exp')
|
||
self.api_url = "https://api.siliconflow.cn/v1/chat/completions"
|
||
|
||
if not self.api_key:
|
||
raise Exception("未配置 SILICONFLOW_API_KEY,请在 .env 文件中设置")
|
||
|
||
def test_api_connection(self) -> bool:
|
||
"""
|
||
测试API连接是否正常
|
||
|
||
Returns:
|
||
是否连接成功
|
||
"""
|
||
try:
|
||
print(" [测试] 正在测试硅基流动API连接...")
|
||
test_payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": "测试"
|
||
}
|
||
],
|
||
"max_tokens": 10
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(
|
||
self.api_url,
|
||
json=test_payload,
|
||
headers=headers,
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
print(" [测试] ✓ API连接正常")
|
||
return True
|
||
else:
|
||
print(f" [测试] ✗ API连接失败: {response.status_code} - {response.text[:200]}")
|
||
return False
|
||
except Exception as e:
|
||
print(f" [测试] ✗ API连接测试失败: {e}")
|
||
return False
|
||
|
||
def analyze_document_content(
|
||
self,
|
||
document_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知"
|
||
) -> List[Dict]:
|
||
"""
|
||
分析文档内容,识别需要替换为占位符的位置
|
||
|
||
Args:
|
||
document_text: 文档文本内容
|
||
available_fields: 可用字段列表,格式: [{"field_code": "xxx", "field_name": "xxx", "description": "xxx"}]
|
||
document_type: 文档类型
|
||
|
||
Returns:
|
||
替换建议列表,格式: [
|
||
{
|
||
"original_text": "原始文本",
|
||
"replacement": "{{field_code}}",
|
||
"field_code": "field_code",
|
||
"field_name": "字段名称",
|
||
"confidence": 0.9,
|
||
"position": "段落/表格"
|
||
}
|
||
]
|
||
"""
|
||
try:
|
||
print(f" [AI] 正在分析文本内容(长度: {len(document_text)} 字符)...")
|
||
# 构建字段信息字符串
|
||
fields_info = "\n".join([
|
||
f"- {field['field_name']} ({{{{field_code}}}}): {field.get('description', '')}"
|
||
for field in available_fields
|
||
])
|
||
|
||
# 构建提示词
|
||
prompt = f"""你是一个专业的文档模板分析助手。请分析以下文档内容,识别所有可以替换为占位符的位置。
|
||
|
||
文档类型:{document_type}
|
||
|
||
可用字段列表:
|
||
{fields_info}
|
||
|
||
文档内容:
|
||
{document_text}
|
||
|
||
请仔细分析文档内容,识别以下类型的可替换内容:
|
||
1. 明确的字段值(如姓名、单位、职务等)
|
||
2. 示例值(如"XXX"、"张三"、"某公司"等)
|
||
3. 组合字段(如"山西XXXX集团有限公司(职务+姓名)"应替换为对应的占位符组合)
|
||
4. 日期、时间等格式化的值
|
||
5. 任何看起来是示例或占位符的内容
|
||
|
||
对于组合字段,如果包含多个字段信息,请使用多个占位符的组合,例如:
|
||
- "山西XXXX集团有限公司(职务+姓名)" → "{{{{target_organization_and_position}}}}({{{{target_name}}}})"
|
||
- "张三,男,1980年5月" → "{{{{target_name}}}},{{{{target_gender}}}},{{{{target_date_of_birth}}}}"
|
||
|
||
请以JSON格式返回分析结果,格式如下:
|
||
{{
|
||
"replacements": [
|
||
{{
|
||
"original_text": "原始文本内容",
|
||
"replacement": "{{{{field_code}}}}或组合占位符",
|
||
"field_code": "字段编码(如果是组合,用逗号分隔)",
|
||
"field_name": "字段名称",
|
||
"confidence": 0.9,
|
||
"position": "位置描述(如:第X段、表格第X行第X列)",
|
||
"reason": "替换原因"
|
||
}}
|
||
]
|
||
}}
|
||
|
||
只返回JSON,不要其他说明文字。"""
|
||
|
||
# 调用AI API
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的文档模板分析助手,能够准确识别文档中需要替换为占位符的内容。请严格按照JSON格式返回结果,不要添加任何解释性文字。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": 0.2,
|
||
"max_tokens": 4000
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
print(f" [AI] 正在调用API...")
|
||
response = requests.post(
|
||
self.api_url,
|
||
json=payload,
|
||
headers=headers,
|
||
timeout=60
|
||
)
|
||
|
||
print(f" [AI] API响应状态: {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
error_msg = response.text[:500] if len(response.text) > 500 else response.text
|
||
raise Exception(f"API调用失败: {response.status_code} - {error_msg}")
|
||
|
||
result = response.json()
|
||
print(f" [AI] API调用成功,正在解析响应...")
|
||
|
||
# 提取AI返回的内容
|
||
if 'choices' in result and len(result['choices']) > 0:
|
||
content = result['choices'][0]['message']['content']
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
# 如果返回的是代码块,提取JSON部分
|
||
if '```json' in content:
|
||
json_start = content.find('```json') + 7
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
elif '```' in content:
|
||
json_start = content.find('```') + 3
|
||
json_end = content.find('```', json_start)
|
||
content = content[json_start:json_end].strip()
|
||
|
||
parsed_result = json.loads(content)
|
||
replacements = parsed_result.get('replacements', [])
|
||
print(f" [AI] ✓ 分析完成,识别到 {len(replacements)} 个替换建议")
|
||
return replacements
|
||
except json.JSONDecodeError as e:
|
||
print(f" [AI] ⚠ JSON解析失败: {e}")
|
||
print(f" [AI] 原始响应内容: {content[:200]}...")
|
||
# 如果JSON解析失败,返回空列表
|
||
return []
|
||
else:
|
||
raise Exception("API返回格式异常")
|
||
|
||
except requests.exceptions.Timeout:
|
||
print(f" [AI] ✗ 请求超时(60秒),跳过AI分析")
|
||
return []
|
||
except Exception as e:
|
||
print(f" [AI] ✗ 分析失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
def analyze_paragraph(
|
||
self,
|
||
paragraph_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知"
|
||
) -> List[Dict]:
|
||
"""
|
||
分析单个段落,识别需要替换的位置
|
||
|
||
Args:
|
||
paragraph_text: 段落文本
|
||
available_fields: 可用字段列表
|
||
document_type: 文档类型
|
||
|
||
Returns:
|
||
替换建议列表
|
||
"""
|
||
# 跳过空内容或太短的内容(少于10个字符)
|
||
if not paragraph_text or len(paragraph_text.strip()) < 10:
|
||
return []
|
||
|
||
# 如果文本已经包含占位符,跳过
|
||
if '{{' in paragraph_text and '}}' in paragraph_text:
|
||
return []
|
||
|
||
return self.analyze_document_content(paragraph_text, available_fields, document_type)
|
||
|
||
def analyze_table_cell(
|
||
self,
|
||
cell_text: str,
|
||
available_fields: List[Dict],
|
||
document_type: str = "未知",
|
||
row: int = 0,
|
||
col: int = 0
|
||
) -> List[Dict]:
|
||
"""
|
||
分析表格单元格,识别需要替换的位置
|
||
|
||
Args:
|
||
cell_text: 单元格文本
|
||
available_fields: 可用字段列表
|
||
document_type: 文档类型
|
||
row: 行号
|
||
col: 列号
|
||
|
||
Returns:
|
||
替换建议列表
|
||
"""
|
||
if not cell_text or len(cell_text.strip()) < 3:
|
||
return []
|
||
|
||
replacements = self.analyze_document_content(cell_text, available_fields, document_type)
|
||
|
||
# 添加位置信息
|
||
for replacement in replacements:
|
||
replacement['position'] = f"表格第{row+1}行第{col+1}列"
|
||
|
||
return replacements
|
||
|
||
|
||
def get_available_fields_for_document(doc_config: Dict, field_name_to_code: Dict) -> List[Dict]:
|
||
"""
|
||
根据文档配置获取可用字段列表
|
||
|
||
Args:
|
||
doc_config: 文档配置
|
||
field_name_to_code: 字段名称到编码的映射
|
||
|
||
Returns:
|
||
可用字段列表
|
||
"""
|
||
available_fields = []
|
||
|
||
for field_code in doc_config.get('fields', []):
|
||
# 查找字段名称
|
||
field_name = None
|
||
for name, code in field_name_to_code.items():
|
||
if code == field_code:
|
||
field_name = name
|
||
break
|
||
|
||
if field_name:
|
||
available_fields.append({
|
||
'field_code': field_code,
|
||
'field_name': field_name,
|
||
'description': f"{field_name}字段"
|
||
})
|
||
|
||
return available_fields
|