103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
"""
|
||
修复document_service.py中的tenant_id查询问题
|
||
|
||
问题:get_file_config_by_id方法没有检查tenant_id,导致查询可能失败
|
||
解决方案:在查询中添加tenant_id检查
|
||
"""
|
||
import re
|
||
from pathlib import Path
|
||
|
||
def fix_document_service():
|
||
"""修复document_service.py中的查询逻辑"""
|
||
file_path = Path("services/document_service.py")
|
||
|
||
if not file_path.exists():
|
||
print(f"[错误] 文件不存在: {file_path}")
|
||
return False
|
||
|
||
# 读取文件
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 查找get_file_config_by_id方法
|
||
pattern = r'(def get_file_config_by_id\(self, file_id: int\) -> Optional\[Dict\]:.*?)(\s+sql = """.*?WHERE id = %s\s+AND state = 1\s+""".*?cursor\.execute\(sql, \(file_id,\)\))'
|
||
|
||
match = re.search(pattern, content, re.DOTALL)
|
||
|
||
if not match:
|
||
print("[错误] 未找到get_file_config_by_id方法或查询语句")
|
||
return False
|
||
|
||
old_code = match.group(0)
|
||
|
||
# 检查是否已经包含tenant_id
|
||
if 'tenant_id' in old_code:
|
||
print("[信息] 查询已经包含tenant_id检查,无需修复")
|
||
return True
|
||
|
||
# 生成新的代码
|
||
new_sql = ''' sql = """
|
||
SELECT id, name, file_path
|
||
FROM f_polic_file_config
|
||
WHERE id = %s
|
||
AND tenant_id = %s
|
||
AND state = 1
|
||
"""
|
||
# 获取tenant_id(从环境变量或请求中获取)
|
||
tenant_id = self.tenant_id if self.tenant_id else os.getenv('TENANT_ID', '1')
|
||
try:
|
||
tenant_id = int(tenant_id)
|
||
except (ValueError, TypeError):
|
||
tenant_id = 1 # 默认值
|
||
|
||
cursor.execute(sql, (file_id, tenant_id))'''
|
||
|
||
# 替换
|
||
new_code = re.sub(
|
||
r'sql = """.*?WHERE id = %s\s+AND state = 1\s+""".*?cursor\.execute\(sql, \(file_id,\)\)',
|
||
new_sql,
|
||
old_code,
|
||
flags=re.DOTALL
|
||
)
|
||
|
||
new_content = content.replace(old_code, new_code)
|
||
|
||
# 检查是否需要导入os
|
||
if 'import os' not in new_content and 'os.getenv' in new_content:
|
||
# 在文件开头添加import os(如果还没有)
|
||
if 'from dotenv import load_dotenv' in new_content:
|
||
new_content = new_content.replace('from dotenv import load_dotenv', 'from dotenv import load_dotenv\nimport os')
|
||
elif 'import pymysql' in new_content:
|
||
new_content = new_content.replace('import pymysql', 'import pymysql\nimport os')
|
||
else:
|
||
# 在文件开头添加
|
||
lines = new_content.split('\n')
|
||
import_line = 0
|
||
for i, line in enumerate(lines):
|
||
if line.startswith('import ') or line.startswith('from '):
|
||
import_line = i + 1
|
||
lines.insert(import_line, 'import os')
|
||
new_content = '\n'.join(lines)
|
||
|
||
# 写回文件
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write(new_content)
|
||
|
||
print("[成功] 已修复get_file_config_by_id方法,添加了tenant_id检查")
|
||
return True
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("="*70)
|
||
print("修复document_service.py中的tenant_id查询问题")
|
||
print("="*70)
|
||
|
||
if fix_document_service():
|
||
print("\n修复完成!")
|
||
print("\n注意:")
|
||
print("1. 请确保.env文件中配置了TENANT_ID")
|
||
print("2. 或者确保应用程序在调用时正确传递tenant_id")
|
||
print("3. 建议在app.py中从请求中获取tenant_id并传递给document_service")
|
||
else:
|
||
print("\n修复失败,请手动检查代码")
|