"""PPT Master 管线封装 — 核心任务处理器""" import subprocess import sys import json import shutil import re from pathlib import Path from typing import Callable from config import config from llm_client import llm_client from wanx_client import wanx_client from db import update_task_status class PPTPipeline: """封装 PPT Master 完整管线""" def __init__(self, task_id: str, task: dict, redis_callback=None): self.task_id = task_id self.task = task self.task_config = task.get("config", {}) if isinstance(task.get("config"), dict) else json.loads(task.get("config", "{}")) self.project_path: Path | None = None self.project_name: str = "" self._redis_callback = redis_callback def run(self): """执行完整管线""" try: self._update("processing", 5, "开始处理任务...") # Step 1: 源内容处理 source_md = self._step1_process_source() # Step 2: 项目初始化 self._step2_init_project() # Step 3: 跳过模板(使用自由设计) # Step 4: 策略师阶段 self._step4_strategist(source_md) # Step 5: 图片获取(如配置) if self.task_config.get("with_images", True): self._step5_images() # Step 6: 执行器阶段(SVG 生成) self._step6_executor() # Step 7: 后处理与导出 output_path = self._step7_export() # 完成 self._update( "completed", 100, "PPT 生成完成", output_file=str(output_path), page_count=self._count_pages(), ) except Exception as e: self._update("failed", progress=0, status_message="生成失败", error_message=str(e)) raise def _update(self, status: str, progress: int = None, status_message: str = None, **kwargs): """更新任务状态""" update_task_status( self.task_id, status=status, progress=progress, status_message=status_message, **kwargs, ) # 同步更新 Redis(供前端快速轮询) if self._redis_callback and progress is not None: self._redis_callback(self.task_id, status, progress, status_message or "") def _run_script(self, script_name: str, *args, cwd: Path = None) -> str: """运行 PPT Master Python 脚本""" script_path = config.SCRIPTS_DIR / script_name cmd = [sys.executable, str(script_path)] + [str(a) for a in args] result = subprocess.run( cmd, capture_output=True, text=True, cwd=str(cwd or config.PPT_MASTER_PATH), env=self._build_env(), timeout=600, ) if result.returncode != 0: raise RuntimeError( f"脚本 {script_name} 执行失败:\n" f"stdout: {result.stdout[-2000:]}\n" f"stderr: {result.stderr[-2000:]}" ) return result.stdout def _build_env(self) -> dict: """构建脚本执行环境变量""" import os env = os.environ.copy() env["OPENAI_API_KEY"] = config.OPENAI_API_KEY env["OPENAI_BASE_URL"] = config.OPENAI_BASE_URL env["OPENAI_MODEL"] = config.OPENAI_MODEL if config.IMAGE_API_KEY: env["IMAGE_API_KEY"] = config.IMAGE_API_KEY env["IMAGE_BACKEND"] = config.IMAGE_BACKEND return env # ==================== Step 1: 源内容处理 ==================== def _step1_process_source(self) -> str: """处理源文件/内容,返回 Markdown 文本""" self._update("converting", 10, "正在转换源文件...") source_type = self.task.get("source_type", "text") if source_type == "text": return self.task.get("source_content", "") elif source_type == "url": url = self.task.get("source_content", "") output = self._run_script("source_to_md/web_to_md.py", url) # web_to_md.py 输出转换后的 md 文件路径 md_path = self._extract_output_path(output) if md_path and Path(md_path).exists(): return Path(md_path).read_text(encoding="utf-8") return output elif source_type == "file": file_path = self.task.get("source_file", "") if not file_path or not Path(file_path).exists(): raise FileNotFoundError(f"源文件不存在: {file_path}") suffix = Path(file_path).suffix.lower() if suffix == ".pdf": output = self._run_script("source_to_md/pdf_to_md.py", file_path) elif suffix in (".docx", ".doc", ".html", ".epub"): output = self._run_script("source_to_md/doc_to_md.py", file_path) elif suffix in (".xlsx", ".xlsm"): output = self._run_script("source_to_md/excel_to_md.py", file_path) elif suffix in (".pptx",): output = self._run_script("source_to_md/ppt_to_md.py", file_path) elif suffix == ".md": return Path(file_path).read_text(encoding="utf-8") else: output = self._run_script("source_to_md/doc_to_md.py", file_path) md_path = self._extract_output_path(output) if md_path and Path(md_path).exists(): return Path(md_path).read_text(encoding="utf-8") return output return "" def _extract_output_path(self, script_output: str) -> str | None: """从脚本输出中提取生成的文件路径""" for line in script_output.strip().split("\n"): line = line.strip() if line.endswith(".md") and Path(line).exists(): return line # 匹配 "Output: /path/to/file.md" 格式 match = re.search(r"(?:Output|Saved|Written):\s*(.+\.md)", line, re.IGNORECASE) if match: path = match.group(1).strip() if Path(path).exists(): return path return None def _resolve_project_path(self, script_output: str) -> Path: """从 project_manager.py 输出中解析实际项目路径""" # 方式 1: 从输出中解析 "Project created: projects/xxx" 或 "[OK] Project initialized: projects/xxx" for line in script_output.strip().split("\n"): for prefix in ("Project created:", "[OK] Project initialized:"): if prefix in line: rel = line.split(prefix, 1)[1].strip() candidate = config.PROJECTS_DIR / rel if candidate.exists(): return candidate # 方式 2: 精确名称匹配 for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR): candidate = base / self.project_name if candidate.exists(): return candidate # 方式 3: glob 模糊匹配(project_manager 会追加格式和日期后缀) for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR): if base.exists(): matches = sorted(base.glob(f"{self.project_name}*"), key=lambda p: p.stat().st_mtime, reverse=True) if matches: return matches[0] raise RuntimeError(f"项目初始化失败,找不到项目目录: {script_output}") # ==================== Step 2: 项目初始化 ==================== def _step2_init_project(self): """初始化 PPT Master 项目""" self._update("processing", 15, "初始化项目...") # 生成项目名 title = self.task.get("title", "untitled") # 清理文件名 safe_name = re.sub(r'[^\w\u4e00-\u9fff-]', '_', title)[:50] self.project_name = f"govai_{safe_name}_{self.task_id[:8]}" fmt = self.task_config.get("format", "ppt169") output = self._run_script( "project_manager.py", "init", self.project_name, "--format", fmt, cwd=config.PROJECTS_DIR, ) # 从输出中提取实际项目路径 self.project_path = self._resolve_project_path(output) update_task_status(self.task_id, status="processing", project_path=str(self.project_path)) # ==================== Step 4: 策略师阶段 ==================== def _step4_strategist(self, source_md: str): """策略师阶段:生成设计规范和内容大纲""" self._update("designing", 20, "AI 正在分析内容并设计方案...") # 读取策略师参考文档 strategist_ref = (config.SKILL_DIR / "references" / "strategist.md").read_text(encoding="utf-8") design_spec_ref = (config.SKILL_DIR / "templates" / "design_spec_reference.md").read_text(encoding="utf-8") spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8") page_count = self.task_config.get("page_count", 10) style = self.task_config.get("style", "general") fmt = self.task_config.get("format", "ppt169") language = self.task_config.get("language", "zh") # 构建策略师 prompt messages = [ { "role": "system", "content": ( "你是 PPT Master 的策略师角色。你的任务是根据用户提供的源内容," "生成完整的 design_spec.md(设计规范)和 spec_lock.md(执行锁定)。\n\n" "你必须严格遵循以下参考模板的结构来生成输出:\n\n" "## 策略师角色定义(摘要)\n" f"{strategist_ref[:8000]}\n\n" "## design_spec.md 参考模板\n" f"{design_spec_ref[:10000]}\n\n" "## spec_lock.md 参考模板\n" f"{spec_lock_ref[:8000]}\n\n" "请直接输出两个文件的完整内容,用以下分隔符分开:\n" "===DESIGN_SPEC_START===\n[design_spec.md 内容]\n===DESIGN_SPEC_END===\n" "===SPEC_LOCK_START===\n[spec_lock.md 内容]\n===SPEC_LOCK_END===\n" ), }, { "role": "user", "content": ( f"请根据以下源内容生成演示文稿的设计规范。\n\n" f"## 生成要求\n" f"- 画布格式: {fmt}\n" f"- 目标页数: {page_count} 页\n" f"- 设计风格: {style}\n" f"- 语言: {language}\n\n" f"## 源内容\n\n{source_md[:50000]}" ), }, ] response = llm_client.chat(messages, temperature=0.7, max_tokens=16384) # 解析输出 design_spec = self._extract_section(response, "DESIGN_SPEC_START", "DESIGN_SPEC_END") spec_lock = self._extract_section(response, "SPEC_LOCK_START", "SPEC_LOCK_END") if not design_spec: design_spec = response # 如果解析失败,整体作为 design_spec if not spec_lock: # 如果没有分离出 spec_lock,再次调用 LLM 单独生成 spec_lock = self._generate_spec_lock(design_spec, source_md) # 写入文件 (self.project_path / "design_spec.md").write_text(design_spec, encoding="utf-8") (self.project_path / "spec_lock.md").write_text(spec_lock, encoding="utf-8") self._update("designing", 35, "设计规范生成完成") def _generate_spec_lock(self, design_spec: str, source_md: str) -> str: """单独生成 spec_lock.md""" spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8") messages = [ { "role": "system", "content": ( "根据已生成的 design_spec.md,生成对应的 spec_lock.md(执行锁定文件)。\n" "严格遵循以下参考模板结构:\n\n" f"{spec_lock_ref[:8000]}\n\n" "直接输出 spec_lock.md 的完整内容,不要加任何额外说明。" ), }, { "role": "user", "content": f"## design_spec.md\n\n{design_spec[:15000]}", }, ] return llm_client.chat(messages, temperature=0.5, max_tokens=8192) def _extract_section(self, text: str, start_marker: str, end_marker: str) -> str: """从文本中提取指定标记之间的内容""" start = text.find(f"==={start_marker}===") end = text.find(f"==={end_marker}===") if start == -1 or end == -1: return "" start += len(f"==={start_marker}===") return text[start:end].strip() # ==================== Step 5: 图片获取 ==================== def _step5_images(self): """图片获取阶段(通义万相 / 脚本回退)""" self._update("generating_images", 40, "正在生成/搜索图片...") # 从 design_spec 中提取图片需求 design_spec_path = self.project_path / "design_spec.md" if not design_spec_path.exists(): return design_spec = design_spec_path.read_text(encoding="utf-8") # 检查是否有图片需求 if "Acquire Via: ai" not in design_spec and "Acquire Via: web" not in design_spec: self._update("generating_images", 45, "无需额外图片") return # 使用 LLM 生成图片 prompt messages = [ { "role": "system", "content": ( "根据设计规范中的图片需求,生成图片生成 prompt。\n" "每行一个,格式为: filename|prompt\n" "例如: cover_bg.png|A modern city skyline at sunset, soft gradient sky\n" "只输出需要生成的图片,不要解释。" ), }, { "role": "user", "content": design_spec[:10000], }, ] response = llm_client.chat(messages, temperature=0.5, max_tokens=4096) images_dir = self.project_path / "images" images_dir.mkdir(exist_ok=True) # 根据画布格式确定图片尺寸 fmt = self.task_config.get("format", "ppt169") wanx_size_map = { "ppt169": "1024*576", "ppt43": "1024*768", "xhs": "720*1280", "story": "720*1280", } size = wanx_size_map.get(fmt, "1024*576") # 逐个生成图片 lines = [l.strip() for l in response.strip().split("\n") if "|" in l] for i, line in enumerate(lines[:10]): # 最多 10 张 parts = line.split("|", 1) if len(parts) != 2: continue filename, prompt = parts[0].strip(), parts[1].strip() try: self._generate_image(prompt, images_dir, filename, size) except Exception: # 图片生成失败不中断管线 pass progress = 40 + int((i + 1) / max(len(lines), 1) * 10) self._update("generating_images", progress, f"图片生成中 ({i+1}/{len(lines)})") def _generate_image(self, prompt: str, output_dir: Path, filename: str, size: str): """生成单张图片:优先用万相,失败回退到 image_gen.py 脚本""" backend = config.IMAGE_BACKEND if backend == "wanx": wanx_client.generate( prompt=prompt, output_dir=output_dir, filename=filename, size=size, ) else: # 回退到 PPT Master 自带的 image_gen.py 脚本 self._run_script( "image_gen.py", prompt, "--aspect_ratio", "16:9", "--image_size", "1K", "-o", str(output_dir), "--filename", filename, ) # ==================== Step 6: 执行器阶段 ==================== def _step6_executor(self): """执行器阶段:逐页生成 SVG""" self._update("generating_svg", 50, "AI 正在逐页生成幻灯片...") spec_lock_path = self.project_path / "spec_lock.md" design_spec_path = self.project_path / "design_spec.md" spec_lock = spec_lock_path.read_text(encoding="utf-8") if spec_lock_path.exists() else "" design_spec = design_spec_path.read_text(encoding="utf-8") if design_spec_path.exists() else "" # 读取执行器参考文档 style = self.task_config.get("style", "general") executor_base = (config.SKILL_DIR / "references" / "executor-base.md").read_text(encoding="utf-8") shared_standards = (config.SKILL_DIR / "references" / "shared-standards.md").read_text(encoding="utf-8") style_file_map = { "general": "executor-general.md", "consultant": "executor-consultant.md", "consultant-top": "executor-consultant-top.md", } style_file = style_file_map.get(style, "executor-general.md") executor_style = (config.SKILL_DIR / "references" / style_file).read_text(encoding="utf-8") # 从 spec_lock 提取页数 page_count = self.task_config.get("page_count", 10) # 尝试从 spec_lock 中提取实际规划页数 page_match = re.search(r"page_count:\s*(\d+)", spec_lock) if page_match: page_count = int(page_match.group(1)) svg_output_dir = self.project_path / "svg_output" svg_output_dir.mkdir(exist_ok=True) # 逐页生成 SVG for page_num in range(1, page_count + 1): self._generate_svg_page( page_num, page_count, spec_lock, design_spec, executor_base, shared_standards, executor_style, svg_output_dir, ) progress = 50 + int(page_num / page_count * 30) self._update("generating_svg", progress, f"生成第 {page_num}/{page_count} 页") # 生成演讲者备注 self._generate_speaker_notes(design_spec, page_count) # SVG 质量检查 try: self._run_script("svg_quality_checker.py", str(self.project_path)) except RuntimeError: pass # 质量检查失败不阻断 def _generate_svg_page( self, page_num: int, total_pages: int, spec_lock: str, design_spec: str, executor_base: str, shared_standards: str, executor_style: str, svg_output_dir: Path, ): """生成单页 SVG""" # 获取画布尺寸 fmt = self.task_config.get("format", "ppt169") canvas_map = { "ppt169": (1280, 720), "ppt43": (1024, 768), "xhs": (1080, 1440), "story": (1080, 1920), } width, height = canvas_map.get(fmt, (1280, 720)) # 构建每页的 prompt messages = [ { "role": "system", "content": ( f"你是 PPT Master 的执行器角色。现在需要生成第 {page_num}/{total_pages} 页的 SVG 代码。\n\n" f"## 关键技术约束(摘要)\n" f"- 画布: {width}x{height}\n" f"- viewBox=\"0 0 {width} {height}\"\n" f"- 所有文本必须用 元素,禁止 \n" f"- 长文本必须手动分行(多个 \n" f"- 每个顶层元素用 包裹\n" f"- 颜色、字体必须严格来自 spec_lock\n\n" f"## 执行器规范(摘要)\n{executor_base[:6000]}\n\n" f"## 风格规范(摘要)\n{executor_style[:4000]}\n\n" f"## 技术标准(摘要)\n{shared_standards[:6000]}\n\n" f"直接输出完整的 SVG 代码,不要用 ```svg 包裹,不要加解释。" ), }, { "role": "user", "content": ( f"## spec_lock.md\n\n{spec_lock[:12000]}\n\n" f"## design_spec.md(第 {page_num} 页相关内容)\n\n{design_spec[:8000]}\n\n" f"请生成第 {page_num} 页的完整 SVG。" ), }, ] svg_content = llm_client.chat(messages, temperature=0.6, max_tokens=16384) # 清理:去掉可能的 ```svg 包裹 svg_content = svg_content.strip() if svg_content.startswith("```"): lines = svg_content.split("\n") svg_content = "\n".join(lines[1:]) if svg_content.endswith("```"): svg_content = svg_content[:-3].strip() # 确保以 0: svg_content = svg_content[svg_start:] # 写入文件 filename = f"page_{page_num:02d}.svg" (svg_output_dir / filename).write_text(svg_content, encoding="utf-8") def _generate_speaker_notes(self, design_spec: str, page_count: int): """生成演讲者备注""" notes_dir = self.project_path / "notes" notes_dir.mkdir(exist_ok=True) messages = [ { "role": "system", "content": ( "根据设计规范为每一页生成演讲者备注。\n" "格式为每页用 ## Page N 分隔,内容为该页的讲解要点,100-200字。\n" "直接输出,不加额外说明。" ), }, { "role": "user", "content": f"共 {page_count} 页。设计规范:\n\n{design_spec[:15000]}", }, ] notes = llm_client.chat(messages, temperature=0.7, max_tokens=8192) (notes_dir / "total.md").write_text(notes, encoding="utf-8") # ==================== Step 7: 后处理与导出 ==================== def _step7_export(self) -> Path: """后处理并导出 PPTX""" self._update("exporting", 85, "正在处理并导出 PPTX...") # Step 7.1: 拆分演讲者备注 try: self._run_script("total_md_split.py", str(self.project_path)) except RuntimeError: pass # 备注拆分失败不阻断 self._update("exporting", 88, "清理无效图片引用...") self._strip_missing_images() self._update("exporting", 90, "SVG 后处理中...") # Step 7.2: SVG 后处理 self._run_script("finalize_svg.py", str(self.project_path)) # Step 7.2.5: 验证并修复无效 SVG(LLM 可能生成格式错误的 XML) self._validate_and_repair_svgs() self._update("exporting", 95, "导出 PPTX 文件...") # Step 7.3: 导出 PPTX self._run_script("svg_to_pptx.py", str(self.project_path)) # 查找输出文件 exports_dir = self.project_path / "exports" if not exports_dir.exists(): # 有些版本输出到上级 exports 目录 exports_dir = config.PROJECTS_DIR / "exports" pptx_files = list(exports_dir.glob("*.pptx")) if exports_dir.exists() else [] if not pptx_files: # 搜索项目目录下所有 pptx pptx_files = list(self.project_path.rglob("*.pptx")) if not pptx_files: raise RuntimeError("PPTX 导出失败:找不到输出文件") # 取最新的 pptx output_pptx = max(pptx_files, key=lambda f: f.stat().st_mtime) # 复制到统一输出目录 final_path = config.OUTPUT_DIR / f"{self.task_id}.pptx" shutil.copy2(output_pptx, final_path) return final_path def _strip_missing_images(self): """移除 SVG 中引用不存在图片文件的 元素""" svg_dir = self.project_path / "svg_output" if not svg_dir.exists(): return import xml.etree.ElementTree as ET ns = {"svg": "http://www.w3.org/2000/svg", "xlink": "http://www.w3.org/1999/xlink"} ET.register_namespace("", "http://www.w3.org/2000/svg") ET.register_namespace("xlink", "http://www.w3.org/1999/xlink") for svg_file in sorted(svg_dir.glob("*.svg")): try: tree = ET.parse(svg_file) root = tree.getroot() changed = False # 递归查找所有 元素 for parent in root.iter(): to_remove = [] for child in parent: tag = child.tag.split("}")[-1] if "}" in child.tag else child.tag if tag != "image": continue href = (child.get("href") or child.get("{http://www.w3.org/1999/xlink}href") or "") # 跳过 data URI 和绝对 URL if href.startswith("data:") or href.startswith("http"): continue # 检查相对路径是否存在 img_path = (svg_dir / href).resolve() if not img_path.exists(): to_remove.append(child) for elem in to_remove: parent.remove(elem) changed = True if changed: tree.write(svg_file, encoding="unicode", xml_declaration=True) except Exception: pass # 解析失败跳过 def _validate_and_repair_svgs(self): """验证所有 SVG 文件,修复或移除无效文件""" import xml.etree.ElementTree as ET svg_dir = self.project_path / "svg_output" if not svg_dir.exists(): return for svg_file in sorted(svg_dir.glob("*.svg")): try: ET.parse(svg_file) except ET.ParseError: # 尝试修复常见问题 repaired = self._try_repair_svg(svg_file) if not repaired: # 无法修复,移除该文件让导出跳过这页 svg_file.unlink() print(f"[Pipeline] 移除无效 SVG: {svg_file.name}") def _try_repair_svg(self, svg_file: Path) -> bool: """尝试修复无效 SVG,返回是否修复成功""" import re import xml.etree.ElementTree as ET content = svg_file.read_text(encoding="utf-8") # 修复0: 移除自闭合标签后的同名多余闭合标签(LLM 常见错误) # 例: \n → content = re.sub( r'<(\w+)\b([^>]*)/>\s*\n\s*', r'<\1\2/>', content, ) # 修复1: 确保以 结尾 if "" not in content: # 找到最后一个完整的闭合标签,截断并添加 last_close = content.rfind("/>") if last_close == -1: last_close = content.rfind(" 0: # 找到该行末尾 line_end = content.find("\n", last_close) if line_end > 0: content = content[:line_end + 1] + "\n" else: content = content[:last_close + 2] + "\n\n" # 修复2: 移除 之后的垃圾内容 svg_end = content.find("") if svg_end > 0: content = content[:svg_end + len("")] # 修复3: 尝试逐行移除导致解析失败的行 try: ET.fromstring(content) svg_file.write_text(content, encoding="utf-8") return True except ET.ParseError: pass # 修复4: 暴力截断 — 找到最后一个有效的闭合 ,截断后闭合 for tag in ("", "", "", ""): last_pos = content.rfind(tag) if last_pos > 0: candidate = content[:last_pos + len(tag)] + "\n" try: ET.fromstring(candidate) svg_file.write_text(candidate, encoding="utf-8") return True except ET.ParseError: continue return False def _count_pages(self) -> int: """统计生成的页数""" svg_dir = self.project_path / "svg_output" if svg_dir.exists(): return len(list(svg_dir.glob("*.svg"))) return 0