185 lines
5.3 KiB
Python
185 lines
5.3 KiB
Python
"""PPT Worker HTTP API — FastAPI 服务,供 Go 后端调用"""
|
|
|
|
import json
|
|
import uuid
|
|
import shutil
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel
|
|
import redis
|
|
import uvicorn
|
|
|
|
from config import config
|
|
from db import get_task, update_task_status
|
|
from worker import PPTWorker
|
|
|
|
app = FastAPI(title="PPT Worker API", version="1.0.0")
|
|
|
|
# Redis 客户端
|
|
rdb = redis.from_url(config.REDIS_URL, decode_responses=True)
|
|
|
|
|
|
# ==================== 数据模型 ====================
|
|
|
|
class CreateTaskRequest(BaseModel):
|
|
user_id: str
|
|
title: str
|
|
source_type: str = "text" # text / file / url
|
|
source_content: Optional[str] = None
|
|
config: dict = {}
|
|
|
|
|
|
class TaskStatusResponse(BaseModel):
|
|
task_id: str
|
|
status: str
|
|
progress: int
|
|
status_message: Optional[str] = None
|
|
error_message: Optional[str] = None
|
|
output_file: Optional[str] = None
|
|
page_count: Optional[int] = None
|
|
|
|
|
|
# ==================== API 路由 ====================
|
|
|
|
@app.post("/api/tasks", response_model=dict)
|
|
async def create_task(req: CreateTaskRequest):
|
|
"""创建 PPT 生成任务"""
|
|
import psycopg
|
|
|
|
task_id = str(uuid.uuid4())
|
|
|
|
with psycopg.connect(config.DATABASE_URL) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
|
VALUES (%s, %s, %s, %s, %s, %s)""",
|
|
(task_id, req.user_id, req.title, req.source_type, req.source_content,
|
|
json.dumps(req.config)),
|
|
)
|
|
conn.commit()
|
|
|
|
# 推送到 Redis 队列
|
|
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
|
|
|
return {"task_id": task_id, "status": "pending"}
|
|
|
|
|
|
@app.post("/api/tasks/upload", response_model=dict)
|
|
async def create_task_with_file(
|
|
user_id: str = Form(...),
|
|
title: str = Form(...),
|
|
config_json: str = Form(default="{}"),
|
|
file: UploadFile = File(...),
|
|
):
|
|
"""创建带文件上传的 PPT 生成任务"""
|
|
import psycopg
|
|
|
|
config.ensure_dirs()
|
|
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# 保存上传文件
|
|
file_ext = Path(file.filename).suffix if file.filename else ".bin"
|
|
saved_path = config.UPLOAD_DIR / f"{task_id}{file_ext}"
|
|
with open(saved_path, "wb") as f:
|
|
content = await file.read()
|
|
f.write(content)
|
|
|
|
task_config = json.loads(config_json)
|
|
|
|
with psycopg.connect(config.DATABASE_URL) as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_file, config)
|
|
VALUES (%s, %s, %s, 'file', %s, %s)""",
|
|
(task_id, user_id, title, str(saved_path), json.dumps(task_config)),
|
|
)
|
|
conn.commit()
|
|
|
|
# 推送到 Redis 队列
|
|
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
|
|
|
return {"task_id": task_id, "status": "pending"}
|
|
|
|
|
|
@app.get("/api/tasks/{task_id}", response_model=TaskStatusResponse)
|
|
async def get_task_status(task_id: str):
|
|
"""查询任务状态"""
|
|
# 先从 Redis 快速查询
|
|
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
|
|
cached = rdb.hgetall(key)
|
|
|
|
if cached:
|
|
return TaskStatusResponse(
|
|
task_id=task_id,
|
|
status=cached.get("status", "unknown"),
|
|
progress=int(cached.get("progress", 0)),
|
|
status_message=cached.get("message"),
|
|
)
|
|
|
|
# 回退到数据库查询
|
|
task = get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
return TaskStatusResponse(
|
|
task_id=task_id,
|
|
status=task["status"],
|
|
progress=task["progress"],
|
|
status_message=task.get("status_message"),
|
|
error_message=task.get("error_message"),
|
|
output_file=task.get("output_file"),
|
|
page_count=task.get("page_count"),
|
|
)
|
|
|
|
|
|
@app.get("/api/tasks/{task_id}/download")
|
|
async def download_task_output(task_id: str):
|
|
"""下载生成的 PPTX 文件"""
|
|
task = get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
if task["status"] != "completed":
|
|
raise HTTPException(status_code=400, detail="任务未完成")
|
|
|
|
output_file = task.get("output_file")
|
|
if not output_file or not Path(output_file).exists():
|
|
raise HTTPException(status_code=404, detail="输出文件不存在")
|
|
|
|
filename = f"{task.get('title', 'presentation')}.pptx"
|
|
return FileResponse(
|
|
path=output_file,
|
|
filename=filename,
|
|
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""健康检查"""
|
|
return {"status": "ok", "service": "ppt-worker"}
|
|
|
|
|
|
# ==================== 启动 ====================
|
|
|
|
def start_worker_thread():
|
|
"""在后台线程中启动 Worker"""
|
|
worker = PPTWorker()
|
|
worker.start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config.ensure_dirs()
|
|
|
|
# 启动后台 Worker 线程
|
|
worker_thread = threading.Thread(target=start_worker_thread, daemon=True)
|
|
worker_thread.start()
|
|
|
|
# 启动 HTTP 服务
|
|
uvicorn.run(app, host=config.HOST, port=config.PORT)
|