113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
"""任务消费者 — 从 Redis 队列中获取任务并执行 PPT Master 管线"""
|
||
|
||
import json
|
||
import time
|
||
import signal
|
||
import threading
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
import redis
|
||
|
||
from config import config
|
||
from db import get_task, update_task_status
|
||
from pipeline import PPTPipeline
|
||
|
||
|
||
class PPTWorker:
|
||
"""PPT 生成任务消费者"""
|
||
|
||
def __init__(self):
|
||
self.redis = redis.from_url(config.REDIS_URL, decode_responses=True)
|
||
self.executor = ThreadPoolExecutor(max_workers=config.CONCURRENCY)
|
||
self.running = True
|
||
self._setup_signal_handlers()
|
||
|
||
def _setup_signal_handlers(self):
|
||
# signal 只能在主线程中注册,子线程中跳过
|
||
try:
|
||
signal.signal(signal.SIGINT, self._shutdown)
|
||
signal.signal(signal.SIGTERM, self._shutdown)
|
||
except ValueError:
|
||
pass
|
||
|
||
def _shutdown(self, signum, frame):
|
||
print(f"\n[Worker] 收到信号 {signum},正在优雅关闭...")
|
||
self.running = False
|
||
|
||
def start(self):
|
||
"""启动 Worker 循环"""
|
||
config.ensure_dirs()
|
||
print(f"[Worker] 启动,并发数: {config.CONCURRENCY}")
|
||
print(f"[Worker] 监听队列: {config.TASK_QUEUE}")
|
||
print(f"[Worker] PPT Master 路径: {config.PPT_MASTER_PATH}")
|
||
|
||
while self.running:
|
||
try:
|
||
# BRPOP 阻塞等待任务,超时 5 秒
|
||
result = self.redis.brpop(config.TASK_QUEUE, timeout=5)
|
||
if result is None:
|
||
continue
|
||
|
||
_, task_data = result
|
||
task_msg = json.loads(task_data)
|
||
task_id = task_msg.get("task_id")
|
||
|
||
if not task_id:
|
||
print(f"[Worker] 无效任务消息: {task_data}")
|
||
continue
|
||
|
||
print(f"[Worker] 收到任务: {task_id}")
|
||
self.executor.submit(self._process_task, task_id)
|
||
|
||
except redis.ConnectionError as e:
|
||
print(f"[Worker] Redis 连接失败: {e},5 秒后重试...")
|
||
time.sleep(5)
|
||
except Exception as e:
|
||
print(f"[Worker] 未知错误: {e}")
|
||
time.sleep(1)
|
||
|
||
print("[Worker] 等待正在执行的任务完成...")
|
||
self.executor.shutdown(wait=True)
|
||
print("[Worker] 已关闭")
|
||
|
||
def _process_task(self, task_id: str):
|
||
"""处理单个任务"""
|
||
try:
|
||
task = get_task(task_id)
|
||
if not task:
|
||
print(f"[Worker] 任务不存在: {task_id}")
|
||
return
|
||
|
||
if task["status"] not in ("pending",):
|
||
print(f"[Worker] 任务状态非 pending,跳过: {task_id} ({task['status']})")
|
||
return
|
||
|
||
# 更新 Redis 状态(供前端快速轮询)
|
||
self._set_redis_status(task_id, "processing", 5, "开始处理...")
|
||
|
||
pipeline = PPTPipeline(task_id, task, redis_callback=self._set_redis_status)
|
||
pipeline.run()
|
||
|
||
self._set_redis_status(task_id, "completed", 100, "生成完成")
|
||
print(f"[Worker] 任务完成: {task_id}")
|
||
|
||
except Exception as e:
|
||
print(f"[Worker] 任务失败: {task_id} - {e}")
|
||
update_task_status(task_id, "failed", error_message=str(e))
|
||
self._set_redis_status(task_id, "failed", 0, f"失败: {str(e)[:200]}")
|
||
|
||
def _set_redis_status(self, task_id: str, status: str, progress: int, message: str):
|
||
"""更新 Redis 中的任务状态(用于快速轮询)"""
|
||
key = f"{config.TASK_STATUS_PREFIX}{task_id}"
|
||
self.redis.hset(key, mapping={
|
||
"status": status,
|
||
"progress": str(progress),
|
||
"message": message,
|
||
})
|
||
self.redis.expire(key, 3600) # 1 小时过期
|
||
|
||
|
||
if __name__ == "__main__":
|
||
worker = PPTWorker()
|
||
worker.start()
|