Files
2026-06-15 23:48:37 +08:00

732 lines
29 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PPT Master 管线封装 — 核心任务处理器"""
import subprocess
import sys
import json
import shutil
import re
from pathlib import Path
from typing import Callable
from config import config
from llm_client import llm_client
from wanx_client import wanx_client
from db import update_task_status
class PPTPipeline:
"""封装 PPT Master 完整管线"""
def __init__(self, task_id: str, task: dict, redis_callback=None):
self.task_id = task_id
self.task = task
self.task_config = task.get("config", {}) if isinstance(task.get("config"), dict) else json.loads(task.get("config", "{}"))
self.project_path: Path | None = None
self.project_name: str = ""
self._redis_callback = redis_callback
def run(self):
"""执行完整管线"""
try:
self._update("processing", 5, "开始处理任务...")
# Step 1: 源内容处理
source_md = self._step1_process_source()
# Step 2: 项目初始化
self._step2_init_project()
# Step 3: 跳过模板(使用自由设计)
# Step 4: 策略师阶段
self._step4_strategist(source_md)
# Step 5: 图片获取(如配置)
if self.task_config.get("with_images", True):
self._step5_images()
# Step 6: 执行器阶段(SVG 生成)
self._step6_executor()
# Step 7: 后处理与导出
output_path = self._step7_export()
# 完成
self._update(
"completed", 100, "PPT 生成完成",
output_file=str(output_path),
page_count=self._count_pages(),
)
except Exception as e:
self._update("failed", progress=0, status_message="生成失败", error_message=str(e))
raise
def _update(self, status: str, progress: int = None, status_message: str = None, **kwargs):
"""更新任务状态"""
update_task_status(
self.task_id,
status=status,
progress=progress,
status_message=status_message,
**kwargs,
)
# 同步更新 Redis(供前端快速轮询)
if self._redis_callback and progress is not None:
self._redis_callback(self.task_id, status, progress, status_message or "")
def _run_script(self, script_name: str, *args, cwd: Path = None) -> str:
"""运行 PPT Master Python 脚本"""
script_path = config.SCRIPTS_DIR / script_name
cmd = [sys.executable, str(script_path)] + [str(a) for a in args]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=str(cwd or config.PPT_MASTER_PATH),
env=self._build_env(),
timeout=600,
)
if result.returncode != 0:
raise RuntimeError(
f"脚本 {script_name} 执行失败:\n"
f"stdout: {result.stdout[-2000:]}\n"
f"stderr: {result.stderr[-2000:]}"
)
return result.stdout
def _build_env(self) -> dict:
"""构建脚本执行环境变量"""
import os
env = os.environ.copy()
env["OPENAI_API_KEY"] = config.OPENAI_API_KEY
env["OPENAI_BASE_URL"] = config.OPENAI_BASE_URL
env["OPENAI_MODEL"] = config.OPENAI_MODEL
if config.IMAGE_API_KEY:
env["IMAGE_API_KEY"] = config.IMAGE_API_KEY
env["IMAGE_BACKEND"] = config.IMAGE_BACKEND
return env
# ==================== Step 1: 源内容处理 ====================
def _step1_process_source(self) -> str:
"""处理源文件/内容,返回 Markdown 文本"""
self._update("converting", 10, "正在转换源文件...")
source_type = self.task.get("source_type", "text")
if source_type == "text":
return self.task.get("source_content", "")
elif source_type == "url":
url = self.task.get("source_content", "")
output = self._run_script("source_to_md/web_to_md.py", url)
# web_to_md.py 输出转换后的 md 文件路径
md_path = self._extract_output_path(output)
if md_path and Path(md_path).exists():
return Path(md_path).read_text(encoding="utf-8")
return output
elif source_type == "file":
file_path = self.task.get("source_file", "")
if not file_path or not Path(file_path).exists():
raise FileNotFoundError(f"源文件不存在: {file_path}")
suffix = Path(file_path).suffix.lower()
if suffix == ".pdf":
output = self._run_script("source_to_md/pdf_to_md.py", file_path)
elif suffix in (".docx", ".doc", ".html", ".epub"):
output = self._run_script("source_to_md/doc_to_md.py", file_path)
elif suffix in (".xlsx", ".xlsm"):
output = self._run_script("source_to_md/excel_to_md.py", file_path)
elif suffix in (".pptx",):
output = self._run_script("source_to_md/ppt_to_md.py", file_path)
elif suffix == ".md":
return Path(file_path).read_text(encoding="utf-8")
else:
output = self._run_script("source_to_md/doc_to_md.py", file_path)
md_path = self._extract_output_path(output)
if md_path and Path(md_path).exists():
return Path(md_path).read_text(encoding="utf-8")
return output
return ""
def _extract_output_path(self, script_output: str) -> str | None:
"""从脚本输出中提取生成的文件路径"""
for line in script_output.strip().split("\n"):
line = line.strip()
if line.endswith(".md") and Path(line).exists():
return line
# 匹配 "Output: /path/to/file.md" 格式
match = re.search(r"(?:Output|Saved|Written):\s*(.+\.md)", line, re.IGNORECASE)
if match:
path = match.group(1).strip()
if Path(path).exists():
return path
return None
def _resolve_project_path(self, script_output: str) -> Path:
"""从 project_manager.py 输出中解析实际项目路径"""
# 方式 1: 从输出中解析 "Project created: projects/xxx" 或 "[OK] Project initialized: projects/xxx"
for line in script_output.strip().split("\n"):
for prefix in ("Project created:", "[OK] Project initialized:"):
if prefix in line:
rel = line.split(prefix, 1)[1].strip()
candidate = config.PROJECTS_DIR / rel
if candidate.exists():
return candidate
# 方式 2: 精确名称匹配
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
candidate = base / self.project_name
if candidate.exists():
return candidate
# 方式 3: glob 模糊匹配(project_manager 会追加格式和日期后缀)
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
if base.exists():
matches = sorted(base.glob(f"{self.project_name}*"), key=lambda p: p.stat().st_mtime, reverse=True)
if matches:
return matches[0]
raise RuntimeError(f"项目初始化失败,找不到项目目录: {script_output}")
# ==================== Step 2: 项目初始化 ====================
def _step2_init_project(self):
"""初始化 PPT Master 项目"""
self._update("processing", 15, "初始化项目...")
# 生成项目名
title = self.task.get("title", "untitled")
# 清理文件名
safe_name = re.sub(r'[^\w\u4e00-\u9fff-]', '_', title)[:50]
self.project_name = f"govai_{safe_name}_{self.task_id[:8]}"
fmt = self.task_config.get("format", "ppt169")
output = self._run_script(
"project_manager.py", "init", self.project_name,
"--format", fmt,
cwd=config.PROJECTS_DIR,
)
# 从输出中提取实际项目路径
self.project_path = self._resolve_project_path(output)
update_task_status(self.task_id, status="processing", project_path=str(self.project_path))
# ==================== Step 4: 策略师阶段 ====================
def _step4_strategist(self, source_md: str):
"""策略师阶段:生成设计规范和内容大纲"""
self._update("designing", 20, "AI 正在分析内容并设计方案...")
# 读取策略师参考文档
strategist_ref = (config.SKILL_DIR / "references" / "strategist.md").read_text(encoding="utf-8")
design_spec_ref = (config.SKILL_DIR / "templates" / "design_spec_reference.md").read_text(encoding="utf-8")
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
page_count = self.task_config.get("page_count", 10)
style = self.task_config.get("style", "general")
fmt = self.task_config.get("format", "ppt169")
language = self.task_config.get("language", "zh")
# 构建策略师 prompt
messages = [
{
"role": "system",
"content": (
"你是 PPT Master 的策略师角色。你的任务是根据用户提供的源内容,"
"生成完整的 design_spec.md(设计规范)和 spec_lock.md(执行锁定)。\n\n"
"你必须严格遵循以下参考模板的结构来生成输出:\n\n"
"## 策略师角色定义(摘要)\n"
f"{strategist_ref[:8000]}\n\n"
"## design_spec.md 参考模板\n"
f"{design_spec_ref[:10000]}\n\n"
"## spec_lock.md 参考模板\n"
f"{spec_lock_ref[:8000]}\n\n"
"请直接输出两个文件的完整内容,用以下分隔符分开:\n"
"===DESIGN_SPEC_START===\n[design_spec.md 内容]\n===DESIGN_SPEC_END===\n"
"===SPEC_LOCK_START===\n[spec_lock.md 内容]\n===SPEC_LOCK_END===\n"
),
},
{
"role": "user",
"content": (
f"请根据以下源内容生成演示文稿的设计规范。\n\n"
f"## 生成要求\n"
f"- 画布格式: {fmt}\n"
f"- 目标页数: {page_count}\n"
f"- 设计风格: {style}\n"
f"- 语言: {language}\n\n"
f"## 源内容\n\n{source_md[:50000]}"
),
},
]
response = llm_client.chat(messages, temperature=0.7, max_tokens=16384)
# 解析输出
design_spec = self._extract_section(response, "DESIGN_SPEC_START", "DESIGN_SPEC_END")
spec_lock = self._extract_section(response, "SPEC_LOCK_START", "SPEC_LOCK_END")
if not design_spec:
design_spec = response # 如果解析失败,整体作为 design_spec
if not spec_lock:
# 如果没有分离出 spec_lock,再次调用 LLM 单独生成
spec_lock = self._generate_spec_lock(design_spec, source_md)
# 写入文件
(self.project_path / "design_spec.md").write_text(design_spec, encoding="utf-8")
(self.project_path / "spec_lock.md").write_text(spec_lock, encoding="utf-8")
self._update("designing", 35, "设计规范生成完成")
def _generate_spec_lock(self, design_spec: str, source_md: str) -> str:
"""单独生成 spec_lock.md"""
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
messages = [
{
"role": "system",
"content": (
"根据已生成的 design_spec.md,生成对应的 spec_lock.md(执行锁定文件)。\n"
"严格遵循以下参考模板结构:\n\n"
f"{spec_lock_ref[:8000]}\n\n"
"直接输出 spec_lock.md 的完整内容,不要加任何额外说明。"
),
},
{
"role": "user",
"content": f"## design_spec.md\n\n{design_spec[:15000]}",
},
]
return llm_client.chat(messages, temperature=0.5, max_tokens=8192)
def _extract_section(self, text: str, start_marker: str, end_marker: str) -> str:
"""从文本中提取指定标记之间的内容"""
start = text.find(f"==={start_marker}===")
end = text.find(f"==={end_marker}===")
if start == -1 or end == -1:
return ""
start += len(f"==={start_marker}===")
return text[start:end].strip()
# ==================== Step 5: 图片获取 ====================
def _step5_images(self):
"""图片获取阶段(通义万相 / 脚本回退)"""
self._update("generating_images", 40, "正在生成/搜索图片...")
# 从 design_spec 中提取图片需求
design_spec_path = self.project_path / "design_spec.md"
if not design_spec_path.exists():
return
design_spec = design_spec_path.read_text(encoding="utf-8")
# 检查是否有图片需求
if "Acquire Via: ai" not in design_spec and "Acquire Via: web" not in design_spec:
self._update("generating_images", 45, "无需额外图片")
return
# 使用 LLM 生成图片 prompt
messages = [
{
"role": "system",
"content": (
"根据设计规范中的图片需求,生成图片生成 prompt。\n"
"每行一个,格式为: filename|prompt\n"
"例如: cover_bg.png|A modern city skyline at sunset, soft gradient sky\n"
"只输出需要生成的图片,不要解释。"
),
},
{
"role": "user",
"content": design_spec[:10000],
},
]
response = llm_client.chat(messages, temperature=0.5, max_tokens=4096)
images_dir = self.project_path / "images"
images_dir.mkdir(exist_ok=True)
# 根据画布格式确定图片尺寸
fmt = self.task_config.get("format", "ppt169")
wanx_size_map = {
"ppt169": "1024*576",
"ppt43": "1024*768",
"xhs": "720*1280",
"story": "720*1280",
}
size = wanx_size_map.get(fmt, "1024*576")
# 逐个生成图片
lines = [l.strip() for l in response.strip().split("\n") if "|" in l]
for i, line in enumerate(lines[:10]): # 最多 10 张
parts = line.split("|", 1)
if len(parts) != 2:
continue
filename, prompt = parts[0].strip(), parts[1].strip()
try:
self._generate_image(prompt, images_dir, filename, size)
except Exception:
# 图片生成失败不中断管线
pass
progress = 40 + int((i + 1) / max(len(lines), 1) * 10)
self._update("generating_images", progress, f"图片生成中 ({i+1}/{len(lines)})")
def _generate_image(self, prompt: str, output_dir: Path, filename: str, size: str):
"""生成单张图片:优先用万相,失败回退到 image_gen.py 脚本"""
backend = config.IMAGE_BACKEND
if backend == "wanx":
wanx_client.generate(
prompt=prompt,
output_dir=output_dir,
filename=filename,
size=size,
)
else:
# 回退到 PPT Master 自带的 image_gen.py 脚本
self._run_script(
"image_gen.py", prompt,
"--aspect_ratio", "16:9",
"--image_size", "1K",
"-o", str(output_dir),
"--filename", filename,
)
# ==================== Step 6: 执行器阶段 ====================
def _step6_executor(self):
"""执行器阶段:逐页生成 SVG"""
self._update("generating_svg", 50, "AI 正在逐页生成幻灯片...")
spec_lock_path = self.project_path / "spec_lock.md"
design_spec_path = self.project_path / "design_spec.md"
spec_lock = spec_lock_path.read_text(encoding="utf-8") if spec_lock_path.exists() else ""
design_spec = design_spec_path.read_text(encoding="utf-8") if design_spec_path.exists() else ""
# 读取执行器参考文档
style = self.task_config.get("style", "general")
executor_base = (config.SKILL_DIR / "references" / "executor-base.md").read_text(encoding="utf-8")
shared_standards = (config.SKILL_DIR / "references" / "shared-standards.md").read_text(encoding="utf-8")
style_file_map = {
"general": "executor-general.md",
"consultant": "executor-consultant.md",
"consultant-top": "executor-consultant-top.md",
}
style_file = style_file_map.get(style, "executor-general.md")
executor_style = (config.SKILL_DIR / "references" / style_file).read_text(encoding="utf-8")
# 从 spec_lock 提取页数
page_count = self.task_config.get("page_count", 10)
# 尝试从 spec_lock 中提取实际规划页数
page_match = re.search(r"page_count:\s*(\d+)", spec_lock)
if page_match:
page_count = int(page_match.group(1))
svg_output_dir = self.project_path / "svg_output"
svg_output_dir.mkdir(exist_ok=True)
# 逐页生成 SVG
for page_num in range(1, page_count + 1):
self._generate_svg_page(
page_num, page_count,
spec_lock, design_spec,
executor_base, shared_standards, executor_style,
svg_output_dir,
)
progress = 50 + int(page_num / page_count * 30)
self._update("generating_svg", progress, f"生成第 {page_num}/{page_count}")
# 生成演讲者备注
self._generate_speaker_notes(design_spec, page_count)
# SVG 质量检查
try:
self._run_script("svg_quality_checker.py", str(self.project_path))
except RuntimeError:
pass # 质量检查失败不阻断
def _generate_svg_page(
self, page_num: int, total_pages: int,
spec_lock: str, design_spec: str,
executor_base: str, shared_standards: str, executor_style: str,
svg_output_dir: Path,
):
"""生成单页 SVG"""
# 获取画布尺寸
fmt = self.task_config.get("format", "ppt169")
canvas_map = {
"ppt169": (1280, 720),
"ppt43": (1024, 768),
"xhs": (1080, 1440),
"story": (1080, 1920),
}
width, height = canvas_map.get(fmt, (1280, 720))
# 构建每页的 prompt
messages = [
{
"role": "system",
"content": (
f"你是 PPT Master 的执行器角色。现在需要生成第 {page_num}/{total_pages} 页的 SVG 代码。\n\n"
f"## 关键技术约束(摘要)\n"
f"- 画布: {width}x{height}\n"
f"- viewBox=\"0 0 {width} {height}\"\n"
f"- 所有文本必须用 <text> 元素,禁止 <foreignObject>\n"
f"- 长文本必须手动分行(多个 <text> 或 <tspan dy=\"...\"\n"
f"- 图片使用 <image href=\"../images/filename.ext\">\n"
f"- 每个顶层元素用 <g id=\"elem_N\"> 包裹\n"
f"- 颜色、字体必须严格来自 spec_lock\n\n"
f"## 执行器规范(摘要)\n{executor_base[:6000]}\n\n"
f"## 风格规范(摘要)\n{executor_style[:4000]}\n\n"
f"## 技术标准(摘要)\n{shared_standards[:6000]}\n\n"
f"直接输出完整的 SVG 代码,不要用 ```svg 包裹,不要加解释。"
),
},
{
"role": "user",
"content": (
f"## spec_lock.md\n\n{spec_lock[:12000]}\n\n"
f"## design_spec.md(第 {page_num} 页相关内容)\n\n{design_spec[:8000]}\n\n"
f"请生成第 {page_num} 页的完整 SVG。"
),
},
]
svg_content = llm_client.chat(messages, temperature=0.6, max_tokens=16384)
# 清理:去掉可能的 ```svg 包裹
svg_content = svg_content.strip()
if svg_content.startswith("```"):
lines = svg_content.split("\n")
svg_content = "\n".join(lines[1:])
if svg_content.endswith("```"):
svg_content = svg_content[:-3].strip()
# 确保以 <svg 开头
svg_start = svg_content.find("<svg")
if svg_start > 0:
svg_content = svg_content[svg_start:]
# 写入文件
filename = f"page_{page_num:02d}.svg"
(svg_output_dir / filename).write_text(svg_content, encoding="utf-8")
def _generate_speaker_notes(self, design_spec: str, page_count: int):
"""生成演讲者备注"""
notes_dir = self.project_path / "notes"
notes_dir.mkdir(exist_ok=True)
messages = [
{
"role": "system",
"content": (
"根据设计规范为每一页生成演讲者备注。\n"
"格式为每页用 ## Page N 分隔,内容为该页的讲解要点,100-200字。\n"
"直接输出,不加额外说明。"
),
},
{
"role": "user",
"content": f"{page_count} 页。设计规范:\n\n{design_spec[:15000]}",
},
]
notes = llm_client.chat(messages, temperature=0.7, max_tokens=8192)
(notes_dir / "total.md").write_text(notes, encoding="utf-8")
# ==================== Step 7: 后处理与导出 ====================
def _step7_export(self) -> Path:
"""后处理并导出 PPTX"""
self._update("exporting", 85, "正在处理并导出 PPTX...")
# Step 7.1: 拆分演讲者备注
try:
self._run_script("total_md_split.py", str(self.project_path))
except RuntimeError:
pass # 备注拆分失败不阻断
self._update("exporting", 88, "清理无效图片引用...")
self._strip_missing_images()
self._update("exporting", 90, "SVG 后处理中...")
# Step 7.2: SVG 后处理
self._run_script("finalize_svg.py", str(self.project_path))
# Step 7.2.5: 验证并修复无效 SVG(LLM 可能生成格式错误的 XML)
self._validate_and_repair_svgs()
self._update("exporting", 95, "导出 PPTX 文件...")
# Step 7.3: 导出 PPTX
self._run_script("svg_to_pptx.py", str(self.project_path))
# 查找输出文件
exports_dir = self.project_path / "exports"
if not exports_dir.exists():
# 有些版本输出到上级 exports 目录
exports_dir = config.PROJECTS_DIR / "exports"
pptx_files = list(exports_dir.glob("*.pptx")) if exports_dir.exists() else []
if not pptx_files:
# 搜索项目目录下所有 pptx
pptx_files = list(self.project_path.rglob("*.pptx"))
if not pptx_files:
raise RuntimeError("PPTX 导出失败:找不到输出文件")
# 取最新的 pptx
output_pptx = max(pptx_files, key=lambda f: f.stat().st_mtime)
# 复制到统一输出目录
final_path = config.OUTPUT_DIR / f"{self.task_id}.pptx"
shutil.copy2(output_pptx, final_path)
return final_path
def _strip_missing_images(self):
"""移除 SVG 中引用不存在图片文件的 <image> 元素"""
svg_dir = self.project_path / "svg_output"
if not svg_dir.exists():
return
import xml.etree.ElementTree as ET
ns = {"svg": "http://www.w3.org/2000/svg", "xlink": "http://www.w3.org/1999/xlink"}
ET.register_namespace("", "http://www.w3.org/2000/svg")
ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
for svg_file in sorted(svg_dir.glob("*.svg")):
try:
tree = ET.parse(svg_file)
root = tree.getroot()
changed = False
# 递归查找所有 <image> 元素
for parent in root.iter():
to_remove = []
for child in parent:
tag = child.tag.split("}")[-1] if "}" in child.tag else child.tag
if tag != "image":
continue
href = (child.get("href") or
child.get("{http://www.w3.org/1999/xlink}href") or "")
# 跳过 data URI 和绝对 URL
if href.startswith("data:") or href.startswith("http"):
continue
# 检查相对路径是否存在
img_path = (svg_dir / href).resolve()
if not img_path.exists():
to_remove.append(child)
for elem in to_remove:
parent.remove(elem)
changed = True
if changed:
tree.write(svg_file, encoding="unicode", xml_declaration=True)
except Exception:
pass # 解析失败跳过
def _validate_and_repair_svgs(self):
"""验证所有 SVG 文件,修复或移除无效文件"""
import xml.etree.ElementTree as ET
svg_dir = self.project_path / "svg_output"
if not svg_dir.exists():
return
for svg_file in sorted(svg_dir.glob("*.svg")):
try:
ET.parse(svg_file)
except ET.ParseError:
# 尝试修复常见问题
repaired = self._try_repair_svg(svg_file)
if not repaired:
# 无法修复,移除该文件让导出跳过这页
svg_file.unlink()
print(f"[Pipeline] 移除无效 SVG: {svg_file.name}")
def _try_repair_svg(self, svg_file: Path) -> bool:
"""尝试修复无效 SVG,返回是否修复成功"""
import re
import xml.etree.ElementTree as ET
content = svg_file.read_text(encoding="utf-8")
# 修复0: 移除自闭合标签后的同名多余闭合标签(LLM 常见错误)
# 例: <feDropShadow .../>\n</feDropShadow> → <feDropShadow .../>
content = re.sub(
r'<(\w+)\b([^>]*)/>\s*\n\s*</\1>',
r'<\1\2/>',
content,
)
# 修复1: 确保以 </svg> 结尾
if "</svg>" not in content:
# 找到最后一个完整的闭合标签,截断并添加 </svg>
last_close = content.rfind("/>")
if last_close == -1:
last_close = content.rfind("</")
if last_close > 0:
# 找到该行末尾
line_end = content.find("\n", last_close)
if line_end > 0:
content = content[:line_end + 1] + "</svg>\n"
else:
content = content[:last_close + 2] + "\n</svg>\n"
# 修复2: 移除 </svg> 之后的垃圾内容
svg_end = content.find("</svg>")
if svg_end > 0:
content = content[:svg_end + len("</svg>")]
# 修复3: 尝试逐行移除导致解析失败的行
try:
ET.fromstring(content)
svg_file.write_text(content, encoding="utf-8")
return True
except ET.ParseError:
pass
# 修复4: 暴力截断 — 找到最后一个有效的闭合 </g> 或 </text>,截断后闭合 </svg>
for tag in ("</g>", "</text>", "</rect>", "</circle>"):
last_pos = content.rfind(tag)
if last_pos > 0:
candidate = content[:last_pos + len(tag)] + "\n</svg>"
try:
ET.fromstring(candidate)
svg_file.write_text(candidate, encoding="utf-8")
return True
except ET.ParseError:
continue
return False
def _count_pages(self) -> int:
"""统计生成的页数"""
svg_dir = self.project_path / "svg_output"
if svg_dir.exists():
return len(list(svg_dir.glob("*.svg")))
return 0