Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
"""数据库操作模块"""
|
||||
|
||||
import psycopg
|
||||
from datetime import datetime, timezone
|
||||
from config import config
|
||||
|
||||
|
||||
def get_connection():
|
||||
"""获取数据库连接"""
|
||||
return psycopg.connect(config.DATABASE_URL)
|
||||
|
||||
|
||||
def update_task_status(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
status_message: str = None,
|
||||
error_message: str = None,
|
||||
output_file: str = None,
|
||||
page_count: int = None,
|
||||
project_path: str = None,
|
||||
):
|
||||
"""更新任务状态"""
|
||||
fields = ["status = %(status)s", "updated_at = NOW()"]
|
||||
params = {"task_id": task_id, "status": status}
|
||||
|
||||
if progress is not None:
|
||||
fields.append("progress = %(progress)s")
|
||||
params["progress"] = progress
|
||||
|
||||
if status_message is not None:
|
||||
fields.append("status_message = %(status_message)s")
|
||||
params["status_message"] = status_message
|
||||
|
||||
if error_message is not None:
|
||||
fields.append("error_message = %(error_message)s")
|
||||
params["error_message"] = error_message
|
||||
|
||||
if output_file is not None:
|
||||
fields.append("output_file = %(output_file)s")
|
||||
params["output_file"] = output_file
|
||||
|
||||
if page_count is not None:
|
||||
fields.append("page_count = %(page_count)s")
|
||||
params["page_count"] = page_count
|
||||
|
||||
if project_path is not None:
|
||||
fields.append("project_path = %(project_path)s")
|
||||
params["project_path"] = project_path
|
||||
|
||||
if status == "processing":
|
||||
fields.append("started_at = NOW()")
|
||||
elif status in ("completed", "failed"):
|
||||
fields.append("completed_at = NOW()")
|
||||
|
||||
sql = f"UPDATE ppt_tasks SET {', '.join(fields)} WHERE id = %(task_id)s"
|
||||
|
||||
with get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_task(task_id: str) -> dict | None:
|
||||
"""获取任务详情"""
|
||||
with get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, user_id, title, source_type, source_content, source_file, "
|
||||
"config, status, progress, status_message, error_message, "
|
||||
"output_file, page_count, project_path, created_at "
|
||||
"FROM ppt_tasks WHERE id = %s",
|
||||
(task_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
cols = [desc[0] for desc in cur.description]
|
||||
return dict(zip(cols, row))
|
||||
Reference in New Issue
Block a user