Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
# PPT Worker 配置
|
||||
|
||||
# 服务配置
|
||||
WORKER_HOST=0.0.0.0
|
||||
WORKER_PORT=8090
|
||||
WORKER_CONCURRENCY=2
|
||||
|
||||
# 数据库
|
||||
DATABASE_URL=postgres://postgres:postgres@localhost:5432/govai
|
||||
|
||||
# Redis
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# PPT Master 路径
|
||||
PPT_MASTER_PATH=/Users/freedak/Documents/go-new/ppt-master
|
||||
|
||||
# 文件存储
|
||||
UPLOAD_DIR=/tmp/govai/uploads
|
||||
OUTPUT_DIR=/tmp/govai/outputs
|
||||
PROJECTS_DIR=/tmp/govai/ppt-projects
|
||||
|
||||
# LLM 配置(通义千问,DashScope OpenAI 兼容接口)
|
||||
LLM_PROVIDER=openai
|
||||
OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxx
|
||||
OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
OPENAI_MODEL=qwen-max
|
||||
|
||||
# 图片生成 — 通义万相(DashScope 异步 API)
|
||||
# 与 LLM 共用同一个 DashScope API Key,无需额外配置
|
||||
IMAGE_BACKEND=wanx
|
||||
IMAGE_API_KEY=
|
||||
WANX_MODEL=wanx-v1
|
||||
@@ -0,0 +1,18 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 系统依赖
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libpango1.0-dev libcairo2-dev libffi-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8090
|
||||
|
||||
CMD ["python", "app.py"]
|
||||
@@ -0,0 +1,94 @@
|
||||
# PPT Worker 微服务
|
||||
|
||||
PPT Master 管线的 HTTP 服务封装,作为 GovAI 平台的 PPT 生成后端。
|
||||
|
||||
## 架构
|
||||
|
||||
```
|
||||
GovAI 前端 → Go 后端 → Redis 队列 → PPT Worker → PPTX 文件
|
||||
```
|
||||
|
||||
Worker 接收任务后执行完整 PPT Master 管线:
|
||||
1. **源内容转换** — PDF/DOCX/URL/文本 → Markdown
|
||||
2. **项目初始化** — 创建 PPT Master 项目结构
|
||||
3. **策略师阶段** — LLM 生成设计规范 + 执行锁定
|
||||
4. **图片获取** — AI 生图(可选)
|
||||
5. **执行器阶段** — LLM 逐页生成 SVG
|
||||
6. **后处理导出** — SVG 质量检查 → finalize → 导出 PPTX
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 前置条件
|
||||
|
||||
- Python 3.10+
|
||||
- PostgreSQL(GovAI 数据库)
|
||||
- Redis
|
||||
- PPT Master 项目(`/Users/freedak/Documents/go-new/ppt-master`)
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
cd ppt-worker
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 配置
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# 编辑 .env,填入数据库、Redis、LLM API Key 等配置
|
||||
```
|
||||
|
||||
### 启动
|
||||
|
||||
```bash
|
||||
# 同时启动 HTTP API + Worker
|
||||
python app.py
|
||||
|
||||
# 或仅启动 Worker(不含 HTTP API)
|
||||
python worker.py
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d ppt-worker
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| POST | `/api/tasks` | 创建文本/URL 任务 |
|
||||
| POST | `/api/tasks/upload` | 创建带文件上传的任务 |
|
||||
| GET | `/api/tasks/{id}` | 查询任务状态 |
|
||||
| GET | `/api/tasks/{id}/download` | 下载 PPTX 文件 |
|
||||
| GET | `/health` | 健康检查 |
|
||||
|
||||
## 环境变量
|
||||
|
||||
| 变量 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `WORKER_PORT` | 8090 | HTTP 服务端口 |
|
||||
| `WORKER_CONCURRENCY` | 2 | 并发处理数 |
|
||||
| `DATABASE_URL` | - | PostgreSQL 连接串 |
|
||||
| `REDIS_URL` | redis://localhost:6379/0 | Redis 连接串 |
|
||||
| `PPT_MASTER_PATH` | - | PPT Master 项目路径 |
|
||||
| `OPENAI_API_KEY` | - | DashScope API Key(LLM + 图片共用) |
|
||||
| `OPENAI_BASE_URL` | https://dashscope.aliyuncs.com/compatible-mode/v1 | 千问 OpenAI 兼容端点 |
|
||||
| `OPENAI_MODEL` | qwen-max | 千问模型(推荐 qwen-max / qwen-plus) |
|
||||
| `IMAGE_BACKEND` | wanx | 图片后端(`wanx` 通义万相 / 其他回退脚本) |
|
||||
| `WANX_MODEL` | wanx-v1 | 万相模型名 |
|
||||
| `IMAGE_API_KEY` | (空,复用 OPENAI_API_KEY) | 单独的图片 API Key(可选) |
|
||||
|
||||
### 千问模型推荐
|
||||
|
||||
| 模型 | 上下文 | 适用场景 |
|
||||
|------|--------|---------|
|
||||
| `qwen-max` | 128K | SVG 生成首选,质量最好 |
|
||||
| `qwen-plus` | 128K | 性价比之选 |
|
||||
| `qwen-long` | 10M | 超长文档输入 |
|
||||
| `qwen-turbo` | 128K | 速度优先 |
|
||||
@@ -0,0 +1,184 @@
|
||||
"""PPT Worker HTTP API — FastAPI 服务,供 Go 后端调用"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
import redis
|
||||
import uvicorn
|
||||
|
||||
from config import config
|
||||
from db import get_task, update_task_status
|
||||
from worker import PPTWorker
|
||||
|
||||
app = FastAPI(title="PPT Worker API", version="1.0.0")
|
||||
|
||||
# Redis 客户端
|
||||
rdb = redis.from_url(config.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
class CreateTaskRequest(BaseModel):
|
||||
user_id: str
|
||||
title: str
|
||||
source_type: str = "text" # text / file / url
|
||||
source_content: Optional[str] = None
|
||||
config: dict = {}
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
progress: int
|
||||
status_message: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
output_file: Optional[str] = None
|
||||
page_count: Optional[int] = None
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
@app.post("/api/tasks", response_model=dict)
|
||||
async def create_task(req: CreateTaskRequest):
|
||||
"""创建 PPT 生成任务"""
|
||||
import psycopg
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
with psycopg.connect(config.DATABASE_URL) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||||
(task_id, req.user_id, req.title, req.source_type, req.source_content,
|
||||
json.dumps(req.config)),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 推送到 Redis 队列
|
||||
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
||||
|
||||
return {"task_id": task_id, "status": "pending"}
|
||||
|
||||
|
||||
@app.post("/api/tasks/upload", response_model=dict)
|
||||
async def create_task_with_file(
|
||||
user_id: str = Form(...),
|
||||
title: str = Form(...),
|
||||
config_json: str = Form(default="{}"),
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""创建带文件上传的 PPT 生成任务"""
|
||||
import psycopg
|
||||
|
||||
config.ensure_dirs()
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 保存上传文件
|
||||
file_ext = Path(file.filename).suffix if file.filename else ".bin"
|
||||
saved_path = config.UPLOAD_DIR / f"{task_id}{file_ext}"
|
||||
with open(saved_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
task_config = json.loads(config_json)
|
||||
|
||||
with psycopg.connect(config.DATABASE_URL) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_file, config)
|
||||
VALUES (%s, %s, %s, 'file', %s, %s)""",
|
||||
(task_id, user_id, title, str(saved_path), json.dumps(task_config)),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 推送到 Redis 队列
|
||||
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
||||
|
||||
return {"task_id": task_id, "status": "pending"}
|
||||
|
||||
|
||||
@app.get("/api/tasks/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
"""查询任务状态"""
|
||||
# 先从 Redis 快速查询
|
||||
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
|
||||
cached = rdb.hgetall(key)
|
||||
|
||||
if cached:
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=cached.get("status", "unknown"),
|
||||
progress=int(cached.get("progress", 0)),
|
||||
status_message=cached.get("message"),
|
||||
)
|
||||
|
||||
# 回退到数据库查询
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=task["status"],
|
||||
progress=task["progress"],
|
||||
status_message=task.get("status_message"),
|
||||
error_message=task.get("error_message"),
|
||||
output_file=task.get("output_file"),
|
||||
page_count=task.get("page_count"),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/tasks/{task_id}/download")
|
||||
async def download_task_output(task_id: str):
|
||||
"""下载生成的 PPTX 文件"""
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
if task["status"] != "completed":
|
||||
raise HTTPException(status_code=400, detail="任务未完成")
|
||||
|
||||
output_file = task.get("output_file")
|
||||
if not output_file or not Path(output_file).exists():
|
||||
raise HTTPException(status_code=404, detail="输出文件不存在")
|
||||
|
||||
filename = f"{task.get('title', 'presentation')}.pptx"
|
||||
return FileResponse(
|
||||
path=output_file,
|
||||
filename=filename,
|
||||
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""健康检查"""
|
||||
return {"status": "ok", "service": "ppt-worker"}
|
||||
|
||||
|
||||
# ==================== 启动 ====================
|
||||
|
||||
def start_worker_thread():
|
||||
"""在后台线程中启动 Worker"""
|
||||
worker = PPTWorker()
|
||||
worker.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.ensure_dirs()
|
||||
|
||||
# 启动后台 Worker 线程
|
||||
worker_thread = threading.Thread(target=start_worker_thread, daemon=True)
|
||||
worker_thread.start()
|
||||
|
||||
# 启动 HTTP 服务
|
||||
uvicorn.run(app, host=config.HOST, port=config.PORT)
|
||||
@@ -0,0 +1,55 @@
|
||||
"""PPT Worker 配置模块"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
# 服务
|
||||
HOST: str = os.getenv("WORKER_HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("WORKER_PORT", "8090"))
|
||||
CONCURRENCY: int = int(os.getenv("WORKER_CONCURRENCY", "2"))
|
||||
|
||||
# 数据库
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "postgres://postgres:postgres@localhost:5432/govai")
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
TASK_QUEUE: str = "ppt:tasks"
|
||||
TASK_STATUS_PREFIX: str = "ppt:status:"
|
||||
|
||||
# PPT Master
|
||||
PPT_MASTER_PATH: Path = Path(os.getenv("PPT_MASTER_PATH", "/Users/freedak/Documents/go-new/ppt-master"))
|
||||
SKILL_DIR: Path = PPT_MASTER_PATH / "skills" / "ppt-master"
|
||||
SCRIPTS_DIR: Path = SKILL_DIR / "scripts"
|
||||
|
||||
# 文件存储
|
||||
UPLOAD_DIR: Path = Path(os.getenv("UPLOAD_DIR", "/tmp/govai/uploads"))
|
||||
OUTPUT_DIR: Path = Path(os.getenv("OUTPUT_DIR", "/tmp/govai/outputs"))
|
||||
PROJECTS_DIR: Path = Path(os.getenv("PROJECTS_DIR", "/tmp/govai/ppt-projects"))
|
||||
|
||||
# LLM(通义千问,DashScope OpenAI 兼容接口)
|
||||
LLM_PROVIDER: str = os.getenv("LLM_PROVIDER", "openai")
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
OPENAI_BASE_URL: str = os.getenv("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
OPENAI_MODEL: str = os.getenv("OPENAI_MODEL", "qwen-max")
|
||||
|
||||
# 图片生成(通义万相)
|
||||
IMAGE_BACKEND: str = os.getenv("IMAGE_BACKEND", "wanx")
|
||||
IMAGE_API_KEY: str = os.getenv("IMAGE_API_KEY", "") # 留空则复用 OPENAI_API_KEY
|
||||
WANX_MODEL: str = os.getenv("WANX_MODEL", "wanx-v1")
|
||||
DASHSCOPE_IMAGE_URL: str = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
|
||||
DASHSCOPE_TASK_URL: str = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||
|
||||
@classmethod
|
||||
def ensure_dirs(cls):
|
||||
"""确保所有必需目录存在"""
|
||||
cls.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
cls.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
cls.PROJECTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
config = Config()
|
||||
@@ -0,0 +1,79 @@
|
||||
"""数据库操作模块"""
|
||||
|
||||
import psycopg
|
||||
from datetime import datetime, timezone
|
||||
from config import config
|
||||
|
||||
|
||||
def get_connection():
|
||||
"""获取数据库连接"""
|
||||
return psycopg.connect(config.DATABASE_URL)
|
||||
|
||||
|
||||
def update_task_status(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
status_message: str = None,
|
||||
error_message: str = None,
|
||||
output_file: str = None,
|
||||
page_count: int = None,
|
||||
project_path: str = None,
|
||||
):
|
||||
"""更新任务状态"""
|
||||
fields = ["status = %(status)s", "updated_at = NOW()"]
|
||||
params = {"task_id": task_id, "status": status}
|
||||
|
||||
if progress is not None:
|
||||
fields.append("progress = %(progress)s")
|
||||
params["progress"] = progress
|
||||
|
||||
if status_message is not None:
|
||||
fields.append("status_message = %(status_message)s")
|
||||
params["status_message"] = status_message
|
||||
|
||||
if error_message is not None:
|
||||
fields.append("error_message = %(error_message)s")
|
||||
params["error_message"] = error_message
|
||||
|
||||
if output_file is not None:
|
||||
fields.append("output_file = %(output_file)s")
|
||||
params["output_file"] = output_file
|
||||
|
||||
if page_count is not None:
|
||||
fields.append("page_count = %(page_count)s")
|
||||
params["page_count"] = page_count
|
||||
|
||||
if project_path is not None:
|
||||
fields.append("project_path = %(project_path)s")
|
||||
params["project_path"] = project_path
|
||||
|
||||
if status == "processing":
|
||||
fields.append("started_at = NOW()")
|
||||
elif status in ("completed", "failed"):
|
||||
fields.append("completed_at = NOW()")
|
||||
|
||||
sql = f"UPDATE ppt_tasks SET {', '.join(fields)} WHERE id = %(task_id)s"
|
||||
|
||||
with get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_task(task_id: str) -> dict | None:
|
||||
"""获取任务详情"""
|
||||
with get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, user_id, title, source_type, source_content, source_file, "
|
||||
"config, status, progress, status_message, error_message, "
|
||||
"output_file, page_count, project_path, created_at "
|
||||
"FROM ppt_tasks WHERE id = %s",
|
||||
(task_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
cols = [desc[0] for desc in cur.description]
|
||||
return dict(zip(cols, row))
|
||||
@@ -0,0 +1,80 @@
|
||||
"""LLM 客户端 — 用于 PPT Master 管线中的策略师和执行器阶段"""
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from typing import Generator
|
||||
from config import config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""OpenAI 兼容的 LLM 客户端"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = config.OPENAI_API_KEY
|
||||
self.base_url = config.OPENAI_BASE_URL.rstrip("/")
|
||||
self.model = config.OPENAI_MODEL
|
||||
self.client = httpx.Client(timeout=300.0)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 16384,
|
||||
stream: bool = False,
|
||||
) -> str:
|
||||
"""同步聊天完成"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(headers, payload)
|
||||
|
||||
resp = self.client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_chat(self, headers: dict, payload: dict) -> str:
|
||||
"""流式聊天,收集完整响应"""
|
||||
full_content = ""
|
||||
with self.client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
full_content += content
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
continue
|
||||
return full_content
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
|
||||
|
||||
llm_client = LLMClient()
|
||||
@@ -0,0 +1,731 @@
|
||||
"""PPT Master 管线封装 — 核心任务处理器"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
import shutil
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from config import config
|
||||
from llm_client import llm_client
|
||||
from wanx_client import wanx_client
|
||||
from db import update_task_status
|
||||
|
||||
|
||||
class PPTPipeline:
|
||||
"""封装 PPT Master 完整管线"""
|
||||
|
||||
def __init__(self, task_id: str, task: dict, redis_callback=None):
|
||||
self.task_id = task_id
|
||||
self.task = task
|
||||
self.task_config = task.get("config", {}) if isinstance(task.get("config"), dict) else json.loads(task.get("config", "{}"))
|
||||
self.project_path: Path | None = None
|
||||
self.project_name: str = ""
|
||||
self._redis_callback = redis_callback
|
||||
|
||||
def run(self):
|
||||
"""执行完整管线"""
|
||||
try:
|
||||
self._update("processing", 5, "开始处理任务...")
|
||||
|
||||
# Step 1: 源内容处理
|
||||
source_md = self._step1_process_source()
|
||||
|
||||
# Step 2: 项目初始化
|
||||
self._step2_init_project()
|
||||
|
||||
# Step 3: 跳过模板(使用自由设计)
|
||||
|
||||
# Step 4: 策略师阶段
|
||||
self._step4_strategist(source_md)
|
||||
|
||||
# Step 5: 图片获取(如配置)
|
||||
if self.task_config.get("with_images", True):
|
||||
self._step5_images()
|
||||
|
||||
# Step 6: 执行器阶段(SVG 生成)
|
||||
self._step6_executor()
|
||||
|
||||
# Step 7: 后处理与导出
|
||||
output_path = self._step7_export()
|
||||
|
||||
# 完成
|
||||
self._update(
|
||||
"completed", 100, "PPT 生成完成",
|
||||
output_file=str(output_path),
|
||||
page_count=self._count_pages(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._update("failed", progress=0, status_message="生成失败", error_message=str(e))
|
||||
raise
|
||||
|
||||
def _update(self, status: str, progress: int = None, status_message: str = None, **kwargs):
|
||||
"""更新任务状态"""
|
||||
update_task_status(
|
||||
self.task_id,
|
||||
status=status,
|
||||
progress=progress,
|
||||
status_message=status_message,
|
||||
**kwargs,
|
||||
)
|
||||
# 同步更新 Redis(供前端快速轮询)
|
||||
if self._redis_callback and progress is not None:
|
||||
self._redis_callback(self.task_id, status, progress, status_message or "")
|
||||
|
||||
def _run_script(self, script_name: str, *args, cwd: Path = None) -> str:
|
||||
"""运行 PPT Master Python 脚本"""
|
||||
script_path = config.SCRIPTS_DIR / script_name
|
||||
cmd = [sys.executable, str(script_path)] + [str(a) for a in args]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(cwd or config.PPT_MASTER_PATH),
|
||||
env=self._build_env(),
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"脚本 {script_name} 执行失败:\n"
|
||||
f"stdout: {result.stdout[-2000:]}\n"
|
||||
f"stderr: {result.stderr[-2000:]}"
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
def _build_env(self) -> dict:
|
||||
"""构建脚本执行环境变量"""
|
||||
import os
|
||||
env = os.environ.copy()
|
||||
env["OPENAI_API_KEY"] = config.OPENAI_API_KEY
|
||||
env["OPENAI_BASE_URL"] = config.OPENAI_BASE_URL
|
||||
env["OPENAI_MODEL"] = config.OPENAI_MODEL
|
||||
if config.IMAGE_API_KEY:
|
||||
env["IMAGE_API_KEY"] = config.IMAGE_API_KEY
|
||||
env["IMAGE_BACKEND"] = config.IMAGE_BACKEND
|
||||
return env
|
||||
|
||||
# ==================== Step 1: 源内容处理 ====================
|
||||
|
||||
def _step1_process_source(self) -> str:
|
||||
"""处理源文件/内容,返回 Markdown 文本"""
|
||||
self._update("converting", 10, "正在转换源文件...")
|
||||
|
||||
source_type = self.task.get("source_type", "text")
|
||||
|
||||
if source_type == "text":
|
||||
return self.task.get("source_content", "")
|
||||
|
||||
elif source_type == "url":
|
||||
url = self.task.get("source_content", "")
|
||||
output = self._run_script("source_to_md/web_to_md.py", url)
|
||||
# web_to_md.py 输出转换后的 md 文件路径
|
||||
md_path = self._extract_output_path(output)
|
||||
if md_path and Path(md_path).exists():
|
||||
return Path(md_path).read_text(encoding="utf-8")
|
||||
return output
|
||||
|
||||
elif source_type == "file":
|
||||
file_path = self.task.get("source_file", "")
|
||||
if not file_path or not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"源文件不存在: {file_path}")
|
||||
|
||||
suffix = Path(file_path).suffix.lower()
|
||||
if suffix == ".pdf":
|
||||
output = self._run_script("source_to_md/pdf_to_md.py", file_path)
|
||||
elif suffix in (".docx", ".doc", ".html", ".epub"):
|
||||
output = self._run_script("source_to_md/doc_to_md.py", file_path)
|
||||
elif suffix in (".xlsx", ".xlsm"):
|
||||
output = self._run_script("source_to_md/excel_to_md.py", file_path)
|
||||
elif suffix in (".pptx",):
|
||||
output = self._run_script("source_to_md/ppt_to_md.py", file_path)
|
||||
elif suffix == ".md":
|
||||
return Path(file_path).read_text(encoding="utf-8")
|
||||
else:
|
||||
output = self._run_script("source_to_md/doc_to_md.py", file_path)
|
||||
|
||||
md_path = self._extract_output_path(output)
|
||||
if md_path and Path(md_path).exists():
|
||||
return Path(md_path).read_text(encoding="utf-8")
|
||||
return output
|
||||
|
||||
return ""
|
||||
|
||||
def _extract_output_path(self, script_output: str) -> str | None:
|
||||
"""从脚本输出中提取生成的文件路径"""
|
||||
for line in script_output.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line.endswith(".md") and Path(line).exists():
|
||||
return line
|
||||
# 匹配 "Output: /path/to/file.md" 格式
|
||||
match = re.search(r"(?:Output|Saved|Written):\s*(.+\.md)", line, re.IGNORECASE)
|
||||
if match:
|
||||
path = match.group(1).strip()
|
||||
if Path(path).exists():
|
||||
return path
|
||||
return None
|
||||
|
||||
def _resolve_project_path(self, script_output: str) -> Path:
|
||||
"""从 project_manager.py 输出中解析实际项目路径"""
|
||||
# 方式 1: 从输出中解析 "Project created: projects/xxx" 或 "[OK] Project initialized: projects/xxx"
|
||||
for line in script_output.strip().split("\n"):
|
||||
for prefix in ("Project created:", "[OK] Project initialized:"):
|
||||
if prefix in line:
|
||||
rel = line.split(prefix, 1)[1].strip()
|
||||
candidate = config.PROJECTS_DIR / rel
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
# 方式 2: 精确名称匹配
|
||||
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
|
||||
candidate = base / self.project_name
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
# 方式 3: glob 模糊匹配(project_manager 会追加格式和日期后缀)
|
||||
for base in (config.PROJECTS_DIR / "projects", config.PROJECTS_DIR):
|
||||
if base.exists():
|
||||
matches = sorted(base.glob(f"{self.project_name}*"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
raise RuntimeError(f"项目初始化失败,找不到项目目录: {script_output}")
|
||||
|
||||
# ==================== Step 2: 项目初始化 ====================
|
||||
|
||||
def _step2_init_project(self):
|
||||
"""初始化 PPT Master 项目"""
|
||||
self._update("processing", 15, "初始化项目...")
|
||||
|
||||
# 生成项目名
|
||||
title = self.task.get("title", "untitled")
|
||||
# 清理文件名
|
||||
safe_name = re.sub(r'[^\w\u4e00-\u9fff-]', '_', title)[:50]
|
||||
self.project_name = f"govai_{safe_name}_{self.task_id[:8]}"
|
||||
|
||||
fmt = self.task_config.get("format", "ppt169")
|
||||
|
||||
output = self._run_script(
|
||||
"project_manager.py", "init", self.project_name,
|
||||
"--format", fmt,
|
||||
cwd=config.PROJECTS_DIR,
|
||||
)
|
||||
|
||||
# 从输出中提取实际项目路径
|
||||
self.project_path = self._resolve_project_path(output)
|
||||
|
||||
update_task_status(self.task_id, status="processing", project_path=str(self.project_path))
|
||||
|
||||
# ==================== Step 4: 策略师阶段 ====================
|
||||
|
||||
def _step4_strategist(self, source_md: str):
|
||||
"""策略师阶段:生成设计规范和内容大纲"""
|
||||
self._update("designing", 20, "AI 正在分析内容并设计方案...")
|
||||
|
||||
# 读取策略师参考文档
|
||||
strategist_ref = (config.SKILL_DIR / "references" / "strategist.md").read_text(encoding="utf-8")
|
||||
design_spec_ref = (config.SKILL_DIR / "templates" / "design_spec_reference.md").read_text(encoding="utf-8")
|
||||
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
|
||||
|
||||
page_count = self.task_config.get("page_count", 10)
|
||||
style = self.task_config.get("style", "general")
|
||||
fmt = self.task_config.get("format", "ppt169")
|
||||
language = self.task_config.get("language", "zh")
|
||||
|
||||
# 构建策略师 prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是 PPT Master 的策略师角色。你的任务是根据用户提供的源内容,"
|
||||
"生成完整的 design_spec.md(设计规范)和 spec_lock.md(执行锁定)。\n\n"
|
||||
"你必须严格遵循以下参考模板的结构来生成输出:\n\n"
|
||||
"## 策略师角色定义(摘要)\n"
|
||||
f"{strategist_ref[:8000]}\n\n"
|
||||
"## design_spec.md 参考模板\n"
|
||||
f"{design_spec_ref[:10000]}\n\n"
|
||||
"## spec_lock.md 参考模板\n"
|
||||
f"{spec_lock_ref[:8000]}\n\n"
|
||||
"请直接输出两个文件的完整内容,用以下分隔符分开:\n"
|
||||
"===DESIGN_SPEC_START===\n[design_spec.md 内容]\n===DESIGN_SPEC_END===\n"
|
||||
"===SPEC_LOCK_START===\n[spec_lock.md 内容]\n===SPEC_LOCK_END===\n"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"请根据以下源内容生成演示文稿的设计规范。\n\n"
|
||||
f"## 生成要求\n"
|
||||
f"- 画布格式: {fmt}\n"
|
||||
f"- 目标页数: {page_count} 页\n"
|
||||
f"- 设计风格: {style}\n"
|
||||
f"- 语言: {language}\n\n"
|
||||
f"## 源内容\n\n{source_md[:50000]}"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
response = llm_client.chat(messages, temperature=0.7, max_tokens=16384)
|
||||
|
||||
# 解析输出
|
||||
design_spec = self._extract_section(response, "DESIGN_SPEC_START", "DESIGN_SPEC_END")
|
||||
spec_lock = self._extract_section(response, "SPEC_LOCK_START", "SPEC_LOCK_END")
|
||||
|
||||
if not design_spec:
|
||||
design_spec = response # 如果解析失败,整体作为 design_spec
|
||||
if not spec_lock:
|
||||
# 如果没有分离出 spec_lock,再次调用 LLM 单独生成
|
||||
spec_lock = self._generate_spec_lock(design_spec, source_md)
|
||||
|
||||
# 写入文件
|
||||
(self.project_path / "design_spec.md").write_text(design_spec, encoding="utf-8")
|
||||
(self.project_path / "spec_lock.md").write_text(spec_lock, encoding="utf-8")
|
||||
|
||||
self._update("designing", 35, "设计规范生成完成")
|
||||
|
||||
def _generate_spec_lock(self, design_spec: str, source_md: str) -> str:
|
||||
"""单独生成 spec_lock.md"""
|
||||
spec_lock_ref = (config.SKILL_DIR / "templates" / "spec_lock_reference.md").read_text(encoding="utf-8")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"根据已生成的 design_spec.md,生成对应的 spec_lock.md(执行锁定文件)。\n"
|
||||
"严格遵循以下参考模板结构:\n\n"
|
||||
f"{spec_lock_ref[:8000]}\n\n"
|
||||
"直接输出 spec_lock.md 的完整内容,不要加任何额外说明。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"## design_spec.md\n\n{design_spec[:15000]}",
|
||||
},
|
||||
]
|
||||
return llm_client.chat(messages, temperature=0.5, max_tokens=8192)
|
||||
|
||||
def _extract_section(self, text: str, start_marker: str, end_marker: str) -> str:
|
||||
"""从文本中提取指定标记之间的内容"""
|
||||
start = text.find(f"==={start_marker}===")
|
||||
end = text.find(f"==={end_marker}===")
|
||||
if start == -1 or end == -1:
|
||||
return ""
|
||||
start += len(f"==={start_marker}===")
|
||||
return text[start:end].strip()
|
||||
|
||||
# ==================== Step 5: 图片获取 ====================
|
||||
|
||||
def _step5_images(self):
|
||||
"""图片获取阶段(通义万相 / 脚本回退)"""
|
||||
self._update("generating_images", 40, "正在生成/搜索图片...")
|
||||
|
||||
# 从 design_spec 中提取图片需求
|
||||
design_spec_path = self.project_path / "design_spec.md"
|
||||
if not design_spec_path.exists():
|
||||
return
|
||||
|
||||
design_spec = design_spec_path.read_text(encoding="utf-8")
|
||||
|
||||
# 检查是否有图片需求
|
||||
if "Acquire Via: ai" not in design_spec and "Acquire Via: web" not in design_spec:
|
||||
self._update("generating_images", 45, "无需额外图片")
|
||||
return
|
||||
|
||||
# 使用 LLM 生成图片 prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"根据设计规范中的图片需求,生成图片生成 prompt。\n"
|
||||
"每行一个,格式为: filename|prompt\n"
|
||||
"例如: cover_bg.png|A modern city skyline at sunset, soft gradient sky\n"
|
||||
"只输出需要生成的图片,不要解释。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": design_spec[:10000],
|
||||
},
|
||||
]
|
||||
|
||||
response = llm_client.chat(messages, temperature=0.5, max_tokens=4096)
|
||||
|
||||
images_dir = self.project_path / "images"
|
||||
images_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 根据画布格式确定图片尺寸
|
||||
fmt = self.task_config.get("format", "ppt169")
|
||||
wanx_size_map = {
|
||||
"ppt169": "1024*576",
|
||||
"ppt43": "1024*768",
|
||||
"xhs": "720*1280",
|
||||
"story": "720*1280",
|
||||
}
|
||||
size = wanx_size_map.get(fmt, "1024*576")
|
||||
|
||||
# 逐个生成图片
|
||||
lines = [l.strip() for l in response.strip().split("\n") if "|" in l]
|
||||
for i, line in enumerate(lines[:10]): # 最多 10 张
|
||||
parts = line.split("|", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
filename, prompt = parts[0].strip(), parts[1].strip()
|
||||
|
||||
try:
|
||||
self._generate_image(prompt, images_dir, filename, size)
|
||||
except Exception:
|
||||
# 图片生成失败不中断管线
|
||||
pass
|
||||
|
||||
progress = 40 + int((i + 1) / max(len(lines), 1) * 10)
|
||||
self._update("generating_images", progress, f"图片生成中 ({i+1}/{len(lines)})")
|
||||
|
||||
def _generate_image(self, prompt: str, output_dir: Path, filename: str, size: str):
|
||||
"""生成单张图片:优先用万相,失败回退到 image_gen.py 脚本"""
|
||||
backend = config.IMAGE_BACKEND
|
||||
|
||||
if backend == "wanx":
|
||||
wanx_client.generate(
|
||||
prompt=prompt,
|
||||
output_dir=output_dir,
|
||||
filename=filename,
|
||||
size=size,
|
||||
)
|
||||
else:
|
||||
# 回退到 PPT Master 自带的 image_gen.py 脚本
|
||||
self._run_script(
|
||||
"image_gen.py", prompt,
|
||||
"--aspect_ratio", "16:9",
|
||||
"--image_size", "1K",
|
||||
"-o", str(output_dir),
|
||||
"--filename", filename,
|
||||
)
|
||||
|
||||
# ==================== Step 6: 执行器阶段 ====================
|
||||
|
||||
def _step6_executor(self):
|
||||
"""执行器阶段:逐页生成 SVG"""
|
||||
self._update("generating_svg", 50, "AI 正在逐页生成幻灯片...")
|
||||
|
||||
spec_lock_path = self.project_path / "spec_lock.md"
|
||||
design_spec_path = self.project_path / "design_spec.md"
|
||||
|
||||
spec_lock = spec_lock_path.read_text(encoding="utf-8") if spec_lock_path.exists() else ""
|
||||
design_spec = design_spec_path.read_text(encoding="utf-8") if design_spec_path.exists() else ""
|
||||
|
||||
# 读取执行器参考文档
|
||||
style = self.task_config.get("style", "general")
|
||||
executor_base = (config.SKILL_DIR / "references" / "executor-base.md").read_text(encoding="utf-8")
|
||||
shared_standards = (config.SKILL_DIR / "references" / "shared-standards.md").read_text(encoding="utf-8")
|
||||
|
||||
style_file_map = {
|
||||
"general": "executor-general.md",
|
||||
"consultant": "executor-consultant.md",
|
||||
"consultant-top": "executor-consultant-top.md",
|
||||
}
|
||||
style_file = style_file_map.get(style, "executor-general.md")
|
||||
executor_style = (config.SKILL_DIR / "references" / style_file).read_text(encoding="utf-8")
|
||||
|
||||
# 从 spec_lock 提取页数
|
||||
page_count = self.task_config.get("page_count", 10)
|
||||
# 尝试从 spec_lock 中提取实际规划页数
|
||||
page_match = re.search(r"page_count:\s*(\d+)", spec_lock)
|
||||
if page_match:
|
||||
page_count = int(page_match.group(1))
|
||||
|
||||
svg_output_dir = self.project_path / "svg_output"
|
||||
svg_output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 逐页生成 SVG
|
||||
for page_num in range(1, page_count + 1):
|
||||
self._generate_svg_page(
|
||||
page_num, page_count,
|
||||
spec_lock, design_spec,
|
||||
executor_base, shared_standards, executor_style,
|
||||
svg_output_dir,
|
||||
)
|
||||
progress = 50 + int(page_num / page_count * 30)
|
||||
self._update("generating_svg", progress, f"生成第 {page_num}/{page_count} 页")
|
||||
|
||||
# 生成演讲者备注
|
||||
self._generate_speaker_notes(design_spec, page_count)
|
||||
|
||||
# SVG 质量检查
|
||||
try:
|
||||
self._run_script("svg_quality_checker.py", str(self.project_path))
|
||||
except RuntimeError:
|
||||
pass # 质量检查失败不阻断
|
||||
|
||||
def _generate_svg_page(
|
||||
self, page_num: int, total_pages: int,
|
||||
spec_lock: str, design_spec: str,
|
||||
executor_base: str, shared_standards: str, executor_style: str,
|
||||
svg_output_dir: Path,
|
||||
):
|
||||
"""生成单页 SVG"""
|
||||
# 获取画布尺寸
|
||||
fmt = self.task_config.get("format", "ppt169")
|
||||
canvas_map = {
|
||||
"ppt169": (1280, 720),
|
||||
"ppt43": (1024, 768),
|
||||
"xhs": (1080, 1440),
|
||||
"story": (1080, 1920),
|
||||
}
|
||||
width, height = canvas_map.get(fmt, (1280, 720))
|
||||
|
||||
# 构建每页的 prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"你是 PPT Master 的执行器角色。现在需要生成第 {page_num}/{total_pages} 页的 SVG 代码。\n\n"
|
||||
f"## 关键技术约束(摘要)\n"
|
||||
f"- 画布: {width}x{height}\n"
|
||||
f"- viewBox=\"0 0 {width} {height}\"\n"
|
||||
f"- 所有文本必须用 <text> 元素,禁止 <foreignObject>\n"
|
||||
f"- 长文本必须手动分行(多个 <text> 或 <tspan dy=\"...\")\n"
|
||||
f"- 图片使用 <image href=\"../images/filename.ext\">\n"
|
||||
f"- 每个顶层元素用 <g id=\"elem_N\"> 包裹\n"
|
||||
f"- 颜色、字体必须严格来自 spec_lock\n\n"
|
||||
f"## 执行器规范(摘要)\n{executor_base[:6000]}\n\n"
|
||||
f"## 风格规范(摘要)\n{executor_style[:4000]}\n\n"
|
||||
f"## 技术标准(摘要)\n{shared_standards[:6000]}\n\n"
|
||||
f"直接输出完整的 SVG 代码,不要用 ```svg 包裹,不要加解释。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"## spec_lock.md\n\n{spec_lock[:12000]}\n\n"
|
||||
f"## design_spec.md(第 {page_num} 页相关内容)\n\n{design_spec[:8000]}\n\n"
|
||||
f"请生成第 {page_num} 页的完整 SVG。"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
svg_content = llm_client.chat(messages, temperature=0.6, max_tokens=16384)
|
||||
|
||||
# 清理:去掉可能的 ```svg 包裹
|
||||
svg_content = svg_content.strip()
|
||||
if svg_content.startswith("```"):
|
||||
lines = svg_content.split("\n")
|
||||
svg_content = "\n".join(lines[1:])
|
||||
if svg_content.endswith("```"):
|
||||
svg_content = svg_content[:-3].strip()
|
||||
|
||||
# 确保以 <svg 开头
|
||||
svg_start = svg_content.find("<svg")
|
||||
if svg_start > 0:
|
||||
svg_content = svg_content[svg_start:]
|
||||
|
||||
# 写入文件
|
||||
filename = f"page_{page_num:02d}.svg"
|
||||
(svg_output_dir / filename).write_text(svg_content, encoding="utf-8")
|
||||
|
||||
def _generate_speaker_notes(self, design_spec: str, page_count: int):
|
||||
"""生成演讲者备注"""
|
||||
notes_dir = self.project_path / "notes"
|
||||
notes_dir.mkdir(exist_ok=True)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"根据设计规范为每一页生成演讲者备注。\n"
|
||||
"格式为每页用 ## Page N 分隔,内容为该页的讲解要点,100-200字。\n"
|
||||
"直接输出,不加额外说明。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"共 {page_count} 页。设计规范:\n\n{design_spec[:15000]}",
|
||||
},
|
||||
]
|
||||
|
||||
notes = llm_client.chat(messages, temperature=0.7, max_tokens=8192)
|
||||
(notes_dir / "total.md").write_text(notes, encoding="utf-8")
|
||||
|
||||
# ==================== Step 7: 后处理与导出 ====================
|
||||
|
||||
def _step7_export(self) -> Path:
|
||||
"""后处理并导出 PPTX"""
|
||||
self._update("exporting", 85, "正在处理并导出 PPTX...")
|
||||
|
||||
# Step 7.1: 拆分演讲者备注
|
||||
try:
|
||||
self._run_script("total_md_split.py", str(self.project_path))
|
||||
except RuntimeError:
|
||||
pass # 备注拆分失败不阻断
|
||||
|
||||
self._update("exporting", 88, "清理无效图片引用...")
|
||||
self._strip_missing_images()
|
||||
|
||||
self._update("exporting", 90, "SVG 后处理中...")
|
||||
|
||||
# Step 7.2: SVG 后处理
|
||||
self._run_script("finalize_svg.py", str(self.project_path))
|
||||
|
||||
# Step 7.2.5: 验证并修复无效 SVG(LLM 可能生成格式错误的 XML)
|
||||
self._validate_and_repair_svgs()
|
||||
|
||||
self._update("exporting", 95, "导出 PPTX 文件...")
|
||||
|
||||
# Step 7.3: 导出 PPTX
|
||||
self._run_script("svg_to_pptx.py", str(self.project_path))
|
||||
|
||||
# 查找输出文件
|
||||
exports_dir = self.project_path / "exports"
|
||||
if not exports_dir.exists():
|
||||
# 有些版本输出到上级 exports 目录
|
||||
exports_dir = config.PROJECTS_DIR / "exports"
|
||||
|
||||
pptx_files = list(exports_dir.glob("*.pptx")) if exports_dir.exists() else []
|
||||
|
||||
if not pptx_files:
|
||||
# 搜索项目目录下所有 pptx
|
||||
pptx_files = list(self.project_path.rglob("*.pptx"))
|
||||
|
||||
if not pptx_files:
|
||||
raise RuntimeError("PPTX 导出失败:找不到输出文件")
|
||||
|
||||
# 取最新的 pptx
|
||||
output_pptx = max(pptx_files, key=lambda f: f.stat().st_mtime)
|
||||
|
||||
# 复制到统一输出目录
|
||||
final_path = config.OUTPUT_DIR / f"{self.task_id}.pptx"
|
||||
shutil.copy2(output_pptx, final_path)
|
||||
|
||||
return final_path
|
||||
|
||||
def _strip_missing_images(self):
|
||||
"""移除 SVG 中引用不存在图片文件的 <image> 元素"""
|
||||
svg_dir = self.project_path / "svg_output"
|
||||
if not svg_dir.exists():
|
||||
return
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
ns = {"svg": "http://www.w3.org/2000/svg", "xlink": "http://www.w3.org/1999/xlink"}
|
||||
ET.register_namespace("", "http://www.w3.org/2000/svg")
|
||||
ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
|
||||
|
||||
for svg_file in sorted(svg_dir.glob("*.svg")):
|
||||
try:
|
||||
tree = ET.parse(svg_file)
|
||||
root = tree.getroot()
|
||||
changed = False
|
||||
|
||||
# 递归查找所有 <image> 元素
|
||||
for parent in root.iter():
|
||||
to_remove = []
|
||||
for child in parent:
|
||||
tag = child.tag.split("}")[-1] if "}" in child.tag else child.tag
|
||||
if tag != "image":
|
||||
continue
|
||||
|
||||
href = (child.get("href") or
|
||||
child.get("{http://www.w3.org/1999/xlink}href") or "")
|
||||
|
||||
# 跳过 data URI 和绝对 URL
|
||||
if href.startswith("data:") or href.startswith("http"):
|
||||
continue
|
||||
|
||||
# 检查相对路径是否存在
|
||||
img_path = (svg_dir / href).resolve()
|
||||
if not img_path.exists():
|
||||
to_remove.append(child)
|
||||
|
||||
for elem in to_remove:
|
||||
parent.remove(elem)
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
tree.write(svg_file, encoding="unicode", xml_declaration=True)
|
||||
except Exception:
|
||||
pass # 解析失败跳过
|
||||
|
||||
def _validate_and_repair_svgs(self):
|
||||
"""验证所有 SVG 文件,修复或移除无效文件"""
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
svg_dir = self.project_path / "svg_output"
|
||||
if not svg_dir.exists():
|
||||
return
|
||||
|
||||
for svg_file in sorted(svg_dir.glob("*.svg")):
|
||||
try:
|
||||
ET.parse(svg_file)
|
||||
except ET.ParseError:
|
||||
# 尝试修复常见问题
|
||||
repaired = self._try_repair_svg(svg_file)
|
||||
if not repaired:
|
||||
# 无法修复,移除该文件让导出跳过这页
|
||||
svg_file.unlink()
|
||||
print(f"[Pipeline] 移除无效 SVG: {svg_file.name}")
|
||||
|
||||
def _try_repair_svg(self, svg_file: Path) -> bool:
|
||||
"""尝试修复无效 SVG,返回是否修复成功"""
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
content = svg_file.read_text(encoding="utf-8")
|
||||
|
||||
# 修复0: 移除自闭合标签后的同名多余闭合标签(LLM 常见错误)
|
||||
# 例: <feDropShadow .../>\n</feDropShadow> → <feDropShadow .../>
|
||||
content = re.sub(
|
||||
r'<(\w+)\b([^>]*)/>\s*\n\s*</\1>',
|
||||
r'<\1\2/>',
|
||||
content,
|
||||
)
|
||||
|
||||
# 修复1: 确保以 </svg> 结尾
|
||||
if "</svg>" not in content:
|
||||
# 找到最后一个完整的闭合标签,截断并添加 </svg>
|
||||
last_close = content.rfind("/>")
|
||||
if last_close == -1:
|
||||
last_close = content.rfind("</")
|
||||
if last_close > 0:
|
||||
# 找到该行末尾
|
||||
line_end = content.find("\n", last_close)
|
||||
if line_end > 0:
|
||||
content = content[:line_end + 1] + "</svg>\n"
|
||||
else:
|
||||
content = content[:last_close + 2] + "\n</svg>\n"
|
||||
|
||||
# 修复2: 移除 </svg> 之后的垃圾内容
|
||||
svg_end = content.find("</svg>")
|
||||
if svg_end > 0:
|
||||
content = content[:svg_end + len("</svg>")]
|
||||
|
||||
# 修复3: 尝试逐行移除导致解析失败的行
|
||||
try:
|
||||
ET.fromstring(content)
|
||||
svg_file.write_text(content, encoding="utf-8")
|
||||
return True
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
# 修复4: 暴力截断 — 找到最后一个有效的闭合 </g> 或 </text>,截断后闭合 </svg>
|
||||
for tag in ("</g>", "</text>", "</rect>", "</circle>"):
|
||||
last_pos = content.rfind(tag)
|
||||
if last_pos > 0:
|
||||
candidate = content[:last_pos + len(tag)] + "\n</svg>"
|
||||
try:
|
||||
ET.fromstring(candidate)
|
||||
svg_file.write_text(candidate, encoding="utf-8")
|
||||
return True
|
||||
except ET.ParseError:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _count_pages(self) -> int:
|
||||
"""统计生成的页数"""
|
||||
svg_dir = self.project_path / "svg_output"
|
||||
if svg_dir.exists():
|
||||
return len(list(svg_dir.glob("*.svg")))
|
||||
return 0
|
||||
@@ -0,0 +1,10 @@
|
||||
fastapi>=0.104.0
|
||||
uvicorn>=0.24.0
|
||||
redis>=5.0.0
|
||||
python-multipart>=0.0.6
|
||||
httpx>=0.25.0
|
||||
psycopg[binary]>=3.1.0
|
||||
python-dotenv>=1.0.0
|
||||
Pillow>=10.0.0
|
||||
python-pptx>=1.0.0
|
||||
pydantic>=2.5.0
|
||||
@@ -0,0 +1,126 @@
|
||||
"""通义万相(Wanx)图片生成客户端 — DashScope 异步任务 API"""
|
||||
|
||||
import time
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from config import config
|
||||
|
||||
|
||||
class WanxClient:
|
||||
"""通义万相文生图客户端,使用 DashScope 异步任务 API"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = config.IMAGE_API_KEY or config.OPENAI_API_KEY
|
||||
self.model = config.WANX_MODEL
|
||||
self.create_url = config.DASHSCOPE_IMAGE_URL
|
||||
self.task_url = config.DASHSCOPE_TASK_URL
|
||||
self.client = httpx.Client(timeout=120.0)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
output_dir: Path,
|
||||
filename: str = "image.png",
|
||||
size: str = "1024*576",
|
||||
n: int = 1,
|
||||
negative_prompt: str = "",
|
||||
style: str = "<auto>",
|
||||
) -> Path | None:
|
||||
"""
|
||||
生成图片并下载到本地。
|
||||
|
||||
Args:
|
||||
prompt: 正向提示词
|
||||
output_dir: 输出目录
|
||||
filename: 输出文件名
|
||||
size: 图片尺寸,格式 "宽*高",如 "1024*576"
|
||||
n: 生成数量(1-4)
|
||||
negative_prompt: 反向提示词
|
||||
style: 风格,<auto> 自动
|
||||
|
||||
Returns:
|
||||
下载后的本地文件路径,失败返回 None
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("未配置 DashScope API Key(OPENAI_API_KEY 或 IMAGE_API_KEY)")
|
||||
|
||||
# Step 1: 提交任务
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
},
|
||||
"parameters": {
|
||||
"size": size,
|
||||
"n": n,
|
||||
"style": style,
|
||||
},
|
||||
}
|
||||
if negative_prompt:
|
||||
payload["input"]["negative_prompt"] = negative_prompt
|
||||
|
||||
resp = self.client.post(self.create_url, headers=self._headers(), json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("code"):
|
||||
raise RuntimeError(f"万相任务创建失败: {data.get('message', data)}")
|
||||
|
||||
task_id = data["output"]["task_id"]
|
||||
|
||||
# Step 2: 轮询任务状态
|
||||
image_url = self._poll_task(task_id)
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
# Step 3: 下载图片
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_dir / filename
|
||||
|
||||
img_resp = self.client.get(image_url)
|
||||
img_resp.raise_for_status()
|
||||
output_path.write_bytes(img_resp.content)
|
||||
|
||||
return output_path
|
||||
|
||||
def _poll_task(self, task_id: str, max_wait: int = 120, interval: int = 3) -> str | None:
|
||||
"""轮询任务状态,返回图片 URL"""
|
||||
url = f"{self.task_url}/{task_id}"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
elapsed = 0
|
||||
while elapsed < max_wait:
|
||||
resp = self.client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
status = data.get("output", {}).get("task_status", "")
|
||||
|
||||
if status == "SUCCEEDED":
|
||||
results = data["output"].get("results", [])
|
||||
if results and results[0].get("url"):
|
||||
return results[0]["url"]
|
||||
return None
|
||||
|
||||
if status in ("FAILED", "UNKNOWN"):
|
||||
msg = data["output"].get("message", "未知错误")
|
||||
raise RuntimeError(f"万相任务失败 (task_id={task_id}): {msg}")
|
||||
|
||||
time.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
raise TimeoutError(f"万相任务超时 (task_id={task_id}),已等待 {max_wait}s")
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
|
||||
|
||||
# 单例
|
||||
wanx_client = WanxClient()
|
||||
@@ -0,0 +1,112 @@
|
||||
"""任务消费者 — 从 Redis 队列中获取任务并执行 PPT Master 管线"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import signal
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import redis
|
||||
|
||||
from config import config
|
||||
from db import get_task, update_task_status
|
||||
from pipeline import PPTPipeline
|
||||
|
||||
|
||||
class PPTWorker:
|
||||
"""PPT 生成任务消费者"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis = redis.from_url(config.REDIS_URL, decode_responses=True)
|
||||
self.executor = ThreadPoolExecutor(max_workers=config.CONCURRENCY)
|
||||
self.running = True
|
||||
self._setup_signal_handlers()
|
||||
|
||||
def _setup_signal_handlers(self):
|
||||
# signal 只能在主线程中注册,子线程中跳过
|
||||
try:
|
||||
signal.signal(signal.SIGINT, self._shutdown)
|
||||
signal.signal(signal.SIGTERM, self._shutdown)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _shutdown(self, signum, frame):
|
||||
print(f"\n[Worker] 收到信号 {signum},正在优雅关闭...")
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
"""启动 Worker 循环"""
|
||||
config.ensure_dirs()
|
||||
print(f"[Worker] 启动,并发数: {config.CONCURRENCY}")
|
||||
print(f"[Worker] 监听队列: {config.TASK_QUEUE}")
|
||||
print(f"[Worker] PPT Master 路径: {config.PPT_MASTER_PATH}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# BRPOP 阻塞等待任务,超时 5 秒
|
||||
result = self.redis.brpop(config.TASK_QUEUE, timeout=5)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
_, task_data = result
|
||||
task_msg = json.loads(task_data)
|
||||
task_id = task_msg.get("task_id")
|
||||
|
||||
if not task_id:
|
||||
print(f"[Worker] 无效任务消息: {task_data}")
|
||||
continue
|
||||
|
||||
print(f"[Worker] 收到任务: {task_id}")
|
||||
self.executor.submit(self._process_task, task_id)
|
||||
|
||||
except redis.ConnectionError as e:
|
||||
print(f"[Worker] Redis 连接失败: {e},5 秒后重试...")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
print(f"[Worker] 未知错误: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
print("[Worker] 等待正在执行的任务完成...")
|
||||
self.executor.shutdown(wait=True)
|
||||
print("[Worker] 已关闭")
|
||||
|
||||
def _process_task(self, task_id: str):
|
||||
"""处理单个任务"""
|
||||
try:
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
print(f"[Worker] 任务不存在: {task_id}")
|
||||
return
|
||||
|
||||
if task["status"] not in ("pending",):
|
||||
print(f"[Worker] 任务状态非 pending,跳过: {task_id} ({task['status']})")
|
||||
return
|
||||
|
||||
# 更新 Redis 状态(供前端快速轮询)
|
||||
self._set_redis_status(task_id, "processing", 5, "开始处理...")
|
||||
|
||||
pipeline = PPTPipeline(task_id, task, redis_callback=self._set_redis_status)
|
||||
pipeline.run()
|
||||
|
||||
self._set_redis_status(task_id, "completed", 100, "生成完成")
|
||||
print(f"[Worker] 任务完成: {task_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Worker] 任务失败: {task_id} - {e}")
|
||||
update_task_status(task_id, "failed", error_message=str(e))
|
||||
self._set_redis_status(task_id, "failed", 0, f"失败: {str(e)[:200]}")
|
||||
|
||||
def _set_redis_status(self, task_id: str, status: str, progress: int, message: str):
|
||||
"""更新 Redis 中的任务状态(用于快速轮询)"""
|
||||
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
|
||||
self.redis.hset(key, mapping={
|
||||
"status": status,
|
||||
"progress": str(progress),
|
||||
"message": message,
|
||||
})
|
||||
self.redis.expire(key, 3600) # 1 小时过期
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
worker = PPTWorker()
|
||||
worker.start()
|
||||
Reference in New Issue
Block a user