732 lines
29 KiB
Python
732 lines
29 KiB
Python
"""PPT Master 管线封装 — 核心任务处理器"""
|
||
|
||
import subprocess
|
||
import sys
|
||
import json
|
||
import shutil
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Callable
|
||
|
||
from config import config
|
||
from llm_client import llm_client
|
||
from wanx_client import wanx_client
|
||
from db import update_task_status
|
||
|
||
|
||
class PPTPipeline:
|
||
"""封装 PPT Master 完整管线"""
|
||
|
||
def __init__(self, task_id: str, task: dict, redis_callback=None):
|
||
self.task_id = task_id
|
||
self.task = task
|
||
self.task_config = task.get("config", {}) if isinstance(task.get("config"), dict) else json.loads(task.get("config", "{}"))
|
||
self.project_path: Path | None = None
|
||
self.project_name: str = ""
|
||
self._redis_callback = redis_callback
|
||
|
||
def run(self):
|
||
"""执行完整管线"""
|
||
try:
|
||
self._update("processing", 5, "开始处理任务...")
|
||
|
||
# Step 1: 源内容处理
|
||
source_md = self._step1_process_source()
|
||
|
||
# Step 2: 项目初始化
|
||
self._step2_init_project()
|
||
|
||
# Step 3: 跳过模板(使用自由设计)
|
||
|
||
# Step 4: 策略师阶段
|
||
self._step4_strategist(source_md)
|
||
|
||
# Step 5: 图片获取(如配置)
|
||
if self.task_config.get("with_images", True):
|
||
self._step5_images()
|
||
|
||
# Step 6: 执行器阶段(SVG 生成)
|
||
self._step6_executor()
|
||
|
||
# Step 7: 后处理与导出
|
||
output_path = self._step7_export()
|
||
|
||
# 完成
|
||
self._update(
|
||
"completed", 100, "PPT 生成完成",
|
||
output_file=str(output_path),
|
||
page_count=self._count_pages(),
|
||
)
|
||
|
||
except Exception as e:
|
||
self._update("failed", progress=0, status_message="生成失败", error_message=str(e))
|
||
raise
|
||
|
||
def _update(self, status: str, progress: int = None, status_message: str = None, **kwargs):
|
||
"""更新任务状态"""
|
||
update_task_status(
|
||
self.task_id,
|
||
status=status,
|
||
progress=progress,
|
||
status_message=status_message,
|
||
**kwargs,
|
||
)
|
||
# 同步更新 Redis(供前端快速轮询)
|
||
if self._redis_callback and progress is not None:
|
||
self._redis_callback(self.task_id, status, progress, status_message or "")
|
||
|
||
def _run_script(self, script_name: str, *args, cwd: Path = None) -> str:
|
||
"""运行 PPT Master Python 脚本"""
|
||
script_path = config.SCRIPTS_DIR / script_name
|
||
cmd = [sys.executable, str(script_path)] + [str(a) for a in args]
|
||
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
cwd=str(cwd or config.PPT_MASTER_PATH),
|
||
env=self._build_env(),
|
||
timeout=600,
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
raise RuntimeError(
|
||
f"脚本 {script_name} 执行失败:\n"
|
||
f"stdout: {result.stdout[-2000:]}\n"
|
||
f"stderr: {result.stderr[-2000:]}"
|
||
)
|
||
return result.stdout
|
||
|
||
def _build_env(self) -> dict:
|
||
"""构建脚本执行环境变量"""
|
||
import os
|
||
env = os.environ.copy()
|
||
env["OPENAI_API_KEY"] = config.OPENAI_API_KEY
|
||
env["OPENAI_BASE_URL"] = config.OPENAI_BASE_URL
|
||
env["OPENAI_MODEL"] = config.OPENAI_MODEL
|
||
if config.IMAGE_API_KEY:
|
||
env["IMAGE_API_KEY"] = config.IMAGE_API_KEY
|
||
env["IMAGE_BACKEND"] = config.IMAGE_BACKEND
|
||
return env
|
||
|
||
# ==================== Step 1: 源内容处理 ====================
|
||
|
||
def _step1_process_source(self) -> str:
|
||
"""处理源文件/内容,返回 Markdown 文本"""
|
||
self._update("converting", 10, "正在转换源文件...")
|
||
|
||
source_type = self.task.get("source_type", "text")
|
||
|
||
if source_type == "text":
|
||
return self.task.get("source_content", "")
|
||
|
||
elif source_type == "url":
|
||
url = self.task.get("source_content", "")
|
||
output = self._run_script("source_to_md/web_to_md.py", url)
|
||
# web_to_md.py 输出转换后的 md 文件路径
|
||
md_path = self._extract_output_path(output)
|
||
if md_path and Path(md_path).exists():
|
||
return Path(md_path).read_text(encoding="utf-8")
|
||
return output
|
||
|
||
elif source_type == "file":
|
||
file_path = self.task.get("source_file", "")
|
||
if not file_path or not Path(file_path).exists():
|
||
raise FileNotFoundError(f"源文件不存在: {file_path}")
|
||
|
||
suffix = Path(file_path).suffix.lower()
|
||
if suffix == ".pdf":
|
||
output = self._run_script("source_to_md/pdf_to_md.py", file_path)
|
||
elif suffix in (".docx", ".doc", ".html", ".epub"):
|
||
output = self._run_script("source_to_md/doc_to_md.py", file_path)
|
||
elif suffix in (".xlsx", ".xlsm"):
|
||
output = self._run_script("source_to_md/excel_to_md.py", file_path)
|
||
elif suffix in (".pptx",):
|
||
output = self._run_script("source_to_md/ppt_to_md.py", file_path)
|
||
elif suffix == ".md":
|
||
return Path(file_path).read_text(encoding="utf-8")
|
||
else:
|
||
output = self._run_script("source_to_md/doc_to_md.py", file_path)
|
||
|
||
md_path = self._extract_output_path(output)
|
||
if md_path and Path(md_path).exists():
|
||
return Path(md_path).read_text(encoding="utf-8")
|
||
return output
|
||
|
||
return ""
|
||
|
||
def _extract_output_path(self, script_output: str) -> str | None:
|
||
"""从脚本输出中提取生成的文件路径"""
|
||
for line in script_output.strip().split("\n"):
|
||
line = line.strip()
|
||
if line.endswith(".md") and Path(line).exists():
|
||
return line
|
||
# 匹配 "Output: /path/to/file.md" 格式
|
||
match = re.search(r"(?:Output|Saved|Written):\s*(.+\.md)", line, re.IGNORECASE)
|
||
if match:
|
||
path = match.group(1).strip()
|
||
if Path(path).exists():
|
||
return path
|
||
return None
|
||
|
||
def _resolve_project_path(self, script_output: str) -> Path:
|
||
"""从 project_manager.py 输出中解析实际项目路径"""
|
||
# 方式 1: 从输出中解析 "Project created: projects/xxx" 或 "[OK] Project initialized: projects/xxx"
|
||
for line in script_output.strip().split("\n"):
|
||
for prefix in ("Project created:", "[OK] Project initialized:"):
|
||
if prefix in line:
|
||
rel = line.split(prefix, 1)[1].strip()
|
||
candidate = config.PROJECTS_DIR / rel
|
||
if candidate.exists():
|
||
return candidate
|
||
|
||
# 方式 2: 精确名称匹配
|
||
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
|
||
candidate = base / self.project_name
|
||
if candidate.exists():
|
||
return candidate
|
||
|
||
# 方式 3: glob 模糊匹配(project_manager 会追加格式和日期后缀)
|
||
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
|
||
if base.exists():
|
||
matches = sorted(base.glob(f"{self.project_name}*"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||
if matches:
|
||
return matches[0]
|
||
|
||
raise RuntimeError(f"项目初始化失败,找不到项目目录: {script_output}")
|
||
|
||
# ==================== Step 2: 项目初始化 ====================
|
||
|
||
def _step2_init_project(self):
|
||
"""初始化 PPT Master 项目"""
|
||
self._update("processing", 15, "初始化项目...")
|
||
|
||
# 生成项目名
|
||
title = self.task.get("title", "untitled")
|
||
# 清理文件名
|
||
safe_name = re.sub(r'[^\w\u4e00-\u9fff-]', '_', title)[:50]
|
||
self.project_name = f"govai_{safe_name}_{self.task_id[:8]}"
|
||
|
||
fmt = self.task_config.get("format", "ppt169")
|
||
|
||
output = self._run_script(
|
||
"project_manager.py", "init", self.project_name,
|
||
"--format", fmt,
|
||
cwd=config.PROJECTS_DIR,
|
||
)
|
||
|
||
# 从输出中提取实际项目路径
|
||
self.project_path = self._resolve_project_path(output)
|
||
|
||
update_task_status(self.task_id, status="processing", project_path=str(self.project_path))
|
||
|
||
# ==================== Step 4: 策略师阶段 ====================
|
||
|
||
def _step4_strategist(self, source_md: str):
|
||
"""策略师阶段:生成设计规范和内容大纲"""
|
||
self._update("designing", 20, "AI 正在分析内容并设计方案...")
|
||
|
||
# 读取策略师参考文档
|
||
strategist_ref = (config.SKILL_DIR / "references" / "strategist.md").read_text(encoding="utf-8")
|
||
design_spec_ref = (config.SKILL_DIR / "templates" / "design_spec_reference.md").read_text(encoding="utf-8")
|
||
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
|
||
|
||
page_count = self.task_config.get("page_count", 10)
|
||
style = self.task_config.get("style", "general")
|
||
fmt = self.task_config.get("format", "ppt169")
|
||
language = self.task_config.get("language", "zh")
|
||
|
||
# 构建策略师 prompt
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是 PPT Master 的策略师角色。你的任务是根据用户提供的源内容,"
|
||
"生成完整的 design_spec.md(设计规范)和 spec_lock.md(执行锁定)。\n\n"
|
||
"你必须严格遵循以下参考模板的结构来生成输出:\n\n"
|
||
"## 策略师角色定义(摘要)\n"
|
||
f"{strategist_ref[:8000]}\n\n"
|
||
"## design_spec.md 参考模板\n"
|
||
f"{design_spec_ref[:10000]}\n\n"
|
||
"## spec_lock.md 参考模板\n"
|
||
f"{spec_lock_ref[:8000]}\n\n"
|
||
"请直接输出两个文件的完整内容,用以下分隔符分开:\n"
|
||
"===DESIGN_SPEC_START===\n[design_spec.md 内容]\n===DESIGN_SPEC_END===\n"
|
||
"===SPEC_LOCK_START===\n[spec_lock.md 内容]\n===SPEC_LOCK_END===\n"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
f"请根据以下源内容生成演示文稿的设计规范。\n\n"
|
||
f"## 生成要求\n"
|
||
f"- 画布格式: {fmt}\n"
|
||
f"- 目标页数: {page_count} 页\n"
|
||
f"- 设计风格: {style}\n"
|
||
f"- 语言: {language}\n\n"
|
||
f"## 源内容\n\n{source_md[:50000]}"
|
||
),
|
||
},
|
||
]
|
||
|
||
response = llm_client.chat(messages, temperature=0.7, max_tokens=16384)
|
||
|
||
# 解析输出
|
||
design_spec = self._extract_section(response, "DESIGN_SPEC_START", "DESIGN_SPEC_END")
|
||
spec_lock = self._extract_section(response, "SPEC_LOCK_START", "SPEC_LOCK_END")
|
||
|
||
if not design_spec:
|
||
design_spec = response # 如果解析失败,整体作为 design_spec
|
||
if not spec_lock:
|
||
# 如果没有分离出 spec_lock,再次调用 LLM 单独生成
|
||
spec_lock = self._generate_spec_lock(design_spec, source_md)
|
||
|
||
# 写入文件
|
||
(self.project_path / "design_spec.md").write_text(design_spec, encoding="utf-8")
|
||
(self.project_path / "spec_lock.md").write_text(spec_lock, encoding="utf-8")
|
||
|
||
self._update("designing", 35, "设计规范生成完成")
|
||
|
||
def _generate_spec_lock(self, design_spec: str, source_md: str) -> str:
|
||
"""单独生成 spec_lock.md"""
|
||
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"根据已生成的 design_spec.md,生成对应的 spec_lock.md(执行锁定文件)。\n"
|
||
"严格遵循以下参考模板结构:\n\n"
|
||
f"{spec_lock_ref[:8000]}\n\n"
|
||
"直接输出 spec_lock.md 的完整内容,不要加任何额外说明。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"## design_spec.md\n\n{design_spec[:15000]}",
|
||
},
|
||
]
|
||
return llm_client.chat(messages, temperature=0.5, max_tokens=8192)
|
||
|
||
def _extract_section(self, text: str, start_marker: str, end_marker: str) -> str:
|
||
"""从文本中提取指定标记之间的内容"""
|
||
start = text.find(f"==={start_marker}===")
|
||
end = text.find(f"==={end_marker}===")
|
||
if start == -1 or end == -1:
|
||
return ""
|
||
start += len(f"==={start_marker}===")
|
||
return text[start:end].strip()
|
||
|
||
# ==================== Step 5: 图片获取 ====================
|
||
|
||
def _step5_images(self):
|
||
"""图片获取阶段(通义万相 / 脚本回退)"""
|
||
self._update("generating_images", 40, "正在生成/搜索图片...")
|
||
|
||
# 从 design_spec 中提取图片需求
|
||
design_spec_path = self.project_path / "design_spec.md"
|
||
if not design_spec_path.exists():
|
||
return
|
||
|
||
design_spec = design_spec_path.read_text(encoding="utf-8")
|
||
|
||
# 检查是否有图片需求
|
||
if "Acquire Via: ai" not in design_spec and "Acquire Via: web" not in design_spec:
|
||
self._update("generating_images", 45, "无需额外图片")
|
||
return
|
||
|
||
# 使用 LLM 生成图片 prompt
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"根据设计规范中的图片需求,生成图片生成 prompt。\n"
|
||
"每行一个,格式为: filename|prompt\n"
|
||
"例如: cover_bg.png|A modern city skyline at sunset, soft gradient sky\n"
|
||
"只输出需要生成的图片,不要解释。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": design_spec[:10000],
|
||
},
|
||
]
|
||
|
||
response = llm_client.chat(messages, temperature=0.5, max_tokens=4096)
|
||
|
||
images_dir = self.project_path / "images"
|
||
images_dir.mkdir(exist_ok=True)
|
||
|
||
# 根据画布格式确定图片尺寸
|
||
fmt = self.task_config.get("format", "ppt169")
|
||
wanx_size_map = {
|
||
"ppt169": "1024*576",
|
||
"ppt43": "1024*768",
|
||
"xhs": "720*1280",
|
||
"story": "720*1280",
|
||
}
|
||
size = wanx_size_map.get(fmt, "1024*576")
|
||
|
||
# 逐个生成图片
|
||
lines = [l.strip() for l in response.strip().split("\n") if "|" in l]
|
||
for i, line in enumerate(lines[:10]): # 最多 10 张
|
||
parts = line.split("|", 1)
|
||
if len(parts) != 2:
|
||
continue
|
||
filename, prompt = parts[0].strip(), parts[1].strip()
|
||
|
||
try:
|
||
self._generate_image(prompt, images_dir, filename, size)
|
||
except Exception:
|
||
# 图片生成失败不中断管线
|
||
pass
|
||
|
||
progress = 40 + int((i + 1) / max(len(lines), 1) * 10)
|
||
self._update("generating_images", progress, f"图片生成中 ({i+1}/{len(lines)})")
|
||
|
||
def _generate_image(self, prompt: str, output_dir: Path, filename: str, size: str):
|
||
"""生成单张图片:优先用万相,失败回退到 image_gen.py 脚本"""
|
||
backend = config.IMAGE_BACKEND
|
||
|
||
if backend == "wanx":
|
||
wanx_client.generate(
|
||
prompt=prompt,
|
||
output_dir=output_dir,
|
||
filename=filename,
|
||
size=size,
|
||
)
|
||
else:
|
||
# 回退到 PPT Master 自带的 image_gen.py 脚本
|
||
self._run_script(
|
||
"image_gen.py", prompt,
|
||
"--aspect_ratio", "16:9",
|
||
"--image_size", "1K",
|
||
"-o", str(output_dir),
|
||
"--filename", filename,
|
||
)
|
||
|
||
# ==================== Step 6: 执行器阶段 ====================
|
||
|
||
def _step6_executor(self):
|
||
"""执行器阶段:逐页生成 SVG"""
|
||
self._update("generating_svg", 50, "AI 正在逐页生成幻灯片...")
|
||
|
||
spec_lock_path = self.project_path / "spec_lock.md"
|
||
design_spec_path = self.project_path / "design_spec.md"
|
||
|
||
spec_lock = spec_lock_path.read_text(encoding="utf-8") if spec_lock_path.exists() else ""
|
||
design_spec = design_spec_path.read_text(encoding="utf-8") if design_spec_path.exists() else ""
|
||
|
||
# 读取执行器参考文档
|
||
style = self.task_config.get("style", "general")
|
||
executor_base = (config.SKILL_DIR / "references" / "executor-base.md").read_text(encoding="utf-8")
|
||
shared_standards = (config.SKILL_DIR / "references" / "shared-standards.md").read_text(encoding="utf-8")
|
||
|
||
style_file_map = {
|
||
"general": "executor-general.md",
|
||
"consultant": "executor-consultant.md",
|
||
"consultant-top": "executor-consultant-top.md",
|
||
}
|
||
style_file = style_file_map.get(style, "executor-general.md")
|
||
executor_style = (config.SKILL_DIR / "references" / style_file).read_text(encoding="utf-8")
|
||
|
||
# 从 spec_lock 提取页数
|
||
page_count = self.task_config.get("page_count", 10)
|
||
# 尝试从 spec_lock 中提取实际规划页数
|
||
page_match = re.search(r"page_count:\s*(\d+)", spec_lock)
|
||
if page_match:
|
||
page_count = int(page_match.group(1))
|
||
|
||
svg_output_dir = self.project_path / "svg_output"
|
||
svg_output_dir.mkdir(exist_ok=True)
|
||
|
||
# 逐页生成 SVG
|
||
for page_num in range(1, page_count + 1):
|
||
self._generate_svg_page(
|
||
page_num, page_count,
|
||
spec_lock, design_spec,
|
||
executor_base, shared_standards, executor_style,
|
||
svg_output_dir,
|
||
)
|
||
progress = 50 + int(page_num / page_count * 30)
|
||
self._update("generating_svg", progress, f"生成第 {page_num}/{page_count} 页")
|
||
|
||
# 生成演讲者备注
|
||
self._generate_speaker_notes(design_spec, page_count)
|
||
|
||
# SVG 质量检查
|
||
try:
|
||
self._run_script("svg_quality_checker.py", str(self.project_path))
|
||
except RuntimeError:
|
||
pass # 质量检查失败不阻断
|
||
|
||
def _generate_svg_page(
|
||
self, page_num: int, total_pages: int,
|
||
spec_lock: str, design_spec: str,
|
||
executor_base: str, shared_standards: str, executor_style: str,
|
||
svg_output_dir: Path,
|
||
):
|
||
"""生成单页 SVG"""
|
||
# 获取画布尺寸
|
||
fmt = self.task_config.get("format", "ppt169")
|
||
canvas_map = {
|
||
"ppt169": (1280, 720),
|
||
"ppt43": (1024, 768),
|
||
"xhs": (1080, 1440),
|
||
"story": (1080, 1920),
|
||
}
|
||
width, height = canvas_map.get(fmt, (1280, 720))
|
||
|
||
# 构建每页的 prompt
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
f"你是 PPT Master 的执行器角色。现在需要生成第 {page_num}/{total_pages} 页的 SVG 代码。\n\n"
|
||
f"## 关键技术约束(摘要)\n"
|
||
f"- 画布: {width}x{height}\n"
|
||
f"- viewBox=\"0 0 {width} {height}\"\n"
|
||
f"- 所有文本必须用 <text> 元素,禁止 <foreignObject>\n"
|
||
f"- 长文本必须手动分行(多个 <text> 或 <tspan dy=\"...\")\n"
|
||
f"- 图片使用 <image href=\"../images/filename.ext\">\n"
|
||
f"- 每个顶层元素用 <g id=\"elem_N\"> 包裹\n"
|
||
f"- 颜色、字体必须严格来自 spec_lock\n\n"
|
||
f"## 执行器规范(摘要)\n{executor_base[:6000]}\n\n"
|
||
f"## 风格规范(摘要)\n{executor_style[:4000]}\n\n"
|
||
f"## 技术标准(摘要)\n{shared_standards[:6000]}\n\n"
|
||
f"直接输出完整的 SVG 代码,不要用 ```svg 包裹,不要加解释。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
f"## spec_lock.md\n\n{spec_lock[:12000]}\n\n"
|
||
f"## design_spec.md(第 {page_num} 页相关内容)\n\n{design_spec[:8000]}\n\n"
|
||
f"请生成第 {page_num} 页的完整 SVG。"
|
||
),
|
||
},
|
||
]
|
||
|
||
svg_content = llm_client.chat(messages, temperature=0.6, max_tokens=16384)
|
||
|
||
# 清理:去掉可能的 ```svg 包裹
|
||
svg_content = svg_content.strip()
|
||
if svg_content.startswith("```"):
|
||
lines = svg_content.split("\n")
|
||
svg_content = "\n".join(lines[1:])
|
||
if svg_content.endswith("```"):
|
||
svg_content = svg_content[:-3].strip()
|
||
|
||
# 确保以 <svg 开头
|
||
svg_start = svg_content.find("<svg")
|
||
if svg_start > 0:
|
||
svg_content = svg_content[svg_start:]
|
||
|
||
# 写入文件
|
||
filename = f"page_{page_num:02d}.svg"
|
||
(svg_output_dir / filename).write_text(svg_content, encoding="utf-8")
|
||
|
||
def _generate_speaker_notes(self, design_spec: str, page_count: int):
|
||
"""生成演讲者备注"""
|
||
notes_dir = self.project_path / "notes"
|
||
notes_dir.mkdir(exist_ok=True)
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"根据设计规范为每一页生成演讲者备注。\n"
|
||
"格式为每页用 ## Page N 分隔,内容为该页的讲解要点,100-200字。\n"
|
||
"直接输出,不加额外说明。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"共 {page_count} 页。设计规范:\n\n{design_spec[:15000]}",
|
||
},
|
||
]
|
||
|
||
notes = llm_client.chat(messages, temperature=0.7, max_tokens=8192)
|
||
(notes_dir / "total.md").write_text(notes, encoding="utf-8")
|
||
|
||
# ==================== Step 7: 后处理与导出 ====================
|
||
|
||
def _step7_export(self) -> Path:
|
||
"""后处理并导出 PPTX"""
|
||
self._update("exporting", 85, "正在处理并导出 PPTX...")
|
||
|
||
# Step 7.1: 拆分演讲者备注
|
||
try:
|
||
self._run_script("total_md_split.py", str(self.project_path))
|
||
except RuntimeError:
|
||
pass # 备注拆分失败不阻断
|
||
|
||
self._update("exporting", 88, "清理无效图片引用...")
|
||
self._strip_missing_images()
|
||
|
||
self._update("exporting", 90, "SVG 后处理中...")
|
||
|
||
# Step 7.2: SVG 后处理
|
||
self._run_script("finalize_svg.py", str(self.project_path))
|
||
|
||
# Step 7.2.5: 验证并修复无效 SVG(LLM 可能生成格式错误的 XML)
|
||
self._validate_and_repair_svgs()
|
||
|
||
self._update("exporting", 95, "导出 PPTX 文件...")
|
||
|
||
# Step 7.3: 导出 PPTX
|
||
self._run_script("svg_to_pptx.py", str(self.project_path))
|
||
|
||
# 查找输出文件
|
||
exports_dir = self.project_path / "exports"
|
||
if not exports_dir.exists():
|
||
# 有些版本输出到上级 exports 目录
|
||
exports_dir = config.PROJECTS_DIR / "exports"
|
||
|
||
pptx_files = list(exports_dir.glob("*.pptx")) if exports_dir.exists() else []
|
||
|
||
if not pptx_files:
|
||
# 搜索项目目录下所有 pptx
|
||
pptx_files = list(self.project_path.rglob("*.pptx"))
|
||
|
||
if not pptx_files:
|
||
raise RuntimeError("PPTX 导出失败:找不到输出文件")
|
||
|
||
# 取最新的 pptx
|
||
output_pptx = max(pptx_files, key=lambda f: f.stat().st_mtime)
|
||
|
||
# 复制到统一输出目录
|
||
final_path = config.OUTPUT_DIR / f"{self.task_id}.pptx"
|
||
shutil.copy2(output_pptx, final_path)
|
||
|
||
return final_path
|
||
|
||
def _strip_missing_images(self):
|
||
"""移除 SVG 中引用不存在图片文件的 <image> 元素"""
|
||
svg_dir = self.project_path / "svg_output"
|
||
if not svg_dir.exists():
|
||
return
|
||
|
||
import xml.etree.ElementTree as ET
|
||
|
||
ns = {"svg": "http://www.w3.org/2000/svg", "xlink": "http://www.w3.org/1999/xlink"}
|
||
ET.register_namespace("", "http://www.w3.org/2000/svg")
|
||
ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
|
||
|
||
for svg_file in sorted(svg_dir.glob("*.svg")):
|
||
try:
|
||
tree = ET.parse(svg_file)
|
||
root = tree.getroot()
|
||
changed = False
|
||
|
||
# 递归查找所有 <image> 元素
|
||
for parent in root.iter():
|
||
to_remove = []
|
||
for child in parent:
|
||
tag = child.tag.split("}")[-1] if "}" in child.tag else child.tag
|
||
if tag != "image":
|
||
continue
|
||
|
||
href = (child.get("href") or
|
||
child.get("{http://www.w3.org/1999/xlink}href") or "")
|
||
|
||
# 跳过 data URI 和绝对 URL
|
||
if href.startswith("data:") or href.startswith("http"):
|
||
continue
|
||
|
||
# 检查相对路径是否存在
|
||
img_path = (svg_dir / href).resolve()
|
||
if not img_path.exists():
|
||
to_remove.append(child)
|
||
|
||
for elem in to_remove:
|
||
parent.remove(elem)
|
||
changed = True
|
||
|
||
if changed:
|
||
tree.write(svg_file, encoding="unicode", xml_declaration=True)
|
||
except Exception:
|
||
pass # 解析失败跳过
|
||
|
||
def _validate_and_repair_svgs(self):
|
||
"""验证所有 SVG 文件,修复或移除无效文件"""
|
||
import xml.etree.ElementTree as ET
|
||
|
||
svg_dir = self.project_path / "svg_output"
|
||
if not svg_dir.exists():
|
||
return
|
||
|
||
for svg_file in sorted(svg_dir.glob("*.svg")):
|
||
try:
|
||
ET.parse(svg_file)
|
||
except ET.ParseError:
|
||
# 尝试修复常见问题
|
||
repaired = self._try_repair_svg(svg_file)
|
||
if not repaired:
|
||
# 无法修复,移除该文件让导出跳过这页
|
||
svg_file.unlink()
|
||
print(f"[Pipeline] 移除无效 SVG: {svg_file.name}")
|
||
|
||
def _try_repair_svg(self, svg_file: Path) -> bool:
|
||
"""尝试修复无效 SVG,返回是否修复成功"""
|
||
import re
|
||
import xml.etree.ElementTree as ET
|
||
|
||
content = svg_file.read_text(encoding="utf-8")
|
||
|
||
# 修复0: 移除自闭合标签后的同名多余闭合标签(LLM 常见错误)
|
||
# 例: <feDropShadow .../>\n</feDropShadow> → <feDropShadow .../>
|
||
content = re.sub(
|
||
r'<(\w+)\b([^>]*)/>\s*\n\s*</\1>',
|
||
r'<\1\2/>',
|
||
content,
|
||
)
|
||
|
||
# 修复1: 确保以 </svg> 结尾
|
||
if "</svg>" not in content:
|
||
# 找到最后一个完整的闭合标签,截断并添加 </svg>
|
||
last_close = content.rfind("/>")
|
||
if last_close == -1:
|
||
last_close = content.rfind("</")
|
||
if last_close > 0:
|
||
# 找到该行末尾
|
||
line_end = content.find("\n", last_close)
|
||
if line_end > 0:
|
||
content = content[:line_end + 1] + "</svg>\n"
|
||
else:
|
||
content = content[:last_close + 2] + "\n</svg>\n"
|
||
|
||
# 修复2: 移除 </svg> 之后的垃圾内容
|
||
svg_end = content.find("</svg>")
|
||
if svg_end > 0:
|
||
content = content[:svg_end + len("</svg>")]
|
||
|
||
# 修复3: 尝试逐行移除导致解析失败的行
|
||
try:
|
||
ET.fromstring(content)
|
||
svg_file.write_text(content, encoding="utf-8")
|
||
return True
|
||
except ET.ParseError:
|
||
pass
|
||
|
||
# 修复4: 暴力截断 — 找到最后一个有效的闭合 </g> 或 </text>,截断后闭合 </svg>
|
||
for tag in ("</g>", "</text>", "</rect>", "</circle>"):
|
||
last_pos = content.rfind(tag)
|
||
if last_pos > 0:
|
||
candidate = content[:last_pos + len(tag)] + "\n</svg>"
|
||
try:
|
||
ET.fromstring(candidate)
|
||
svg_file.write_text(candidate, encoding="utf-8")
|
||
return True
|
||
except ET.ParseError:
|
||
continue
|
||
|
||
return False
|
||
|
||
def _count_pages(self) -> int:
|
||
"""统计生成的页数"""
|
||
svg_dir = self.project_path / "svg_output"
|
||
if svg_dir.exists():
|
||
return len(list(svg_dir.glob("*.svg")))
|
||
return 0
|