"""知识图谱仓储:实体/关系写入与多跳穿透(递归 CTE)。 对应需求 R2:支撑隐性实控人、关联方网络、"马甲"供应商等穿透分析。 统一穿透查询服务(P1.2.5)在此之上封装对外 API,对上层屏蔽底层是关系表还是图库。 """ from __future__ import annotations import uuid from sqlalchemy import text from sqlalchemy.orm import Session from app.datahub.models import Entity, EntityRelationship from app.datahub.ontology import EntityType, RelationshipType, is_valid_relationship class OntologyViolationError(ValueError): """关系不符合本体约束。""" def upsert_entity( session: Session, entity_type: EntityType, business_key: str, display_name: str | None = None, attributes: dict | None = None, data_version_id: uuid.UUID | None = None, ) -> Entity: """按 (类型, 业务主键) 幂等写入实体(主数据对齐的归一锚点)。""" existing = ( session.query(Entity) .filter(Entity.entity_type == entity_type.value, Entity.business_key == business_key) .one_or_none() ) if existing is not None: if display_name is not None: existing.display_name = display_name if attributes: existing.attributes = {**(existing.attributes or {}), **attributes} return existing entity = Entity( entity_type=entity_type.value, business_key=business_key, display_name=display_name, attributes=attributes or {}, data_version_id=data_version_id, ) session.add(entity) session.flush() return entity def add_relationship( session: Session, rel_type: RelationshipType, source: Entity, target: Entity, attributes: dict | None = None, data_version_id: uuid.UUID | None = None, ) -> EntityRelationship: """新增一条关系边,写入前校验本体约束。""" src_type = EntityType(source.entity_type) tgt_type = EntityType(target.entity_type) if not is_valid_relationship(rel_type, src_type, tgt_type): raise OntologyViolationError( f"关系 {rel_type.value} 不允许从 {src_type.value} 指向 {tgt_type.value}" ) rel = EntityRelationship( rel_type=rel_type.value, source_id=source.id, target_id=target.id, attributes=attributes or {}, data_version_id=data_version_id, ) session.add(rel) session.flush() return rel # 多跳穿透:以无向方式遍历关系边,返回与起点在 max_depth 跳内连通的实体集合。 # 用于"疑似同一实控人/关联方网络"识别。 _TRAVERSE_SQL = text( """ WITH RECURSIVE reachable(entity_id, depth, path) AS ( SELECT :start_id, 0, ARRAY[:start_id] UNION ALL SELECT CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END, rc.depth + 1, rc.path || CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END FROM reachable rc JOIN entity_relationship r ON (r.source_id = rc.entity_id OR r.target_id = rc.entity_id) WHERE rc.depth < :max_depth AND NOT ( CASE WHEN r.source_id = rc.entity_id THEN r.target_id ELSE r.source_id END = ANY(rc.path) ) ) SELECT DISTINCT entity_id, MIN(depth) AS depth FROM reachable WHERE entity_id <> :start_id GROUP BY entity_id ORDER BY depth; """ ) def find_related_entities( session: Session, start_id: uuid.UUID, max_depth: int = 3 ) -> list[tuple[uuid.UUID, int]]: """返回与起点实体在 max_depth 跳内连通的实体 (id, 最短跳数) 列表。""" rows = session.execute( _TRAVERSE_SQL, {"start_id": start_id, "max_depth": max_depth} ).all() return [(r[0], r[1]) for r in rows]