Files
Train/app/etl/importer.py
T
2026-06-16 00:55:20 +08:00

268 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""ETL 导入管线:12 张 CSV -> 清洗 -> SQLite + JSON + 导入报告。
对应任务 T-1.2。运行:
python3 -m app.etl.importer # 从 Train/ 目录运行
或:
python3 app/etl/importer.py
"""
import csv
import json
import os
import sqlite3
import sys
from . import field_dict as fd
from . import clean as cl
HERE = os.path.dirname(os.path.abspath(__file__))
APP_DIR = os.path.dirname(HERE)
ROOT = os.path.dirname(APP_DIR) # Train/
CSV_DIR = os.path.join(ROOT, "csv")
OUT_DIR = os.path.join(APP_DIR, "data")
DB_PATH = os.path.join(OUT_DIR, "machines.db")
JSON_PATH = os.path.join(OUT_DIR, "machines.json")
REPORT_PATH = os.path.join(OUT_DIR, "import_report.md")
SCHEMA_PATH = os.path.join(HERE, "schema.sql")
def find_header_row(rows):
"""返回表头行索引:含 >=3 个非空单元且命中关键 token 的首行。"""
for i, row in enumerate(rows):
cells = [fd.normalize_header(c) for c in row]
nonempty = [c for c in cells if c]
if len(nonempty) >= 3 and (set(cells) & fd.HEADER_TOKENS):
return i
return None
def build_column_map(header_row):
"""列索引 -> canonical 字段名(未知列保留为 raw::原表头)。"""
col_map = {}
for idx, h in enumerate(header_row):
canon = fd.map_header(h)
norm = fd.normalize_header(h)
if canon:
col_map[idx] = canon
elif norm:
col_map[idx] = "raw::" + norm
return col_map
def clean_record(row, col_map):
"""把一行映射为 canonical 字段 dict(含 raw:: 保真列)。"""
rec, raw = {}, {}
for idx, field in col_map.items():
value = row[idx] if idx < len(row) else ""
if field.startswith("raw::"):
c = cl.clean_cell(value)
if c:
raw[field[5:]] = c
continue
rec[field] = value
rec["_raw_extra"] = raw
return rec
def to_model_row(rec, category_id, sheet, series_value):
"""构造 model 表插入字典。"""
raw_all = dict(rec.get("_raw_extra", {}))
m = {
"category_id": category_id,
"series": cl.clean_cell(rec.get("series") or series_value),
"model_code": cl.clean_cell(rec.get("model_code")),
"full_name": cl.clean_cell(rec.get("full_name")),
"manufacturer": cl.clean_cell(rec.get("manufacturer")),
"first_year": cl.parse_year(rec.get("first_year")),
"last_year": cl.parse_year(rec.get("last_year")),
"status": cl.normalize_status(rec.get("status")),
"usage": cl.clean_cell(rec.get("usage")),
"production_count": cl.clean_cell(rec.get("production_count")),
"axle_arrangement": cl.clean_cell(rec.get("axle_arrangement")),
"drive": cl.clean_cell(rec.get("drive")),
"efficiency": cl.clean_cell(rec.get("efficiency")),
"country": "中国",
"country_type": cl.infer_country_type(
rec.get("manufacturer"), rec.get("model_code"),
rec.get("usage"), rec.get("production_count")),
"source_sheet": sheet,
}
# 数值+单位字段拆分
for field, default_unit in fd.NUMERIC_UNIT_FIELDS.items():
if field not in fd.FIELD_SYNONYMS:
continue
val, unit, _ = cl.parse_value_unit(rec.get(field), default_unit)
m[field + "_value"] = val
m[field + "_unit"] = unit or default_unit
if rec.get(field) is not None:
raw_all[field] = cl.clean_cell(rec.get(field))
# 把所有 canonical 原文也并入 raw_json 保真
for k, v in rec.items():
if k == "_raw_extra":
continue
c = cl.clean_cell(v)
if c:
raw_all[k] = c
m["raw_json"] = json.dumps(raw_all, ensure_ascii=False)
return m
def to_unit_row(rec, category_id, sheet):
raw_all = dict(rec.get("_raw_extra", {}))
for k, v in rec.items():
if k == "_raw_extra":
continue
c = cl.clean_cell(v)
if c:
raw_all[k] = c
return {
"category_id": category_id,
"car_number": cl.clean_cell(rec.get("car_number")),
"model_name": cl.clean_cell(rec.get("full_name") or rec.get("model_code")),
"function": cl.clean_cell(rec.get("function")),
"depot": cl.clean_cell(rec.get("depot")),
"livery": cl.clean_cell(rec.get("livery")),
"status": cl.normalize_status(rec.get("status")),
"location": cl.clean_cell(rec.get("location")),
"note": cl.clean_cell(rec.get("note")),
"raw_json": json.dumps(raw_all, ensure_ascii=False),
"source_sheet": sheet,
}
def _insert(conn, table, row):
cols = list(row.keys())
ph = ",".join(["?"] * len(cols))
conn.execute(
f"INSERT INTO {table} ({','.join(cols)}) VALUES ({ph})",
[row[c] for c in cols],
)
def import_all(csv_dir=CSV_DIR, db_path=DB_PATH):
os.makedirs(OUT_DIR, exist_ok=True)
if os.path.exists(db_path):
os.remove(db_path)
conn = sqlite3.connect(db_path)
with open(SCHEMA_PATH, encoding="utf-8") as f:
conn.executescript(f.read())
report = {"sheets": [], "models": 0, "units": 0, "skipped": 0, "review": []}
export = {"categories": [], "models": [], "units": []}
cat_ids = {}
for sheet, cfg in fd.CATEGORY_CONFIG.items():
path = os.path.join(csv_dir, sheet + ".csv")
entry = {"sheet": sheet, "category": cfg["category"], "grain": cfg["grain"],
"rows": 0, "imported": 0, "skipped": 0, "note": ""}
if not os.path.exists(path):
entry["note"] = "文件缺失"
report["sheets"].append(entry)
continue
with open(path, encoding="utf-8-sig") as fh:
rows = list(csv.reader(fh))
if not rows:
entry["note"] = "空表(无数据)"
report["sheets"].append(entry)
continue
key = (cfg["category"], cfg["subcat"])
if key not in cat_ids:
cur = conn.execute(
"INSERT INTO category(name, subcat, slug) VALUES (?,?,?)",
(cfg["category"], cfg["subcat"], None))
cat_ids[key] = cur.lastrowid
export["categories"].append(
{"id": cur.lastrowid, "name": cfg["category"], "subcat": cfg["subcat"]})
category_id = cat_ids[key]
hidx = find_header_row(rows)
if hidx is None:
entry["note"] = "未识别表头行"
report["sheets"].append(entry)
continue
col_map = build_column_map(rows[hidx])
# 应用 per-sheet 列覆盖(处理表头标注与实际不符的脏表)
for idx, field in cfg.get("col_override", {}).items():
col_map[idx] = field
data_rows = rows[hidx + 1:]
entry["rows"] = len(data_rows)
# 系列列向前填充(合并单元格)
series_col = next((i for i, f in col_map.items() if f == "series"), None)
if series_col is not None:
filled = cl.forward_fill([r[series_col] if series_col < len(r) else ""
for r in data_rows])
else:
filled = [""] * len(data_rows)
for r, series_value in zip(data_rows, filled):
rec = clean_record(r, col_map)
if cfg["grain"] == "unit":
car = cl.clean_cell(rec.get("car_number"))
if not car:
entry["skipped"] += 1
report["skipped"] += 1
continue
row = to_unit_row(rec, category_id, sheet)
_insert(conn, "unit", row)
export["units"].append(row)
report["units"] += 1
entry["imported"] += 1
else:
code = cl.clean_cell(rec.get("model_code")) or \
cl.clean_cell(rec.get("tour_name")) or \
cl.clean_cell(rec.get("full_name"))
if not code:
entry["skipped"] += 1
report["skipped"] += 1
continue
row = to_model_row(rec, category_id, sheet, series_value)
if not row["model_code"]:
row["model_code"] = code
_insert(conn, "model", row)
export["models"].append(row)
report["models"] += 1
entry["imported"] += 1
# 年代逻辑校验 -> 待复核
if (row["first_year"] and row["last_year"]
and row["first_year"] > row["last_year"]):
report["review"].append(
f"{sheet} / {code}: 首产年 {row['first_year']} > 停产年 {row['last_year']}")
report["sheets"].append(entry)
conn.commit()
conn.close()
with open(JSON_PATH, "w", encoding="utf-8") as f:
json.dump(export, f, ensure_ascii=False, indent=2)
_write_report(report)
return report
def _write_report(report):
lines = ["# ETL 导入报告\n",
f"- 车型(Model)**{report['models']}**",
f"- 个体(Unit)**{report['units']}**",
f"- 跳过(无主键)**{report['skipped']}**",
f"- 待人工复核:**{len(report['review'])}**\n",
"## 分表明细\n",
"| 分类表 | 分类 | 粒度 | 数据行 | 入库 | 跳过 | 备注 |",
"|---|---|---|---|---|---|---|"]
for s in report["sheets"]:
lines.append(f"| {s['sheet']} | {s['category']} | {s['grain']} | "
f"{s['rows']} | {s['imported']} | {s['skipped']} | {s['note']} |")
if report["review"]:
lines.append("\n## 待人工复核\n")
for r in report["review"]:
lines.append(f"- {r}")
with open(REPORT_PATH, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
if __name__ == "__main__":
rep = import_all()
print(f"导入完成:Model={rep['models']} Unit={rep['units']} "
f"跳过={rep['skipped']} 待复核={len(rep['review'])}")
print(f"输出:\n {DB_PATH}\n {JSON_PATH}\n {REPORT_PATH}")
sys.exit(0)