312 lines
8.6 KiB
Go
312 lines
8.6 KiB
Go
package handler
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/enterprise-ai-platform/server/internal/middleware"
|
|
"github.com/enterprise-ai-platform/server/internal/response"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type PPTHandler struct {
|
|
pool *pgxpool.Pool
|
|
rdb *redis.Client
|
|
workerURL string
|
|
}
|
|
|
|
func NewPPTHandler(pool *pgxpool.Pool, rdb *redis.Client, workerURL string) *PPTHandler {
|
|
return &PPTHandler{pool: pool, rdb: rdb, workerURL: workerURL}
|
|
}
|
|
|
|
// ==================== 请求/响应结构 ====================
|
|
|
|
type createPPTRequest struct {
|
|
Title string `json:"title"`
|
|
SourceType string `json:"source_type"` // text / url
|
|
SourceContent string `json:"source_content"` // 文本内容或 URL
|
|
Config map[string]any `json:"config"`
|
|
}
|
|
|
|
type pptTaskResponse struct {
|
|
TaskID string `json:"task_id"`
|
|
Status string `json:"status"`
|
|
Progress int `json:"progress"`
|
|
StatusMessage *string `json:"status_message,omitempty"`
|
|
ErrorMessage *string `json:"error_message,omitempty"`
|
|
OutputFile *string `json:"output_file,omitempty"`
|
|
PageCount *int `json:"page_count,omitempty"`
|
|
Title string `json:"title"`
|
|
CreatedAt string `json:"created_at"`
|
|
}
|
|
|
|
// ==================== 接口实现 ====================
|
|
|
|
// CreateTask 创建 PPT 生成任务(文本/URL 输入)
|
|
func (h *PPTHandler) CreateTask(w http.ResponseWriter, r *http.Request) {
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var req createPPTRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
response.BadRequest(w, "无效的请求格式")
|
|
return
|
|
}
|
|
|
|
if req.Title == "" {
|
|
response.BadRequest(w, "标题不能为空")
|
|
return
|
|
}
|
|
if req.SourceType == "" {
|
|
req.SourceType = "text"
|
|
}
|
|
if req.SourceContent == "" {
|
|
response.BadRequest(w, "请提供源内容")
|
|
return
|
|
}
|
|
|
|
taskID := uuid.New().String()
|
|
configJSON, _ := json.Marshal(req.Config)
|
|
|
|
// 写入数据库
|
|
_, err := h.pool.Exec(r.Context(),
|
|
`INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
|
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
taskID, userID, req.Title, req.SourceType, req.SourceContent, configJSON,
|
|
)
|
|
if err != nil {
|
|
response.InternalError(w, "创建任务失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// 推送到 Redis 队列
|
|
taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID})
|
|
h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg)
|
|
|
|
response.JSON(w, http.StatusCreated, map[string]string{
|
|
"task_id": taskID,
|
|
"status": "pending",
|
|
})
|
|
}
|
|
|
|
// CreateTaskWithFile 创建带文件上传的 PPT 生成任务
|
|
func (h *PPTHandler) CreateTaskWithFile(w http.ResponseWriter, r *http.Request) {
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
if err := r.ParseMultipartForm(50 << 20); err != nil { // 50MB 限制
|
|
response.BadRequest(w, "文件过大或请求格式错误")
|
|
return
|
|
}
|
|
|
|
title := r.FormValue("title")
|
|
if title == "" {
|
|
response.BadRequest(w, "标题不能为空")
|
|
return
|
|
}
|
|
|
|
configStr := r.FormValue("config")
|
|
var taskConfig map[string]any
|
|
if configStr != "" {
|
|
json.Unmarshal([]byte(configStr), &taskConfig)
|
|
}
|
|
if taskConfig == nil {
|
|
taskConfig = map[string]any{}
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
response.BadRequest(w, "请上传文件")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
// 将文件转发到 PPT Worker
|
|
taskID := uuid.New().String()
|
|
configJSON, _ := json.Marshal(taskConfig)
|
|
|
|
// 转发文件到 Worker 服务
|
|
err = h.forwardFileToWorker(r.Context(), taskID, userID.String(), title, string(configJSON), file, header)
|
|
if err != nil {
|
|
response.InternalError(w, "提交任务失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusCreated, map[string]string{
|
|
"task_id": taskID,
|
|
"status": "pending",
|
|
})
|
|
}
|
|
|
|
// GetTaskStatus 查询任务状态
|
|
func (h *PPTHandler) GetTaskStatus(w http.ResponseWriter, r *http.Request) {
|
|
taskID := chi.URLParam(r, "taskId")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
// 先从 Redis 快速查询
|
|
key := "ppt:status:" + taskID
|
|
cached, err := h.rdb.HGetAll(r.Context(), key).Result()
|
|
if err == nil && len(cached) > 0 {
|
|
progress := 0
|
|
fmt.Sscanf(cached["progress"], "%d", &progress)
|
|
msg := cached["message"]
|
|
response.JSON(w, http.StatusOK, pptTaskResponse{
|
|
TaskID: taskID,
|
|
Status: cached["status"],
|
|
Progress: progress,
|
|
StatusMessage: &msg,
|
|
})
|
|
return
|
|
}
|
|
|
|
// 回退到数据库
|
|
var task pptTaskResponse
|
|
var statusMsg, errMsg, outputFile *string
|
|
var pageCount *int
|
|
var createdAt time.Time
|
|
|
|
err = h.pool.QueryRow(r.Context(),
|
|
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
|
FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID,
|
|
).Scan(&task.TaskID, &task.Status, &task.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &task.Title, &createdAt)
|
|
if err != nil {
|
|
response.NotFound(w, "任务不存在")
|
|
return
|
|
}
|
|
|
|
task.StatusMessage = statusMsg
|
|
task.ErrorMessage = errMsg
|
|
task.OutputFile = outputFile
|
|
task.PageCount = pageCount
|
|
task.CreatedAt = createdAt.Format(time.RFC3339)
|
|
|
|
response.JSON(w, http.StatusOK, task)
|
|
}
|
|
|
|
// ListTasks 列出用户的 PPT 任务
|
|
func (h *PPTHandler) ListTasks(w http.ResponseWriter, r *http.Request) {
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
rows, err := h.pool.Query(r.Context(),
|
|
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
|
FROM ppt_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`, userID,
|
|
)
|
|
if err != nil {
|
|
response.InternalError(w, "查询失败")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tasks []pptTaskResponse
|
|
for rows.Next() {
|
|
var t pptTaskResponse
|
|
var statusMsg, errMsg, outputFile *string
|
|
var pageCount *int
|
|
var createdAt time.Time
|
|
|
|
if err := rows.Scan(&t.TaskID, &t.Status, &t.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &t.Title, &createdAt); err != nil {
|
|
continue
|
|
}
|
|
t.StatusMessage = statusMsg
|
|
t.ErrorMessage = errMsg
|
|
t.OutputFile = outputFile
|
|
t.PageCount = pageCount
|
|
t.CreatedAt = createdAt.Format(time.RFC3339)
|
|
tasks = append(tasks, t)
|
|
}
|
|
|
|
if tasks == nil {
|
|
tasks = []pptTaskResponse{}
|
|
}
|
|
response.JSON(w, http.StatusOK, tasks)
|
|
}
|
|
|
|
// DownloadTask 下载生成的 PPTX 文件
|
|
func (h *PPTHandler) DownloadTask(w http.ResponseWriter, r *http.Request) {
|
|
taskID := chi.URLParam(r, "taskId")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var status, title string
|
|
var outputFile *string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT status, title, output_file FROM ppt_tasks WHERE id = $1 AND user_id = $2`,
|
|
taskID, userID,
|
|
).Scan(&status, &title, &outputFile)
|
|
if err != nil {
|
|
response.NotFound(w, "任务不存在")
|
|
return
|
|
}
|
|
|
|
if status != "completed" {
|
|
response.BadRequest(w, "任务未完成")
|
|
return
|
|
}
|
|
|
|
// 代理下载请求到 Worker 服务
|
|
workerResp, err := http.Get(h.workerURL + "/api/tasks/" + taskID + "/download")
|
|
if err != nil {
|
|
response.InternalError(w, "下载服务不可用")
|
|
return
|
|
}
|
|
defer workerResp.Body.Close()
|
|
|
|
if workerResp.StatusCode != http.StatusOK {
|
|
response.InternalError(w, "文件下载失败")
|
|
return
|
|
}
|
|
|
|
filename := title + ".pptx"
|
|
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
|
|
io.Copy(w, workerResp.Body)
|
|
}
|
|
|
|
// ==================== 内部方法 ====================
|
|
|
|
func (h *PPTHandler) forwardFileToWorker(ctx context.Context, taskID, userID, title, configJSON string, file multipart.File, header *multipart.FileHeader) error {
|
|
// 构造 multipart 请求转发到 Worker
|
|
var buf bytes.Buffer
|
|
writer := multipart.NewWriter(&buf)
|
|
|
|
writer.WriteField("task_id", taskID)
|
|
writer.WriteField("user_id", userID)
|
|
writer.WriteField("title", title)
|
|
writer.WriteField("config_json", configJSON)
|
|
|
|
part, err := writer.CreateFormFile("file", filepath.Base(header.Filename))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(part, file); err != nil {
|
|
return err
|
|
}
|
|
writer.Close()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", h.workerURL+"/api/tasks/upload", &buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("worker 返回错误 %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
return nil
|
|
}
|