ai-business-write/template_ai_helper.py

305 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模板AI辅助工具 - 使用AI智能分析文档内容并识别占位符替换位置
"""
import os
import json
import requests
from typing import Dict, List, Optional, Tuple
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class TemplateAIHelper:
"""模板AI辅助类用于智能分析文档内容"""
def __init__(self):
self.api_key = os.getenv('SILICONFLOW_API_KEY')
self.model = os.getenv('SILICONFLOW_MODEL', 'deepseek-ai/DeepSeek-V3.2-Exp')
self.api_url = "https://api.siliconflow.cn/v1/chat/completions"
if not self.api_key:
raise Exception("未配置 SILICONFLOW_API_KEY请在 .env 文件中设置")
def test_api_connection(self) -> bool:
"""
测试API连接是否正常
Returns:
是否连接成功
"""
try:
print(" [测试] 正在测试硅基流动API连接...")
test_payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": "测试"
}
],
"max_tokens": 10
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
self.api_url,
json=test_payload,
headers=headers,
timeout=10
)
if response.status_code == 200:
print(" [测试] ✓ API连接正常")
return True
else:
print(f" [测试] ✗ API连接失败: {response.status_code} - {response.text[:200]}")
return False
except Exception as e:
print(f" [测试] ✗ API连接测试失败: {e}")
return False
def analyze_document_content(
self,
document_text: str,
available_fields: List[Dict],
document_type: str = "未知"
) -> List[Dict]:
"""
分析文档内容,识别需要替换为占位符的位置
Args:
document_text: 文档文本内容
available_fields: 可用字段列表,格式: [{"field_code": "xxx", "field_name": "xxx", "description": "xxx"}]
document_type: 文档类型
Returns:
替换建议列表,格式: [
{
"original_text": "原始文本",
"replacement": "{{field_code}}",
"field_code": "field_code",
"field_name": "字段名称",
"confidence": 0.9,
"position": "段落/表格"
}
]
"""
try:
print(f" [AI] 正在分析文本内容(长度: {len(document_text)} 字符)...")
# 构建字段信息字符串
fields_info = "\n".join([
f"- {field['field_name']} ({{{{field_code}}}}): {field.get('description', '')}"
for field in available_fields
])
# 构建提示词
prompt = f"""你是一个专业的文档模板分析助手。请分析以下文档内容,识别所有可以替换为占位符的位置。
文档类型:{document_type}
可用字段列表:
{fields_info}
文档内容:
{document_text}
请仔细分析文档内容,识别以下类型的可替换内容:
1. 明确的字段值(如姓名、单位、职务等)
2. 示例值(如"XXX""张三""某公司"等)
3. 组合字段(如"山西XXXX集团有限公司(职务+姓名)"应替换为对应的占位符组合)
4. 日期、时间等格式化的值
5. 任何看起来是示例或占位符的内容
对于组合字段,如果包含多个字段信息,请使用多个占位符的组合,例如:
- "山西XXXX集团有限公司(职务+姓名)""{{{{target_organization_and_position}}}}({{{{target_name}}}})"
- "张三1980年5月""{{{{target_name}}}}{{{{target_gender}}}}{{{{target_date_of_birth}}}}"
请以JSON格式返回分析结果格式如下
{{
"replacements": [
{{
"original_text": "原始文本内容",
"replacement": "{{{{field_code}}}}或组合占位符",
"field_code": "字段编码(如果是组合,用逗号分隔)",
"field_name": "字段名称",
"confidence": 0.9,
"position": "位置描述第X段、表格第X行第X列",
"reason": "替换原因"
}}
]
}}
只返回JSON不要其他说明文字。"""
# 调用AI API
payload = {
"model": self.model,
"messages": [
{
"role": "system",
"content": "你是一个专业的文档模板分析助手能够准确识别文档中需要替换为占位符的内容。请严格按照JSON格式返回结果不要添加任何解释性文字。"
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.2,
"max_tokens": 4000
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
print(f" [AI] 正在调用API...")
response = requests.post(
self.api_url,
json=payload,
headers=headers,
timeout=60
)
print(f" [AI] API响应状态: {response.status_code}")
if response.status_code != 200:
error_msg = response.text[:500] if len(response.text) > 500 else response.text
raise Exception(f"API调用失败: {response.status_code} - {error_msg}")
result = response.json()
print(f" [AI] API调用成功正在解析响应...")
# 提取AI返回的内容
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0]['message']['content']
# 尝试解析JSON
try:
# 如果返回的是代码块提取JSON部分
if '```json' in content:
json_start = content.find('```json') + 7
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
elif '```' in content:
json_start = content.find('```') + 3
json_end = content.find('```', json_start)
content = content[json_start:json_end].strip()
parsed_result = json.loads(content)
replacements = parsed_result.get('replacements', [])
print(f" [AI] ✓ 分析完成,识别到 {len(replacements)} 个替换建议")
return replacements
except json.JSONDecodeError as e:
print(f" [AI] ⚠ JSON解析失败: {e}")
print(f" [AI] 原始响应内容: {content[:200]}...")
# 如果JSON解析失败返回空列表
return []
else:
raise Exception("API返回格式异常")
except requests.exceptions.Timeout:
print(f" [AI] ✗ 请求超时60秒跳过AI分析")
return []
except Exception as e:
print(f" [AI] ✗ 分析失败: {e}")
import traceback
traceback.print_exc()
return []
def analyze_paragraph(
self,
paragraph_text: str,
available_fields: List[Dict],
document_type: str = "未知"
) -> List[Dict]:
"""
分析单个段落,识别需要替换的位置
Args:
paragraph_text: 段落文本
available_fields: 可用字段列表
document_type: 文档类型
Returns:
替换建议列表
"""
# 跳过空内容或太短的内容少于10个字符
if not paragraph_text or len(paragraph_text.strip()) < 10:
return []
# 如果文本已经包含占位符,跳过
if '{{' in paragraph_text and '}}' in paragraph_text:
return []
return self.analyze_document_content(paragraph_text, available_fields, document_type)
def analyze_table_cell(
self,
cell_text: str,
available_fields: List[Dict],
document_type: str = "未知",
row: int = 0,
col: int = 0
) -> List[Dict]:
"""
分析表格单元格,识别需要替换的位置
Args:
cell_text: 单元格文本
available_fields: 可用字段列表
document_type: 文档类型
row: 行号
col: 列号
Returns:
替换建议列表
"""
if not cell_text or len(cell_text.strip()) < 3:
return []
replacements = self.analyze_document_content(cell_text, available_fields, document_type)
# 添加位置信息
for replacement in replacements:
replacement['position'] = f"表格第{row+1}行第{col+1}"
return replacements
def get_available_fields_for_document(doc_config: Dict, field_name_to_code: Dict) -> List[Dict]:
"""
根据文档配置获取可用字段列表
Args:
doc_config: 文档配置
field_name_to_code: 字段名称到编码的映射
Returns:
可用字段列表
"""
available_fields = []
for field_code in doc_config.get('fields', []):
# 查找字段名称
field_name = None
for name, code in field_name_to_code.items():
if code == field_code:
field_name = name
break
if field_name:
available_fields.append({
'field_code': field_code,
'field_name': field_name,
'description': f"{field_name}字段"
})
return available_fields