Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,307 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/dify"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type ChatHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
dify *dify.Client
|
||||
}
|
||||
|
||||
func NewChatHandler(pool *pgxpool.Pool, difyClient *dify.Client) *ChatHandler {
|
||||
return &ChatHandler{pool: pool, dify: difyClient}
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
Inputs map[string]any `json:"inputs,omitempty"`
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Chat(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Message == "" {
|
||||
response.BadRequest(w, "消息不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// Get app's Dify API key
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
||||
appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在或未上架")
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Call Dify streaming chat
|
||||
difyReq := &dify.ChatRequest{
|
||||
Query: req.Message,
|
||||
Inputs: req.Inputs,
|
||||
ConversationID: req.ConversationID,
|
||||
User: userID.String(),
|
||||
ResponseMode: "streaming",
|
||||
}
|
||||
|
||||
body, err := h.dify.ChatStream(r.Context(), difyAPIKey, difyReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
// Stream SSE to client
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalError(w, "Streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var totalTokens int
|
||||
var modelName string
|
||||
var conversationID string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// Forward SSE event to client
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
// Parse for usage tracking
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err == nil {
|
||||
if cid, ok := event["conversation_id"].(string); ok && cid != "" {
|
||||
conversationID = cid
|
||||
}
|
||||
if event["event"] == "message_end" {
|
||||
if meta, ok := event["metadata"].(map[string]any); ok {
|
||||
if usage, ok := meta["usage"].(map[string]any); ok {
|
||||
if t, ok := usage["total_tokens"].(float64); ok {
|
||||
totalTokens = int(t)
|
||||
}
|
||||
if m, ok := usage["model"].(string); ok {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record usage asynchronously
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(), `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, model_name, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'web')`,
|
||||
appID, userID, conversationID, totalTokens, modelName, duration)
|
||||
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Completion(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Message == "" {
|
||||
response.BadRequest(w, "消息不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
||||
appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在或未上架")
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
difyReq := &dify.ChatRequest{
|
||||
Query: req.Message,
|
||||
Inputs: req.Inputs,
|
||||
ConversationID: req.ConversationID,
|
||||
User: userID.String(),
|
||||
ResponseMode: "blocking",
|
||||
}
|
||||
|
||||
result, err := h.dify.ChatBlocking(r.Context(), difyAPIKey, difyReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(), `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, 'web')`,
|
||||
appID, userID, result.ConversationID, 0, duration)
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Feedback(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Rating string `json:"rating"` // "like" | "dislike" | null
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.MessageID == "" {
|
||||
response.BadRequest(w, "message_id 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
feedbackReq := &dify.FeedbackRequest{
|
||||
Rating: req.Rating,
|
||||
User: userID.String(),
|
||||
}
|
||||
if err := h.dify.SubmitFeedback(r.Context(), difyAPIKey, req.MessageID, feedbackReq); err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "提交反馈失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已提交"})
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Conversations(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.dify.ListConversations(r.Context(), difyAPIKey, userID.String(), 20, "")
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "获取对话列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
convID := chi.URLParam(r, "convId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.dify.ListMessages(r.Context(), difyAPIKey, userID.String(), convID, 100, "")
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "获取消息列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
convID := chi.URLParam(r, "convId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.dify.DeleteConversation(r.Context(), difyAPIKey, userID.String(), convID); err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "删除对话失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
// Suppress unused import warnings
|
||||
var _ = io.EOF
|
||||
Reference in New Issue
Block a user