119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
"""知识图谱仓储:实体/关系写入与多跳穿透(递归 CTE)。
|
|
|
|
对应需求 R2:支撑隐性实控人、关联方网络、"马甲"供应商等穿透分析。
|
|
统一穿透查询服务(P1.2.5)在此之上封装对外 API,对上层屏蔽底层是关系表还是图库。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.datahub.models import Entity, EntityRelationship
|
|
from app.datahub.ontology import EntityType, RelationshipType, is_valid_relationship
|
|
|
|
|
|
class OntologyViolationError(ValueError):
|
|
"""关系不符合本体约束。"""
|
|
|
|
|
|
def upsert_entity(
|
|
session: Session,
|
|
entity_type: EntityType,
|
|
business_key: str,
|
|
display_name: str | None = None,
|
|
attributes: dict | None = None,
|
|
data_version_id: uuid.UUID | None = None,
|
|
) -> Entity:
|
|
"""按 (类型, 业务主键) 幂等写入实体(主数据对齐的归一锚点)。"""
|
|
existing = (
|
|
session.query(Entity)
|
|
.filter(Entity.entity_type == entity_type.value, Entity.business_key == business_key)
|
|
.one_or_none()
|
|
)
|
|
if existing is not None:
|
|
if display_name is not None:
|
|
existing.display_name = display_name
|
|
if attributes:
|
|
existing.attributes = {**(existing.attributes or {}), **attributes}
|
|
return existing
|
|
|
|
entity = Entity(
|
|
entity_type=entity_type.value,
|
|
business_key=business_key,
|
|
display_name=display_name,
|
|
attributes=attributes or {},
|
|
data_version_id=data_version_id,
|
|
)
|
|
session.add(entity)
|
|
session.flush()
|
|
return entity
|
|
|
|
|
|
def add_relationship(
|
|
session: Session,
|
|
rel_type: RelationshipType,
|
|
source: Entity,
|
|
target: Entity,
|
|
attributes: dict | None = None,
|
|
data_version_id: uuid.UUID | None = None,
|
|
) -> EntityRelationship:
|
|
"""新增一条关系边,写入前校验本体约束。"""
|
|
src_type = EntityType(source.entity_type)
|
|
tgt_type = EntityType(target.entity_type)
|
|
if not is_valid_relationship(rel_type, src_type, tgt_type):
|
|
raise OntologyViolationError(
|
|
f"关系 {rel_type.value} 不允许从 {src_type.value} 指向 {tgt_type.value}"
|
|
)
|
|
rel = EntityRelationship(
|
|
rel_type=rel_type.value,
|
|
source_id=source.id,
|
|
target_id=target.id,
|
|
attributes=attributes or {},
|
|
data_version_id=data_version_id,
|
|
)
|
|
session.add(rel)
|
|
session.flush()
|
|
return rel
|
|
|
|
|
|
# 多跳穿透:以无向方式遍历关系边,返回与起点在 max_depth 跳内连通的实体集合。
|
|
# 用于"疑似同一实控人/关联方网络"识别。
|
|
_TRAVERSE_SQL = text(
|
|
"""
|
|
WITH RECURSIVE reachable(entity_id, depth, path) AS (
|
|
SELECT :start_id, 0, ARRAY[:start_id]
|
|
UNION ALL
|
|
SELECT
|
|
CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END,
|
|
rc.depth + 1,
|
|
rc.path || CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END
|
|
FROM reachable rc
|
|
JOIN entity_relationship r
|
|
ON (r.source_id = rc.entity_id OR r.target_id = rc.entity_id)
|
|
WHERE rc.depth < :max_depth
|
|
AND NOT (
|
|
CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END
|
|
= ANY(rc.path)
|
|
)
|
|
)
|
|
SELECT DISTINCT entity_id, MIN(depth) AS depth
|
|
FROM reachable
|
|
WHERE entity_id <> :start_id
|
|
GROUP BY entity_id
|
|
ORDER BY depth;
|
|
"""
|
|
)
|
|
|
|
|
|
def find_related_entities(
|
|
session: Session, start_id: uuid.UUID, max_depth: int = 3
|
|
) -> list[tuple[uuid.UUID, int]]:
|
|
"""返回与起点实体在 max_depth 跳内连通的实体 (id, 最短跳数) 列表。"""
|
|
rows = session.execute(
|
|
_TRAVERSE_SQL, {"start_id": start_id, "max_depth": max_depth}
|
|
).all()
|
|
return [(r[0], r[1]) for r in rows]
|