ai-business-write/test_field_relations.py

307 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试字段关联关系 API验证返回结果是否与数据库一致
"""
import pymysql
import os
import json
import requests
import sys
from dotenv import load_dotenv
# 设置输出编码为UTF-8
if sys.platform == 'win32':
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
load_dotenv()
# 数据库连接配置
DB_CONFIG = {
'host': os.getenv('DB_HOST', '152.136.177.240'),
'port': int(os.getenv('DB_PORT', 5012)),
'user': os.getenv('DB_USER', 'finyx'),
'password': os.getenv('DB_PASSWORD', '6QsGK6MpePZDE57Z'),
'database': os.getenv('DB_NAME', 'finyx'),
'charset': 'utf8mb4'
}
TENANT_ID = 615873064429507639
API_BASE_URL = 'http://localhost:7500'
def clean_query_result(data):
"""清理查询结果,将 bytes 类型转换为字符串"""
if isinstance(data, bytes):
try:
return data.decode('utf-8')
except UnicodeDecodeError:
return data.decode('utf-8', errors='ignore')
elif isinstance(data, dict):
return {key: clean_query_result(value) for key, value in data.items()}
elif isinstance(data, list):
return [clean_query_result(item) for item in data]
elif isinstance(data, (int, float, str, bool, type(None))):
return data
else:
try:
return str(data)
except:
return data
def get_database_relations(conn, tenant_id):
"""从数据库获取关联关系"""
cursor = conn.cursor(pymysql.cursors.DictCursor)
# 获取所有模板
cursor.execute("""
SELECT id, name, template_code
FROM f_polic_file_config
WHERE tenant_id = %s AND state = 1
ORDER BY name
""", (tenant_id,))
templates = cursor.fetchall()
templates = [clean_query_result(t) for t in templates]
# 获取所有输入字段
cursor.execute("""
SELECT id, name, filed_code, field_type, state
FROM f_polic_field
WHERE tenant_id = %s AND field_type = 1
ORDER BY name
""", (tenant_id,))
input_fields = cursor.fetchall()
input_fields = [clean_query_result(f) for f in input_fields]
for field in input_fields:
if 'state' in field:
try:
field['state'] = int(field['state'])
except:
field['state'] = 1
# 获取所有输出字段
cursor.execute("""
SELECT id, name, filed_code, field_type, state
FROM f_polic_field
WHERE tenant_id = %s AND field_type = 2
ORDER BY name
""", (tenant_id,))
output_fields = cursor.fetchall()
output_fields = [clean_query_result(f) for f in output_fields]
for field in output_fields:
if 'state' in field:
try:
field['state'] = int(field['state'])
except:
field['state'] = 1
# 获取关联关系
cursor.execute("""
SELECT fff.file_id, fff.filed_id
FROM f_polic_file_field fff
INNER JOIN f_polic_file_config fc ON fff.file_id = fc.id AND fff.tenant_id = fc.tenant_id
WHERE fff.tenant_id = %s AND fff.state = 1
""", (tenant_id,))
relations = cursor.fetchall()
relations = [clean_query_result(r) for r in relations]
# 构建关联关系映射
relation_map = {}
for rel in relations:
file_id = rel['file_id']
filed_id = rel['filed_id']
try:
file_id = int(file_id)
filed_id = int(filed_id)
except:
continue
if file_id not in relation_map:
relation_map[file_id] = []
relation_map[file_id].append(filed_id)
cursor.close()
return {
'templates': templates,
'input_fields': input_fields,
'output_fields': output_fields,
'relations': relation_map
}
def get_api_relations(tenant_id):
"""从 API 获取关联关系"""
try:
response = requests.get(
f'{API_BASE_URL}/api/template-field-relations',
params={'tenant_id': tenant_id},
timeout=10
)
if response.status_code == 200:
result = response.json()
if result.get('isSuccess'):
return result.get('data', {})
else:
print(f"API 返回错误: {result.get('errorMsg')}")
return None
else:
print(f"API 请求失败: {response.status_code}")
return None
except Exception as e:
print(f"API 请求异常: {e}")
return None
def compare_results(db_data, api_data):
"""对比数据库和 API 返回的数据"""
print("=" * 80)
print("对比数据库和 API 返回的数据")
print("=" * 80)
# 对比模板
print("\n1. 对比模板:")
db_templates = {t['id']: t for t in db_data['templates']}
api_templates = {t['id']: t for t in api_data.get('templates', [])}
print(f" 数据库模板数: {len(db_templates)}")
print(f" API 模板数: {len(api_templates)}")
if set(db_templates.keys()) != set(api_templates.keys()):
print(" [ERROR] 模板ID不一致")
print(f" 数据库有但API没有: {set(db_templates.keys()) - set(api_templates.keys())}")
print(f" API有但数据库没有: {set(api_templates.keys()) - set(db_templates.keys())}")
else:
print(" [OK] 模板ID一致")
# 对比输入字段
print("\n2. 对比输入字段:")
db_input_fields = {f['id']: f for f in db_data['input_fields']}
api_input_fields = {f['id']: f for f in api_data.get('input_fields', [])}
print(f" 数据库输入字段数: {len(db_input_fields)}")
print(f" API 输入字段数: {len(api_input_fields)}")
if set(db_input_fields.keys()) != set(api_input_fields.keys()):
print(" [ERROR] 输入字段ID不一致")
print(f" 数据库有但API没有: {set(db_input_fields.keys()) - set(api_input_fields.keys())}")
print(f" API有但数据库没有: {set(api_input_fields.keys()) - set(db_input_fields.keys())}")
else:
print(" [OK] 输入字段ID一致")
# 对比输出字段
print("\n3. 对比输出字段:")
db_output_fields = {f['id']: f for f in db_data['output_fields']}
api_output_fields = {f['id']: f for f in api_data.get('output_fields', [])}
print(f" 数据库输出字段数: {len(db_output_fields)}")
print(f" API 输出字段数: {len(api_output_fields)}")
if set(db_output_fields.keys()) != set(api_output_fields.keys()):
print(" [ERROR] 输出字段ID不一致")
print(f" 数据库有但API没有: {set(db_output_fields.keys()) - set(api_output_fields.keys())}")
print(f" API有但数据库没有: {set(api_output_fields.keys()) - set(db_output_fields.keys())}")
else:
print(" [OK] 输出字段ID一致")
# 对比关联关系
print("\n4. 对比关联关系:")
db_relations = db_data['relations']
api_relations = api_data.get('relations', {})
print(f" 数据库关联模板数: {len(db_relations)}")
print(f" API 关联模板数: {len(api_relations)}")
# 检查每个模板的关联关系
all_template_ids = set(db_relations.keys()) | set(api_relations.keys())
mismatch_count = 0
for template_id in all_template_ids:
db_field_ids = set(db_relations.get(template_id, []))
api_field_ids = set(api_relations.get(template_id, []))
if db_field_ids != api_field_ids:
mismatch_count += 1
template_name = db_templates.get(template_id, {}).get('name', f'ID:{template_id}')
print(f"\n [ERROR] 模板 '{template_name}' (ID: {template_id}) 关联关系不一致:")
print(f" 数据库关联字段: {sorted(db_field_ids)}")
print(f" API 关联字段: {sorted(api_field_ids)}")
print(f" 数据库有但API没有: {sorted(db_field_ids - api_field_ids)}")
print(f" API有但数据库没有: {sorted(api_field_ids - db_field_ids)}")
# 检查字段类型
db_input = db_field_ids & set(db_input_fields.keys())
db_output = db_field_ids & set(db_output_fields.keys())
api_input = api_field_ids & set(api_input_fields.keys())
api_output = api_field_ids & set(api_output_fields.keys())
print(f" 数据库 - 输入字段: {sorted(db_input)}, 输出字段: {sorted(db_output)}")
print(f" API - 输入字段: {sorted(api_input)}, 输出字段: {sorted(api_output)}")
if mismatch_count == 0:
print(" [OK] 所有模板的关联关系都一致")
else:
print(f"\n [ERROR] 共 {mismatch_count} 个模板的关联关系不一致")
# 详细检查前3个模板
print("\n5. 详细检查前3个模板的关联关系:")
template_ids = sorted(all_template_ids)[:3]
for template_id in template_ids:
template_name = db_templates.get(template_id, {}).get('name', f'ID:{template_id}')
db_field_ids = set(db_relations.get(template_id, []))
api_field_ids = set(api_relations.get(template_id, []))
print(f"\n 模板: {template_name} (ID: {template_id})")
print(f" 数据库关联字段数: {len(db_field_ids)}")
print(f" API 关联字段数: {len(api_field_ids)}")
if db_field_ids:
db_input = sorted(db_field_ids & set(db_input_fields.keys()))
db_output = sorted(db_field_ids & set(db_output_fields.keys()))
print(f" 数据库关联 - 输入字段({len(db_input)}): {db_input[:5]}{'...' if len(db_input) > 5 else ''}")
print(f" 数据库关联 - 输出字段({len(db_output)}): {db_output[:5]}{'...' if len(db_output) > 5 else ''}")
if api_field_ids:
api_input = sorted(api_field_ids & set(api_input_fields.keys()))
api_output = sorted(api_field_ids & set(api_output_fields.keys()))
print(f" API 关联 - 输入字段({len(api_input)}): {api_input[:5]}{'...' if len(api_input) > 5 else ''}")
print(f" API 关联 - 输出字段({len(api_output)}): {api_output[:5]}{'...' if len(api_output) > 5 else ''}")
def main():
"""主函数"""
print("开始测试字段关联关系 API...")
print(f"租户ID: {TENANT_ID}")
print(f"API地址: {API_BASE_URL}")
# 从数据库获取数据
print("\n从数据库获取数据...")
conn = pymysql.connect(**DB_CONFIG)
try:
db_data = get_database_relations(conn, TENANT_ID)
print(f"[OK] 数据库查询完成")
print(f" 模板数: {len(db_data['templates'])}")
print(f" 输入字段数: {len(db_data['input_fields'])}")
print(f" 输出字段数: {len(db_data['output_fields'])}")
print(f" 关联关系数: {sum(len(v) for v in db_data['relations'].values())}")
finally:
conn.close()
# 从 API 获取数据
print("\n从 API 获取数据...")
api_data = get_api_relations(TENANT_ID)
if api_data:
print(f"[OK] API 查询完成")
print(f" 模板数: {len(api_data.get('templates', []))}")
print(f" 输入字段数: {len(api_data.get('input_fields', []))}")
print(f" 输出字段数: {len(api_data.get('output_fields', []))}")
print(f" 关联关系数: {sum(len(v) for v in api_data.get('relations', {}).values())}")
else:
print("[ERROR] API 查询失败")
return
# 对比结果
compare_results(db_data, api_data)
print("\n" + "=" * 80)
print("测试完成")
print("=" * 80)
if __name__ == '__main__':
main()