80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
"""数据库操作模块"""
|
|
|
|
import psycopg
|
|
from datetime import datetime, timezone
|
|
from config import config
|
|
|
|
|
|
def get_connection():
|
|
"""获取数据库连接"""
|
|
return psycopg.connect(config.DATABASE_URL)
|
|
|
|
|
|
def update_task_status(
|
|
task_id: str,
|
|
status: str,
|
|
progress: int = None,
|
|
status_message: str = None,
|
|
error_message: str = None,
|
|
output_file: str = None,
|
|
page_count: int = None,
|
|
project_path: str = None,
|
|
):
|
|
"""更新任务状态"""
|
|
fields = ["status = %(status)s", "updated_at = NOW()"]
|
|
params = {"task_id": task_id, "status": status}
|
|
|
|
if progress is not None:
|
|
fields.append("progress = %(progress)s")
|
|
params["progress"] = progress
|
|
|
|
if status_message is not None:
|
|
fields.append("status_message = %(status_message)s")
|
|
params["status_message"] = status_message
|
|
|
|
if error_message is not None:
|
|
fields.append("error_message = %(error_message)s")
|
|
params["error_message"] = error_message
|
|
|
|
if output_file is not None:
|
|
fields.append("output_file = %(output_file)s")
|
|
params["output_file"] = output_file
|
|
|
|
if page_count is not None:
|
|
fields.append("page_count = %(page_count)s")
|
|
params["page_count"] = page_count
|
|
|
|
if project_path is not None:
|
|
fields.append("project_path = %(project_path)s")
|
|
params["project_path"] = project_path
|
|
|
|
if status == "processing":
|
|
fields.append("started_at = NOW()")
|
|
elif status in ("completed", "failed"):
|
|
fields.append("completed_at = NOW()")
|
|
|
|
sql = f"UPDATE ppt_tasks SET {', '.join(fields)} WHERE id = %(task_id)s"
|
|
|
|
with get_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params)
|
|
conn.commit()
|
|
|
|
|
|
def get_task(task_id: str) -> dict | None:
|
|
"""获取任务详情"""
|
|
with get_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT id, user_id, title, source_type, source_content, source_file, "
|
|
"config, status, progress, status_message, error_message, "
|
|
"output_file, page_count, project_path, created_at "
|
|
"FROM ppt_tasks WHERE id = %s",
|
|
(task_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return None
|
|
cols = [desc[0] for desc in cur.description]
|
|
return dict(zip(cols, row))
|