Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,311 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type PPTHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
workerURL string
|
||||
}
|
||||
|
||||
func NewPPTHandler(pool *pgxpool.Pool, rdb *redis.Client, workerURL string) *PPTHandler {
|
||||
return &PPTHandler{pool: pool, rdb: rdb, workerURL: workerURL}
|
||||
}
|
||||
|
||||
// ==================== 请求/响应结构 ====================
|
||||
|
||||
type createPPTRequest struct {
|
||||
Title string `json:"title"`
|
||||
SourceType string `json:"source_type"` // text / url
|
||||
SourceContent string `json:"source_content"` // 文本内容或 URL
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type pptTaskResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
StatusMessage *string `json:"status_message,omitempty"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
OutputFile *string `json:"output_file,omitempty"`
|
||||
PageCount *int `json:"page_count,omitempty"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// ==================== 接口实现 ====================
|
||||
|
||||
// CreateTask 创建 PPT 生成任务(文本/URL 输入)
|
||||
func (h *PPTHandler) CreateTask(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req createPPTRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Title == "" {
|
||||
response.BadRequest(w, "标题不能为空")
|
||||
return
|
||||
}
|
||||
if req.SourceType == "" {
|
||||
req.SourceType = "text"
|
||||
}
|
||||
if req.SourceContent == "" {
|
||||
response.BadRequest(w, "请提供源内容")
|
||||
return
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
|
||||
// 写入数据库
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
taskID, userID, req.Title, req.SourceType, req.SourceContent, configJSON,
|
||||
)
|
||||
if err != nil {
|
||||
response.InternalError(w, "创建任务失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 推送到 Redis 队列
|
||||
taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID})
|
||||
h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg)
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]string{
|
||||
"task_id": taskID,
|
||||
"status": "pending",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateTaskWithFile 创建带文件上传的 PPT 生成任务
|
||||
func (h *PPTHandler) CreateTaskWithFile(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
if err := r.ParseMultipartForm(50 << 20); err != nil { // 50MB 限制
|
||||
response.BadRequest(w, "文件过大或请求格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
title := r.FormValue("title")
|
||||
if title == "" {
|
||||
response.BadRequest(w, "标题不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
configStr := r.FormValue("config")
|
||||
var taskConfig map[string]any
|
||||
if configStr != "" {
|
||||
json.Unmarshal([]byte(configStr), &taskConfig)
|
||||
}
|
||||
if taskConfig == nil {
|
||||
taskConfig = map[string]any{}
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
response.BadRequest(w, "请上传文件")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 将文件转发到 PPT Worker
|
||||
taskID := uuid.New().String()
|
||||
configJSON, _ := json.Marshal(taskConfig)
|
||||
|
||||
// 转发文件到 Worker 服务
|
||||
err = h.forwardFileToWorker(r.Context(), taskID, userID.String(), title, string(configJSON), file, header)
|
||||
if err != nil {
|
||||
response.InternalError(w, "提交任务失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]string{
|
||||
"task_id": taskID,
|
||||
"status": "pending",
|
||||
})
|
||||
}
|
||||
|
||||
// GetTaskStatus 查询任务状态
|
||||
func (h *PPTHandler) GetTaskStatus(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
// 先从 Redis 快速查询
|
||||
key := "ppt:status:" + taskID
|
||||
cached, err := h.rdb.HGetAll(r.Context(), key).Result()
|
||||
if err == nil && len(cached) > 0 {
|
||||
progress := 0
|
||||
fmt.Sscanf(cached["progress"], "%d", &progress)
|
||||
msg := cached["message"]
|
||||
response.JSON(w, http.StatusOK, pptTaskResponse{
|
||||
TaskID: taskID,
|
||||
Status: cached["status"],
|
||||
Progress: progress,
|
||||
StatusMessage: &msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 回退到数据库
|
||||
var task pptTaskResponse
|
||||
var statusMsg, errMsg, outputFile *string
|
||||
var pageCount *int
|
||||
var createdAt time.Time
|
||||
|
||||
err = h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
||||
FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID,
|
||||
).Scan(&task.TaskID, &task.Status, &task.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &task.Title, &createdAt)
|
||||
if err != nil {
|
||||
response.NotFound(w, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
task.StatusMessage = statusMsg
|
||||
task.ErrorMessage = errMsg
|
||||
task.OutputFile = outputFile
|
||||
task.PageCount = pageCount
|
||||
task.CreatedAt = createdAt.Format(time.RFC3339)
|
||||
|
||||
response.JSON(w, http.StatusOK, task)
|
||||
}
|
||||
|
||||
// ListTasks 列出用户的 PPT 任务
|
||||
func (h *PPTHandler) ListTasks(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
||||
FROM ppt_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`, userID,
|
||||
)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []pptTaskResponse
|
||||
for rows.Next() {
|
||||
var t pptTaskResponse
|
||||
var statusMsg, errMsg, outputFile *string
|
||||
var pageCount *int
|
||||
var createdAt time.Time
|
||||
|
||||
if err := rows.Scan(&t.TaskID, &t.Status, &t.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &t.Title, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
t.StatusMessage = statusMsg
|
||||
t.ErrorMessage = errMsg
|
||||
t.OutputFile = outputFile
|
||||
t.PageCount = pageCount
|
||||
t.CreatedAt = createdAt.Format(time.RFC3339)
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
|
||||
if tasks == nil {
|
||||
tasks = []pptTaskResponse{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, tasks)
|
||||
}
|
||||
|
||||
// DownloadTask 下载生成的 PPTX 文件
|
||||
func (h *PPTHandler) DownloadTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var status, title string
|
||||
var outputFile *string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status, title, output_file FROM ppt_tasks WHERE id = $1 AND user_id = $2`,
|
||||
taskID, userID,
|
||||
).Scan(&status, &title, &outputFile)
|
||||
if err != nil {
|
||||
response.NotFound(w, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if status != "completed" {
|
||||
response.BadRequest(w, "任务未完成")
|
||||
return
|
||||
}
|
||||
|
||||
// 代理下载请求到 Worker 服务
|
||||
workerResp, err := http.Get(h.workerURL + "/api/tasks/" + taskID + "/download")
|
||||
if err != nil {
|
||||
response.InternalError(w, "下载服务不可用")
|
||||
return
|
||||
}
|
||||
defer workerResp.Body.Close()
|
||||
|
||||
if workerResp.StatusCode != http.StatusOK {
|
||||
response.InternalError(w, "文件下载失败")
|
||||
return
|
||||
}
|
||||
|
||||
filename := title + ".pptx"
|
||||
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
|
||||
io.Copy(w, workerResp.Body)
|
||||
}
|
||||
|
||||
// ==================== 内部方法 ====================
|
||||
|
||||
func (h *PPTHandler) forwardFileToWorker(ctx context.Context, taskID, userID, title, configJSON string, file multipart.File, header *multipart.FileHeader) error {
|
||||
// 构造 multipart 请求转发到 Worker
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
writer.WriteField("task_id", taskID)
|
||||
writer.WriteField("user_id", userID)
|
||||
writer.WriteField("title", title)
|
||||
writer.WriteField("config_json", configJSON)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(header.Filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", h.workerURL+"/api/tasks/upload", &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("worker 返回错误 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user