ai-business-write/fix_document_service_tenant_id.py
2025-12-30 10:41:35 +08:00

103 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
修复document_service.py中的tenant_id查询问题
问题get_file_config_by_id方法没有检查tenant_id导致查询可能失败
解决方案在查询中添加tenant_id检查
"""
import re
from pathlib import Path
def fix_document_service():
"""修复document_service.py中的查询逻辑"""
file_path = Path("services/document_service.py")
if not file_path.exists():
print(f"[错误] 文件不存在: {file_path}")
return False
# 读取文件
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 查找get_file_config_by_id方法
pattern = r'(def get_file_config_by_id\(self, file_id: int\) -> Optional\[Dict\]:.*?)(\s+sql = """.*?WHERE id = %s\s+AND state = 1\s+""".*?cursor\.execute\(sql, \(file_id,\)\))'
match = re.search(pattern, content, re.DOTALL)
if not match:
print("[错误] 未找到get_file_config_by_id方法或查询语句")
return False
old_code = match.group(0)
# 检查是否已经包含tenant_id
if 'tenant_id' in old_code:
print("[信息] 查询已经包含tenant_id检查无需修复")
return True
# 生成新的代码
new_sql = ''' sql = """
SELECT id, name, file_path
FROM f_polic_file_config
WHERE id = %s
AND tenant_id = %s
AND state = 1
"""
# 获取tenant_id从环境变量或请求中获取
tenant_id = self.tenant_id if self.tenant_id else os.getenv('TENANT_ID', '1')
try:
tenant_id = int(tenant_id)
except (ValueError, TypeError):
tenant_id = 1 # 默认值
cursor.execute(sql, (file_id, tenant_id))'''
# 替换
new_code = re.sub(
r'sql = """.*?WHERE id = %s\s+AND state = 1\s+""".*?cursor\.execute\(sql, \(file_id,\)\)',
new_sql,
old_code,
flags=re.DOTALL
)
new_content = content.replace(old_code, new_code)
# 检查是否需要导入os
if 'import os' not in new_content and 'os.getenv' in new_content:
# 在文件开头添加import os如果还没有
if 'from dotenv import load_dotenv' in new_content:
new_content = new_content.replace('from dotenv import load_dotenv', 'from dotenv import load_dotenv\nimport os')
elif 'import pymysql' in new_content:
new_content = new_content.replace('import pymysql', 'import pymysql\nimport os')
else:
# 在文件开头添加
lines = new_content.split('\n')
import_line = 0
for i, line in enumerate(lines):
if line.startswith('import ') or line.startswith('from '):
import_line = i + 1
lines.insert(import_line, 'import os')
new_content = '\n'.join(lines)
# 写回文件
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
print("[成功] 已修复get_file_config_by_id方法添加了tenant_id检查")
return True
if __name__ == "__main__":
print("="*70)
print("修复document_service.py中的tenant_id查询问题")
print("="*70)
if fix_document_service():
print("\n修复完成!")
print("\n注意:")
print("1. 请确保.env文件中配置了TENANT_ID")
print("2. 或者确保应用程序在调用时正确传递tenant_id")
print("3. 建议在app.py中从请求中获取tenant_id并传递给document_service")
else:
print("\n修复失败,请手动检查代码")