Files
GovAI/ppt-worker/app.py
T
2026-06-15 23:48:37 +08:00

185 lines
5.3 KiB
Python

"""PPT Worker HTTP API — FastAPI 服务,供 Go 后端调用"""
import json
import uuid
import shutil
import threading
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
import redis
import uvicorn
from config import config
from db import get_task, update_task_status
from worker import PPTWorker
app = FastAPI(title="PPT Worker API", version="1.0.0")
# Redis 客户端
rdb = redis.from_url(config.REDIS_URL, decode_responses=True)
# ==================== 数据模型 ====================
class CreateTaskRequest(BaseModel):
user_id: str
title: str
source_type: str = "text" # text / file / url
source_content: Optional[str] = None
config: dict = {}
class TaskStatusResponse(BaseModel):
task_id: str
status: str
progress: int
status_message: Optional[str] = None
error_message: Optional[str] = None
output_file: Optional[str] = None
page_count: Optional[int] = None
# ==================== API 路由 ====================
@app.post("/api/tasks", response_model=dict)
async def create_task(req: CreateTaskRequest):
"""创建 PPT 生成任务"""
import psycopg
task_id = str(uuid.uuid4())
with psycopg.connect(config.DATABASE_URL) as conn:
with conn.cursor() as cur:
cur.execute(
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
VALUES (%s, %s, %s, %s, %s, %s)""",
(task_id, req.user_id, req.title, req.source_type, req.source_content,
json.dumps(req.config)),
)
conn.commit()
# 推送到 Redis 队列
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
return {"task_id": task_id, "status": "pending"}
@app.post("/api/tasks/upload", response_model=dict)
async def create_task_with_file(
user_id: str = Form(...),
title: str = Form(...),
config_json: str = Form(default="{}"),
file: UploadFile = File(...),
):
"""创建带文件上传的 PPT 生成任务"""
import psycopg
config.ensure_dirs()
task_id = str(uuid.uuid4())
# 保存上传文件
file_ext = Path(file.filename).suffix if file.filename else ".bin"
saved_path = config.UPLOAD_DIR / f"{task_id}{file_ext}"
with open(saved_path, "wb") as f:
content = await file.read()
f.write(content)
task_config = json.loads(config_json)
with psycopg.connect(config.DATABASE_URL) as conn:
with conn.cursor() as cur:
cur.execute(
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_file, config)
VALUES (%s, %s, %s, 'file', %s, %s)""",
(task_id, user_id, title, str(saved_path), json.dumps(task_config)),
)
conn.commit()
# 推送到 Redis 队列
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
return {"task_id": task_id, "status": "pending"}
@app.get("/api/tasks/{task_id}", response_model=TaskStatusResponse)
async def get_task_status(task_id: str):
"""查询任务状态"""
# 先从 Redis 快速查询
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
cached = rdb.hgetall(key)
if cached:
return TaskStatusResponse(
task_id=task_id,
status=cached.get("status", "unknown"),
progress=int(cached.get("progress", 0)),
status_message=cached.get("message"),
)
# 回退到数据库查询
task = get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
return TaskStatusResponse(
task_id=task_id,
status=task["status"],
progress=task["progress"],
status_message=task.get("status_message"),
error_message=task.get("error_message"),
output_file=task.get("output_file"),
page_count=task.get("page_count"),
)
@app.get("/api/tasks/{task_id}/download")
async def download_task_output(task_id: str):
"""下载生成的 PPTX 文件"""
task = get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
if task["status"] != "completed":
raise HTTPException(status_code=400, detail="任务未完成")
output_file = task.get("output_file")
if not output_file or not Path(output_file).exists():
raise HTTPException(status_code=404, detail="输出文件不存在")
filename = f"{task.get('title', 'presentation')}.pptx"
return FileResponse(
path=output_file,
filename=filename,
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
)
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "ok", "service": "ppt-worker"}
# ==================== 启动 ====================
def start_worker_thread():
"""在后台线程中启动 Worker"""
worker = PPTWorker()
worker.start()
if __name__ == "__main__":
config.ensure_dirs()
# 启动后台 Worker 线程
worker_thread = threading.Thread(target=start_worker_thread, daemon=True)
worker_thread.start()
# 启动 HTTP 服务
uvicorn.run(app, host=config.HOST, port=config.PORT)