Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,184 @@
|
||||
"""PPT Worker HTTP API — FastAPI 服务,供 Go 后端调用"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
import redis
|
||||
import uvicorn
|
||||
|
||||
from config import config
|
||||
from db import get_task, update_task_status
|
||||
from worker import PPTWorker
|
||||
|
||||
app = FastAPI(title="PPT Worker API", version="1.0.0")
|
||||
|
||||
# Redis 客户端
|
||||
rdb = redis.from_url(config.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
class CreateTaskRequest(BaseModel):
|
||||
user_id: str
|
||||
title: str
|
||||
source_type: str = "text" # text / file / url
|
||||
source_content: Optional[str] = None
|
||||
config: dict = {}
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
progress: int
|
||||
status_message: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
output_file: Optional[str] = None
|
||||
page_count: Optional[int] = None
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
@app.post("/api/tasks", response_model=dict)
|
||||
async def create_task(req: CreateTaskRequest):
|
||||
"""创建 PPT 生成任务"""
|
||||
import psycopg
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
with psycopg.connect(config.DATABASE_URL) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)""",
|
||||
(task_id, req.user_id, req.title, req.source_type, req.source_content,
|
||||
json.dumps(req.config)),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 推送到 Redis 队列
|
||||
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
||||
|
||||
return {"task_id": task_id, "status": "pending"}
|
||||
|
||||
|
||||
@app.post("/api/tasks/upload", response_model=dict)
|
||||
async def create_task_with_file(
|
||||
user_id: str = Form(...),
|
||||
title: str = Form(...),
|
||||
config_json: str = Form(default="{}"),
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""创建带文件上传的 PPT 生成任务"""
|
||||
import psycopg
|
||||
|
||||
config.ensure_dirs()
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 保存上传文件
|
||||
file_ext = Path(file.filename).suffix if file.filename else ".bin"
|
||||
saved_path = config.UPLOAD_DIR / f"{task_id}{file_ext}"
|
||||
with open(saved_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
task_config = json.loads(config_json)
|
||||
|
||||
with psycopg.connect(config.DATABASE_URL) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO ppt_tasks (id, user_id, title, source_type, source_file, config)
|
||||
VALUES (%s, %s, %s, 'file', %s, %s)""",
|
||||
(task_id, user_id, title, str(saved_path), json.dumps(task_config)),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 推送到 Redis 队列
|
||||
rdb.lpush(config.TASK_QUEUE, json.dumps({"task_id": task_id}))
|
||||
|
||||
return {"task_id": task_id, "status": "pending"}
|
||||
|
||||
|
||||
@app.get("/api/tasks/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
"""查询任务状态"""
|
||||
# 先从 Redis 快速查询
|
||||
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
|
||||
cached = rdb.hgetall(key)
|
||||
|
||||
if cached:
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=cached.get("status", "unknown"),
|
||||
progress=int(cached.get("progress", 0)),
|
||||
status_message=cached.get("message"),
|
||||
)
|
||||
|
||||
# 回退到数据库查询
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=task["status"],
|
||||
progress=task["progress"],
|
||||
status_message=task.get("status_message"),
|
||||
error_message=task.get("error_message"),
|
||||
output_file=task.get("output_file"),
|
||||
page_count=task.get("page_count"),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/tasks/{task_id}/download")
|
||||
async def download_task_output(task_id: str):
|
||||
"""下载生成的 PPTX 文件"""
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
if task["status"] != "completed":
|
||||
raise HTTPException(status_code=400, detail="任务未完成")
|
||||
|
||||
output_file = task.get("output_file")
|
||||
if not output_file or not Path(output_file).exists():
|
||||
raise HTTPException(status_code=404, detail="输出文件不存在")
|
||||
|
||||
filename = f"{task.get('title', 'presentation')}.pptx"
|
||||
return FileResponse(
|
||||
path=output_file,
|
||||
filename=filename,
|
||||
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""健康检查"""
|
||||
return {"status": "ok", "service": "ppt-worker"}
|
||||
|
||||
|
||||
# ==================== 启动 ====================
|
||||
|
||||
def start_worker_thread():
|
||||
"""在后台线程中启动 Worker"""
|
||||
worker = PPTWorker()
|
||||
worker.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.ensure_dirs()
|
||||
|
||||
# 启动后台 Worker 线程
|
||||
worker_thread = threading.Thread(target=start_worker_thread, daemon=True)
|
||||
worker_thread.start()
|
||||
|
||||
# 启动 HTTP 服务
|
||||
uvicorn.run(app, host=config.HOST, port=config.PORT)
|
||||
Reference in New Issue
Block a user