package handler import ( "bytes" "context" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "path/filepath" "time" "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" ) type PPTHandler struct { pool *pgxpool.Pool rdb *redis.Client workerURL string } func NewPPTHandler(pool *pgxpool.Pool, rdb *redis.Client, workerURL string) *PPTHandler { return &PPTHandler{pool: pool, rdb: rdb, workerURL: workerURL} } // ==================== 请求/响应结构 ==================== type createPPTRequest struct { Title string `json:"title"` SourceType string `json:"source_type"` // text / url SourceContent string `json:"source_content"` // 文本内容或 URL Config map[string]any `json:"config"` } type pptTaskResponse struct { TaskID string `json:"task_id"` Status string `json:"status"` Progress int `json:"progress"` StatusMessage *string `json:"status_message,omitempty"` ErrorMessage *string `json:"error_message,omitempty"` OutputFile *string `json:"output_file,omitempty"` PageCount *int `json:"page_count,omitempty"` Title string `json:"title"` CreatedAt string `json:"created_at"` } // ==================== 接口实现 ==================== // CreateTask 创建 PPT 生成任务(文本/URL 输入) func (h *PPTHandler) CreateTask(w http.ResponseWriter, r *http.Request) { userID := middleware.GetUserID(r.Context()) var req createPPTRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.Title == "" { response.BadRequest(w, "标题不能为空") return } if req.SourceType == "" { req.SourceType = "text" } if req.SourceContent == "" { response.BadRequest(w, "请提供源内容") return } taskID := uuid.New().String() configJSON, _ := json.Marshal(req.Config) // 写入数据库 _, err := h.pool.Exec(r.Context(), `INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config) VALUES ($1, $2, $3, $4, $5, $6)`, taskID, userID, req.Title, req.SourceType, req.SourceContent, configJSON, ) if err != nil { response.InternalError(w, "创建任务失败: "+err.Error()) return } // 推送到 Redis 队列 taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID}) h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg) response.JSON(w, http.StatusCreated, map[string]string{ "task_id": taskID, "status": "pending", }) } // CreateTaskWithFile 创建带文件上传的 PPT 生成任务 func (h *PPTHandler) CreateTaskWithFile(w http.ResponseWriter, r *http.Request) { userID := middleware.GetUserID(r.Context()) if err := r.ParseMultipartForm(50 << 20); err != nil { // 50MB 限制 response.BadRequest(w, "文件过大或请求格式错误") return } title := r.FormValue("title") if title == "" { response.BadRequest(w, "标题不能为空") return } configStr := r.FormValue("config") var taskConfig map[string]any if configStr != "" { json.Unmarshal([]byte(configStr), &taskConfig) } if taskConfig == nil { taskConfig = map[string]any{} } file, header, err := r.FormFile("file") if err != nil { response.BadRequest(w, "请上传文件") return } defer file.Close() // 将文件转发到 PPT Worker taskID := uuid.New().String() configJSON, _ := json.Marshal(taskConfig) // 转发文件到 Worker 服务 err = h.forwardFileToWorker(r.Context(), taskID, userID.String(), title, string(configJSON), file, header) if err != nil { response.InternalError(w, "提交任务失败: "+err.Error()) return } response.JSON(w, http.StatusCreated, map[string]string{ "task_id": taskID, "status": "pending", }) } // GetTaskStatus 查询任务状态 func (h *PPTHandler) GetTaskStatus(w http.ResponseWriter, r *http.Request) { taskID := chi.URLParam(r, "taskId") userID := middleware.GetUserID(r.Context()) // 先从 Redis 快速查询 key := "ppt:status:" + taskID cached, err := h.rdb.HGetAll(r.Context(), key).Result() if err == nil && len(cached) > 0 { progress := 0 fmt.Sscanf(cached["progress"], "%d", &progress) msg := cached["message"] response.JSON(w, http.StatusOK, pptTaskResponse{ TaskID: taskID, Status: cached["status"], Progress: progress, StatusMessage: &msg, }) return } // 回退到数据库 var task pptTaskResponse var statusMsg, errMsg, outputFile *string var pageCount *int var createdAt time.Time err = h.pool.QueryRow(r.Context(), `SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID, ).Scan(&task.TaskID, &task.Status, &task.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &task.Title, &createdAt) if err != nil { response.NotFound(w, "任务不存在") return } task.StatusMessage = statusMsg task.ErrorMessage = errMsg task.OutputFile = outputFile task.PageCount = pageCount task.CreatedAt = createdAt.Format(time.RFC3339) response.JSON(w, http.StatusOK, task) } // ListTasks 列出用户的 PPT 任务 func (h *PPTHandler) ListTasks(w http.ResponseWriter, r *http.Request) { userID := middleware.GetUserID(r.Context()) rows, err := h.pool.Query(r.Context(), `SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at FROM ppt_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`, userID, ) if err != nil { response.InternalError(w, "查询失败") return } defer rows.Close() var tasks []pptTaskResponse for rows.Next() { var t pptTaskResponse var statusMsg, errMsg, outputFile *string var pageCount *int var createdAt time.Time if err := rows.Scan(&t.TaskID, &t.Status, &t.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &t.Title, &createdAt); err != nil { continue } t.StatusMessage = statusMsg t.ErrorMessage = errMsg t.OutputFile = outputFile t.PageCount = pageCount t.CreatedAt = createdAt.Format(time.RFC3339) tasks = append(tasks, t) } if tasks == nil { tasks = []pptTaskResponse{} } response.JSON(w, http.StatusOK, tasks) } // DownloadTask 下载生成的 PPTX 文件 func (h *PPTHandler) DownloadTask(w http.ResponseWriter, r *http.Request) { taskID := chi.URLParam(r, "taskId") userID := middleware.GetUserID(r.Context()) var status, title string var outputFile *string err := h.pool.QueryRow(r.Context(), `SELECT status, title, output_file FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID, ).Scan(&status, &title, &outputFile) if err != nil { response.NotFound(w, "任务不存在") return } if status != "completed" { response.BadRequest(w, "任务未完成") return } // 代理下载请求到 Worker 服务 workerResp, err := http.Get(h.workerURL + "/api/tasks/" + taskID + "/download") if err != nil { response.InternalError(w, "下载服务不可用") return } defer workerResp.Body.Close() if workerResp.StatusCode != http.StatusOK { response.InternalError(w, "文件下载失败") return } filename := title + ".pptx" w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.presentationml.presentation") w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename)) io.Copy(w, workerResp.Body) } // ==================== 内部方法 ==================== func (h *PPTHandler) forwardFileToWorker(ctx context.Context, taskID, userID, title, configJSON string, file multipart.File, header *multipart.FileHeader) error { // 构造 multipart 请求转发到 Worker var buf bytes.Buffer writer := multipart.NewWriter(&buf) writer.WriteField("task_id", taskID) writer.WriteField("user_id", userID) writer.WriteField("title", title) writer.WriteField("config_json", configJSON) part, err := writer.CreateFormFile("file", filepath.Base(header.Filename)) if err != nil { return err } if _, err := io.Copy(part, file); err != nil { return err } writer.Close() req, err := http.NewRequestWithContext(ctx, "POST", h.workerURL+"/api/tasks/upload", &buf) if err != nil { return err } req.Header.Set("Content-Type", writer.FormDataContentType()) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return fmt.Errorf("worker 返回错误 %d: %s", resp.StatusCode, string(body)) } return nil }