package handler import ( "bufio" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/enterprise-ai-platform/server/pkg/dify" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" ) type ChatHandler struct { pool *pgxpool.Pool dify *dify.Client } func NewChatHandler(pool *pgxpool.Pool, difyClient *dify.Client) *ChatHandler { return &ChatHandler{pool: pool, dify: difyClient} } type chatRequest struct { Message string `json:"message"` ConversationID string `json:"conversation_id,omitempty"` Inputs map[string]any `json:"inputs,omitempty"` } func (h *ChatHandler) Chat(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var req chatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.Message == "" { response.BadRequest(w, "消息不能为空") return } // Get app's Dify API key var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在或未上架") return } startTime := time.Now() // Call Dify streaming chat difyReq := &dify.ChatRequest{ Query: req.Message, Inputs: req.Inputs, ConversationID: req.ConversationID, User: userID.String(), ResponseMode: "streaming", } body, err := h.dify.ChatStream(r.Context(), difyAPIKey, difyReq) if err != nil { response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error()) return } defer body.Close() // Stream SSE to client w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") flusher, ok := w.(http.Flusher) if !ok { response.InternalError(w, "Streaming not supported") return } scanner := bufio.NewScanner(body) scanner.Buffer(make([]byte, 64*1024), 256*1024) var totalTokens int var modelName string var conversationID string for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() break } // Forward SSE event to client fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() // Parse for usage tracking var event map[string]any if err := json.Unmarshal([]byte(data), &event); err == nil { if cid, ok := event["conversation_id"].(string); ok && cid != "" { conversationID = cid } if event["event"] == "message_end" { if meta, ok := event["metadata"].(map[string]any); ok { if usage, ok := meta["usage"].(map[string]any); ok { if t, ok := usage["total_tokens"].(float64); ok { totalTokens = int(t) } if m, ok := usage["model"].(string); ok { modelName = m } } } } } } // Record usage asynchronously duration := time.Since(startTime).Milliseconds() go func() { _, _ = h.pool.Exec(context.Background(), ` INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, model_name, duration_ms, client_type) VALUES ($1, $2, $3, $4, $5, $6, 'web')`, appID, userID, conversationID, totalTokens, modelName, duration) _, _ = h.pool.Exec(context.Background(), `UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID) }() } func (h *ChatHandler) Completion(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var req chatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.Message == "" { response.BadRequest(w, "消息不能为空") return } var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在或未上架") return } startTime := time.Now() difyReq := &dify.ChatRequest{ Query: req.Message, Inputs: req.Inputs, ConversationID: req.ConversationID, User: userID.String(), ResponseMode: "blocking", } result, err := h.dify.ChatBlocking(r.Context(), difyAPIKey, difyReq) if err != nil { response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error()) return } duration := time.Since(startTime).Milliseconds() go func() { _, _ = h.pool.Exec(context.Background(), ` INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, duration_ms, client_type) VALUES ($1, $2, $3, $4, $5, 'web')`, appID, userID, result.ConversationID, 0, duration) _, _ = h.pool.Exec(context.Background(), `UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID) }() response.JSON(w, http.StatusOK, result) } func (h *ChatHandler) Feedback(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var req struct { MessageID string `json:"message_id"` Rating string `json:"rating"` // "like" | "dislike" | null Comment string `json:"comment"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.MessageID == "" { response.BadRequest(w, "message_id 不能为空") return } var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在") return } feedbackReq := &dify.FeedbackRequest{ Rating: req.Rating, User: userID.String(), } if err := h.dify.SubmitFeedback(r.Context(), difyAPIKey, req.MessageID, feedbackReq); err != nil { response.Error(w, http.StatusBadGateway, 50201, "提交反馈失败: "+err.Error()) return } response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已提交"}) } func (h *ChatHandler) Conversations(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在") return } result, err := h.dify.ListConversations(r.Context(), difyAPIKey, userID.String(), 20, "") if err != nil { response.Error(w, http.StatusBadGateway, 50201, "获取对话列表失败") return } response.JSON(w, http.StatusOK, result) } func (h *ChatHandler) Messages(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") convID := chi.URLParam(r, "convId") userID := middleware.GetUserID(r.Context()) var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在") return } result, err := h.dify.ListMessages(r.Context(), difyAPIKey, userID.String(), convID, 100, "") if err != nil { response.Error(w, http.StatusBadGateway, 50201, "获取消息列表失败") return } response.JSON(w, http.StatusOK, result) } func (h *ChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") convID := chi.URLParam(r, "convId") userID := middleware.GetUserID(r.Context()) var difyAPIKey string err := h.pool.QueryRow(r.Context(), `SELECT dify_api_key FROM applications WHERE id = $1`, appID, ).Scan(&difyAPIKey) if err != nil || difyAPIKey == "" { response.NotFound(w, "应用不存在") return } if err := h.dify.DeleteConversation(r.Context(), difyAPIKey, userID.String(), convID); err != nil { response.Error(w, http.StatusBadGateway, 50201, "删除对话失败") return } response.JSON(w, http.StatusOK, nil) } // Suppress unused import warnings var _ = io.EOF