308 lines
8.3 KiB
Go
308 lines
8.3 KiB
Go
package handler
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/enterprise-ai-platform/server/internal/middleware"
|
|
"github.com/enterprise-ai-platform/server/internal/response"
|
|
"github.com/enterprise-ai-platform/server/pkg/dify"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type ChatHandler struct {
|
|
pool *pgxpool.Pool
|
|
dify *dify.Client
|
|
}
|
|
|
|
func NewChatHandler(pool *pgxpool.Pool, difyClient *dify.Client) *ChatHandler {
|
|
return &ChatHandler{pool: pool, dify: difyClient}
|
|
}
|
|
|
|
type chatRequest struct {
|
|
Message string `json:"message"`
|
|
ConversationID string `json:"conversation_id,omitempty"`
|
|
Inputs map[string]any `json:"inputs,omitempty"`
|
|
}
|
|
|
|
func (h *ChatHandler) Chat(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var req chatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
response.BadRequest(w, "无效的请求格式")
|
|
return
|
|
}
|
|
if req.Message == "" {
|
|
response.BadRequest(w, "消息不能为空")
|
|
return
|
|
}
|
|
|
|
// Get app's Dify API key
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
|
appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在或未上架")
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Call Dify streaming chat
|
|
difyReq := &dify.ChatRequest{
|
|
Query: req.Message,
|
|
Inputs: req.Inputs,
|
|
ConversationID: req.ConversationID,
|
|
User: userID.String(),
|
|
ResponseMode: "streaming",
|
|
}
|
|
|
|
body, err := h.dify.ChatStream(r.Context(), difyAPIKey, difyReq)
|
|
if err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
|
return
|
|
}
|
|
defer body.Close()
|
|
|
|
// Stream SSE to client
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
response.InternalError(w, "Streaming not supported")
|
|
return
|
|
}
|
|
|
|
scanner := bufio.NewScanner(body)
|
|
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
|
|
var totalTokens int
|
|
var modelName string
|
|
var conversationID string
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
break
|
|
}
|
|
|
|
// Forward SSE event to client
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
flusher.Flush()
|
|
|
|
// Parse for usage tracking
|
|
var event map[string]any
|
|
if err := json.Unmarshal([]byte(data), &event); err == nil {
|
|
if cid, ok := event["conversation_id"].(string); ok && cid != "" {
|
|
conversationID = cid
|
|
}
|
|
if event["event"] == "message_end" {
|
|
if meta, ok := event["metadata"].(map[string]any); ok {
|
|
if usage, ok := meta["usage"].(map[string]any); ok {
|
|
if t, ok := usage["total_tokens"].(float64); ok {
|
|
totalTokens = int(t)
|
|
}
|
|
if m, ok := usage["model"].(string); ok {
|
|
modelName = m
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Record usage asynchronously
|
|
duration := time.Since(startTime).Milliseconds()
|
|
go func() {
|
|
_, _ = h.pool.Exec(context.Background(), `
|
|
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, model_name, duration_ms, client_type)
|
|
VALUES ($1, $2, $3, $4, $5, $6, 'web')`,
|
|
appID, userID, conversationID, totalTokens, modelName, duration)
|
|
|
|
_, _ = h.pool.Exec(context.Background(),
|
|
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
|
}()
|
|
}
|
|
|
|
func (h *ChatHandler) Completion(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var req chatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
response.BadRequest(w, "无效的请求格式")
|
|
return
|
|
}
|
|
if req.Message == "" {
|
|
response.BadRequest(w, "消息不能为空")
|
|
return
|
|
}
|
|
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
|
appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在或未上架")
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
difyReq := &dify.ChatRequest{
|
|
Query: req.Message,
|
|
Inputs: req.Inputs,
|
|
ConversationID: req.ConversationID,
|
|
User: userID.String(),
|
|
ResponseMode: "blocking",
|
|
}
|
|
|
|
result, err := h.dify.ChatBlocking(r.Context(), difyAPIKey, difyReq)
|
|
if err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
|
return
|
|
}
|
|
|
|
duration := time.Since(startTime).Milliseconds()
|
|
go func() {
|
|
_, _ = h.pool.Exec(context.Background(), `
|
|
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, duration_ms, client_type)
|
|
VALUES ($1, $2, $3, $4, $5, 'web')`,
|
|
appID, userID, result.ConversationID, 0, duration)
|
|
_, _ = h.pool.Exec(context.Background(),
|
|
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
|
}()
|
|
|
|
response.JSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
func (h *ChatHandler) Feedback(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var req struct {
|
|
MessageID string `json:"message_id"`
|
|
Rating string `json:"rating"` // "like" | "dislike" | null
|
|
Comment string `json:"comment"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
response.BadRequest(w, "无效的请求格式")
|
|
return
|
|
}
|
|
if req.MessageID == "" {
|
|
response.BadRequest(w, "message_id 不能为空")
|
|
return
|
|
}
|
|
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在")
|
|
return
|
|
}
|
|
|
|
feedbackReq := &dify.FeedbackRequest{
|
|
Rating: req.Rating,
|
|
User: userID.String(),
|
|
}
|
|
if err := h.dify.SubmitFeedback(r.Context(), difyAPIKey, req.MessageID, feedbackReq); err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "提交反馈失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已提交"})
|
|
}
|
|
|
|
func (h *ChatHandler) Conversations(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在")
|
|
return
|
|
}
|
|
|
|
result, err := h.dify.ListConversations(r.Context(), difyAPIKey, userID.String(), 20, "")
|
|
if err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "获取对话列表失败")
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
func (h *ChatHandler) Messages(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
convID := chi.URLParam(r, "convId")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在")
|
|
return
|
|
}
|
|
|
|
result, err := h.dify.ListMessages(r.Context(), difyAPIKey, userID.String(), convID, 100, "")
|
|
if err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "获取消息列表失败")
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
func (h *ChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
convID := chi.URLParam(r, "convId")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var difyAPIKey string
|
|
err := h.pool.QueryRow(r.Context(),
|
|
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
|
).Scan(&difyAPIKey)
|
|
if err != nil || difyAPIKey == "" {
|
|
response.NotFound(w, "应用不存在")
|
|
return
|
|
}
|
|
|
|
if err := h.dify.DeleteConversation(r.Context(), difyAPIKey, userID.String(), convID); err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50201, "删除对话失败")
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, nil)
|
|
}
|
|
|
|
// Suppress unused import warnings
|
|
var _ = io.EOF
|