"""PPT Worker HTTP API — FastAPI 服务,供 Go 后端调用""" import json import uuid import shutil import threading from pathlib import Path from typing import Optional from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import FileResponse from pydantic import BaseModel import redis import uvicorn from config import config from db import get_task, update_task_status from worker import PPTWorker app = FastAPI(title="PPT Worker API", version="1.0.0") # Redis 客户端 rdb = redis.from_url(config.REDIS_URL, decode_responses=True) # ==================== 数据模型 ==================== class CreateTaskRequest(BaseModel): user_id: str title: str source_type: str = "text" # text / file / url source_content: Optional[str] = None config: dict = {} class TaskStatusResponse(BaseModel): task_id: str status: str progress: int status_message: Optional[str] = None error_message: Optional[str] = None output_file: Optional[str] = None page_count: Optional[int] = None # ==================== API 路由 ==================== @app.post("/api/tasks", response_model=dict) async def create_task(req: CreateTaskRequest): """创建 PPT 生成任务""" import psycopg task_id = str(uuid.uuid4()) with psycopg.connect(config.DATABASE_URL) as conn: with conn.cursor() as cur: cur.execute( """INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config) VALUES (%s, %s, %s, %s, %s, %s)""", (task_id, req.user_id, req.title, req.source_type, req.source_content, json.dumps(req.config)), ) conn.commit() # 推送到 Redis 队列 rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id})) return {"task_id": task_id, "status": "pending"} @app.post("/api/tasks/upload", response_model=dict) async def create_task_with_file( user_id: str = Form(...), title: str = Form(...), config_json: str = Form(default="{}"), file: UploadFile = File(...), ): """创建带文件上传的 PPT 生成任务""" import psycopg config.ensure_dirs() task_id = str(uuid.uuid4()) # 保存上传文件 file_ext = Path(file.filename).suffix if file.filename else ".bin" saved_path = config.UPLOAD_DIR / f"{task_id}{file_ext}" with open(saved_path, "wb") as f: content = await file.read() f.write(content) task_config = json.loads(config_json) with psycopg.connect(config.DATABASE_URL) as conn: with conn.cursor() as cur: cur.execute( """INSERT INTO ppt_tasks (id, user_id, title, source_type, source_file, config) VALUES (%s, %s, %s, 'file', %s, %s)""", (task_id, user_id, title, str(saved_path), json.dumps(task_config)), ) conn.commit() # 推送到 Redis 队列 rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id})) return {"task_id": task_id, "status": "pending"} @app.get("/api/tasks/{task_id}", response_model=TaskStatusResponse) async def get_task_status(task_id: str): """查询任务状态""" # 先从 Redis 快速查询 key = f"{config.TASK_STATUS_PREFIX}{task_id}" cached = rdb.hgetall(key) if cached: return TaskStatusResponse( task_id=task_id, status=cached.get("status", "unknown"), progress=int(cached.get("progress", 0)), status_message=cached.get("message"), ) # 回退到数据库查询 task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") return TaskStatusResponse( task_id=task_id, status=task["status"], progress=task["progress"], status_message=task.get("status_message"), error_message=task.get("error_message"), output_file=task.get("output_file"), page_count=task.get("page_count"), ) @app.get("/api/tasks/{task_id}/download") async def download_task_output(task_id: str): """下载生成的 PPTX 文件""" task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") if task["status"] != "completed": raise HTTPException(status_code=400, detail="任务未完成") output_file = task.get("output_file") if not output_file or not Path(output_file).exists(): raise HTTPException(status_code=404, detail="输出文件不存在") filename = f"{task.get('title', 'presentation')}.pptx" return FileResponse( path=output_file, filename=filename, media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", ) @app.get("/health") async def health(): """健康检查""" return {"status": "ok", "service": "ppt-worker"} # ==================== 启动 ==================== def start_worker_thread(): """在后台线程中启动 Worker""" worker = PPTWorker() worker.start() if __name__ == "__main__": config.ensure_dirs() # 启动后台 Worker 线程 worker_thread = threading.Thread(target=start_worker_thread, daemon=True) worker_thread.start() # 启动 HTTP 服务 uvicorn.run(app, host=config.HOST, port=config.PORT)