Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,614 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type AdminHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewAdminHandler(pool *pgxpool.Pool) *AdminHandler {
|
||||
return &AdminHandler{pool: pool}
|
||||
}
|
||||
|
||||
// getUserOrgID 通过当前登录用户ID查询其所属机构ID,用于多租户数据隔离
|
||||
func (h *AdminHandler) getUserOrgID(r *http.Request) (string, error) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
var orgID string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT COALESCE(org_id::text, '') FROM users WHERE id = $1`, userID).Scan(&orgID)
|
||||
return orgID, err
|
||||
}
|
||||
|
||||
// --- Overview Stats ---
|
||||
|
||||
func (h *AdminHandler) Overview(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, err := h.getUserOrgID(r)
|
||||
if err != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
|
||||
var totalUsers, totalApps, activeUsers, todayConversations int
|
||||
var monthlyTokens int64
|
||||
var monthlyCost float64
|
||||
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE status = 'active' AND org_id = $1`, orgID).Scan(&totalUsers)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE status = 'approved' AND org_id = $1`, orgID).Scan(&totalApps)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(DISTINCT l.user_id) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, todayStart, orgID).Scan(&activeUsers)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, todayStart, orgID).Scan(&todayConversations)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(l.total_tokens), 0) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, monthStart, orgID).Scan(&monthlyTokens)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(l.estimated_cost), 0) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, monthStart, orgID).Scan(&monthlyCost)
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"total_users": totalUsers,
|
||||
"total_apps": totalApps,
|
||||
"active_users": activeUsers,
|
||||
"total_conversations": todayConversations,
|
||||
"monthly_tokens": monthlyTokens,
|
||||
"monthly_cost": monthlyCost,
|
||||
})
|
||||
}
|
||||
|
||||
// --- User Management ---
|
||||
|
||||
func (h *AdminHandler) ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, err := h.getUserOrgID(r)
|
||||
if err != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := 20
|
||||
offset := (page - 1) * pageSize
|
||||
search := q.Get("q")
|
||||
roleFilter := q.Get("role")
|
||||
statusFilter := q.Get("status")
|
||||
|
||||
query := `SELECT id, name, email, avatar_url, role, status, employee_id, last_login_at, login_count, created_at
|
||||
FROM users WHERE org_id = $1`
|
||||
args := []any{orgID}
|
||||
argIdx := 2
|
||||
|
||||
if search != "" {
|
||||
query += ` AND (name ILIKE '%' || $` + strconv.Itoa(argIdx) + ` || '%' OR email ILIKE '%' || $` + strconv.Itoa(argIdx) + ` || '%')`
|
||||
args = append(args, search)
|
||||
argIdx++
|
||||
}
|
||||
if roleFilter != "" {
|
||||
query += ` AND role = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, roleFilter)
|
||||
argIdx++
|
||||
}
|
||||
if statusFilter != "" {
|
||||
query += ` AND status = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, statusFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += ` ORDER BY created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询用户失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, email, role, status string
|
||||
avatarURL, employeeID *string
|
||||
lastLoginAt *time.Time
|
||||
loginCount int
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &email, &avatarURL, &role, &status, &employeeID, &lastLoginAt, &loginCount, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, map[string]any{
|
||||
"id": id, "name": name, "email": email,
|
||||
"avatar_url": avatarURL, "role": role, "status": status,
|
||||
"employee_id": employeeID, "last_login_at": lastLoginAt,
|
||||
"login_count": loginCount, "created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if users == nil {
|
||||
users = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"items": users, "page": page})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "id")
|
||||
operatorRole := middleware.GetRole(r.Context())
|
||||
|
||||
var req struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
validRoles := map[string]bool{"user": true, "creator": true, "admin": true, "super_admin": true}
|
||||
if !validRoles[req.Role] {
|
||||
response.BadRequest(w, "无效的角色")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Role == "super_admin" && operatorRole != "super_admin" {
|
||||
response.Forbidden(w, "只有超级管理员才能设置超级管理员角色")
|
||||
return
|
||||
}
|
||||
|
||||
// 确保只能修改本机构用户
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE users SET role = $2 WHERE id = $1 AND org_id = $3`, userID, req.Role, orgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新角色失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) UpdateUserStatus(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "id")
|
||||
|
||||
var req struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Status != "active" && req.Status != "disabled" {
|
||||
response.BadRequest(w, "无效的状态")
|
||||
return
|
||||
}
|
||||
|
||||
// 确保只能修改本机构用户
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE users SET status = $2 WHERE id = $1 AND org_id = $3`, userID, req.Status, orgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新状态失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
// --- App Management ---
|
||||
|
||||
func (h *AdminHandler) ListAllApps(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, err := h.getUserOrgID(r)
|
||||
if err != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * 20
|
||||
statusFilter := q.Get("status")
|
||||
|
||||
query := `SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, u.name as creator_name,
|
||||
a.dify_app_type, a.status, a.visibility, a.usage_count, a.created_at
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.org_id = $1`
|
||||
args := []any{orgID}
|
||||
argIdx := 2
|
||||
|
||||
if statusFilter != "" {
|
||||
query += ` AND a.status = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, statusFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += ` ORDER BY a.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
args = append(args, 20, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询应用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var apps []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug, status, visibility string
|
||||
desc, iconURL, catName, creator *string
|
||||
appType *string
|
||||
usageCount int64
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &catName, &creator,
|
||||
&appType, &status, &visibility, &usageCount, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
apps = append(apps, map[string]any{
|
||||
"id": id, "name": name, "slug": slug, "description": desc,
|
||||
"icon_url": iconURL, "category_name": catName, "creator_name": creator,
|
||||
"dify_app_type": appType,
|
||||
"status": status, "visibility": visibility, "usage_count": usageCount,
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if apps == nil {
|
||||
apps = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"items": apps, "page": page})
|
||||
}
|
||||
|
||||
// --- Audit Logs ---
|
||||
|
||||
func (h *AdminHandler) ListAuditLogs(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * 50
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT al.id, al.action, al.resource_type, al.resource_id,
|
||||
al.details, al.ip_address, al.created_at,
|
||||
u.name as user_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
WHERE u.org_id = $2
|
||||
ORDER BY al.created_at DESC
|
||||
LIMIT 50 OFFSET $1`, offset, orgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询审计日志失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, action, resType string
|
||||
resID *string
|
||||
details json.RawMessage
|
||||
ipAddr *string
|
||||
createdAt time.Time
|
||||
userName *string
|
||||
)
|
||||
if err := rows.Scan(&id, &action, &resType, &resID, &details, &ipAddr, &createdAt, &userName); err != nil {
|
||||
continue
|
||||
}
|
||||
logs = append(logs, map[string]any{
|
||||
"id": id, "action": action, "resource_type": resType,
|
||||
"resource_id": resID, "details": details,
|
||||
"ip_address": ipAddr, "created_at": createdAt, "user_name": userName,
|
||||
})
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"items": logs, "page": page})
|
||||
}
|
||||
|
||||
// --- Usage Analytics ---
|
||||
|
||||
func (h *AdminHandler) UsageAnalytics(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
days := 7
|
||||
if d, err := strconv.Atoi(r.URL.Query().Get("days")); err == nil && d > 0 && d <= 90 {
|
||||
days = d
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT DATE(l.created_at) as day, COUNT(*) as count, COALESCE(SUM(l.total_tokens), 0) as tokens
|
||||
FROM app_usage_logs l
|
||||
JOIN users u ON l.user_id = u.id
|
||||
WHERE l.created_at >= NOW() - $1::interval AND u.org_id = $2
|
||||
GROUP BY DATE(l.created_at)
|
||||
ORDER BY day`, strconv.Itoa(days)+" days", orgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询使用统计失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dailyStats []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
day time.Time
|
||||
count int
|
||||
tokens int64
|
||||
)
|
||||
if err := rows.Scan(&day, &count, &tokens); err != nil {
|
||||
continue
|
||||
}
|
||||
dailyStats = append(dailyStats, map[string]any{
|
||||
"date": day.Format("2006-01-02"),
|
||||
"count": count,
|
||||
"total_tokens": tokens,
|
||||
})
|
||||
}
|
||||
if dailyStats == nil {
|
||||
dailyStats = []map[string]any{}
|
||||
}
|
||||
|
||||
// Top apps
|
||||
topRows, err := h.pool.Query(r.Context(), `
|
||||
SELECT a.name, COUNT(l.id) as usage_count
|
||||
FROM app_usage_logs l
|
||||
JOIN applications a ON l.app_id = a.id
|
||||
WHERE l.created_at >= NOW() - $1::interval AND a.org_id = $2
|
||||
GROUP BY a.name
|
||||
ORDER BY usage_count DESC LIMIT 10`, strconv.Itoa(days)+" days", orgID)
|
||||
if err != nil {
|
||||
response.JSON(w, http.StatusOK, map[string]any{"daily": dailyStats})
|
||||
return
|
||||
}
|
||||
defer topRows.Close()
|
||||
|
||||
var topApps []map[string]any
|
||||
for topRows.Next() {
|
||||
var name string
|
||||
var count int
|
||||
if err := topRows.Scan(&name, &count); err != nil {
|
||||
continue
|
||||
}
|
||||
topApps = append(topApps, map[string]any{"name": name, "count": count})
|
||||
}
|
||||
if topApps == nil {
|
||||
topApps = []map[string]any{}
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"daily": dailyStats,
|
||||
"top_apps": topApps,
|
||||
})
|
||||
}
|
||||
|
||||
// --- Review Management ---
|
||||
|
||||
func (h *AdminHandler) ListPendingReviews(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT r.id, r.app_id, r.version, r.submit_comment, r.submitted_at,
|
||||
a.name as app_name, a.description as app_description, a.icon_url,
|
||||
u.name as submitter_name
|
||||
FROM app_reviews r
|
||||
JOIN applications a ON r.app_id = a.id
|
||||
JOIN users u ON r.submitter_id = u.id
|
||||
WHERE r.status = 'pending' AND a.org_id = $1
|
||||
ORDER BY r.submitted_at ASC LIMIT 50`, orgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询审核列表失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reviews []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, appID, version string
|
||||
comment *string
|
||||
submittedAt time.Time
|
||||
appName, appDesc *string
|
||||
appIcon *string
|
||||
submitterName string
|
||||
)
|
||||
if err := rows.Scan(&id, &appID, &version, &comment, &submittedAt,
|
||||
&appName, &appDesc, &appIcon, &submitterName); err != nil {
|
||||
continue
|
||||
}
|
||||
reviews = append(reviews, map[string]any{
|
||||
"id": id, "app_id": appID, "version": version,
|
||||
"submit_comment": comment, "submitted_at": submittedAt,
|
||||
"app_name": appName, "app_description": appDesc, "app_icon": appIcon,
|
||||
"submitter_name": submitterName,
|
||||
})
|
||||
}
|
||||
if reviews == nil {
|
||||
reviews = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, reviews)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) ApproveReview(w http.ResponseWriter, r *http.Request) {
|
||||
reviewID := chi.URLParam(r, "id")
|
||||
reviewerID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get app_id from review
|
||||
var appID string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT app_id FROM app_reviews WHERE id = $1 AND status = 'pending'`, reviewID).Scan(&appID)
|
||||
if err != nil {
|
||||
response.NotFound(w, "审核记录不存在或已处理")
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := h.pool.Begin(r.Context())
|
||||
if err != nil {
|
||||
response.InternalError(w, "事务开始失败")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE app_reviews SET status = 'approved', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
|
||||
WHERE id = $1`, reviewID, reviewerID, req.Comment)
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE applications SET status = 'approved', published_at = NOW()
|
||||
WHERE id = $1`, appID)
|
||||
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
response.InternalError(w, "审核通过失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) RejectReview(w http.ResponseWriter, r *http.Request) {
|
||||
reviewID := chi.URLParam(r, "id")
|
||||
reviewerID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Comment == "" {
|
||||
response.BadRequest(w, "驳回必须填写原因")
|
||||
return
|
||||
}
|
||||
|
||||
var appID string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT app_id FROM app_reviews WHERE id = $1 AND status = 'pending'`, reviewID).Scan(&appID)
|
||||
if err != nil {
|
||||
response.NotFound(w, "审核记录不存在或已处理")
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := h.pool.Begin(r.Context())
|
||||
if err != nil {
|
||||
response.InternalError(w, "事务开始失败")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE app_reviews SET status = 'rejected', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
|
||||
WHERE id = $1`, reviewID, reviewerID, req.Comment)
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE applications SET status = 'rejected'
|
||||
WHERE id = $1`, appID)
|
||||
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
response.InternalError(w, "驳回失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *AdminHandler) DelistApp(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
appID := chi.URLParam(r, "id")
|
||||
|
||||
var status string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status FROM applications WHERE id = $1 AND org_id = $2`, appID, orgID).Scan(&status)
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
if status != "approved" {
|
||||
response.BadRequest(w, "只有已上架的应用可以撤架")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "撤架失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已撤架"})
|
||||
}
|
||||
|
||||
func (h *AdminHandler) RelistApp(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, orgErr := h.getUserOrgID(r)
|
||||
if orgErr != nil || orgID == "" {
|
||||
response.InternalError(w, "无法确定当前机构")
|
||||
return
|
||||
}
|
||||
|
||||
appID := chi.URLParam(r, "id")
|
||||
|
||||
var status string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status FROM applications WHERE id = $1 AND org_id = $2`, appID, orgID).Scan(&status)
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
if status != "archived" {
|
||||
response.BadRequest(w, "只有已归档的应用可以重新上架")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET status = 'approved', published_at = NOW() WHERE id = $1`, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "重新上架失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已重新上架"})
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/llm"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type AnalysisTemplateHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
manager *llm.Manager
|
||||
provider string
|
||||
}
|
||||
|
||||
func NewAnalysisTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *AnalysisTemplateHandler {
|
||||
return &AnalysisTemplateHandler{pool: pool, manager: manager, provider: provider}
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `SELECT id, name, report_type, COALESCE(description,''), COALESCE(icon,''), steps, sort_order
|
||||
FROM analysis_report_templates
|
||||
WHERE is_active = true`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (org_id = $1 OR org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY sort_order ASC`
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询模板失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var templates []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, reportType, desc, icon string
|
||||
var steps json.RawMessage
|
||||
var sortOrder int
|
||||
if err := rows.Scan(&id, &name, &reportType, &desc, &icon, &steps, &sortOrder); err != nil {
|
||||
continue
|
||||
}
|
||||
templates = append(templates, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"report_type": reportType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"steps": steps,
|
||||
"sort_order": sortOrder,
|
||||
})
|
||||
}
|
||||
if templates == nil {
|
||||
templates = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "templateId")
|
||||
|
||||
var name, reportType, desc, icon, promptTpl string
|
||||
var steps, outputSections json.RawMessage
|
||||
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, report_type, COALESCE(description,''), COALESCE(icon,''),
|
||||
steps, prompt_template, COALESCE(output_sections, '[]')
|
||||
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, id,
|
||||
).Scan(&name, &reportType, &desc, &icon, &steps, &promptTpl, &outputSections)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"report_type": reportType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"steps": steps,
|
||||
"prompt_template": promptTpl,
|
||||
"output_sections": outputSections,
|
||||
})
|
||||
}
|
||||
|
||||
type generateAnalysisRequest struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
FieldData map[string]string `json:"field_data"`
|
||||
}
|
||||
|
||||
func containsAny(list []string, targets ...string) bool {
|
||||
for _, item := range list {
|
||||
for _, t := range targets {
|
||||
if item == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) queryEconomicData(ctx context.Context, districts, timeRange string) string {
|
||||
years := []int{2025}
|
||||
switch timeRange {
|
||||
case "2024-2025年两年对比", "2024_2025":
|
||||
years = []int{2024, 2025}
|
||||
case "2023-2025年三年趋势", "2023_2025":
|
||||
years = []int{2023, 2024, 2025}
|
||||
case "2024年全年", "2024_full":
|
||||
years = []int{2024}
|
||||
}
|
||||
|
||||
districtList := strings.Split(districts, ",")
|
||||
for i := range districtList {
|
||||
districtList[i] = strings.TrimSpace(districtList[i])
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(ctx, `
|
||||
SELECT district_name, year,
|
||||
COALESCE(gdp,0), COALESCE(gdp_growth,0), COALESCE(gdp_per_capita,0),
|
||||
COALESCE(fiscal_revenue,0), COALESCE(fiscal_revenue_growth,0),
|
||||
COALESCE(fixed_investment,0), COALESCE(fixed_investment_growth,0),
|
||||
COALESCE(retail_sales,0), COALESCE(retail_sales_growth,0),
|
||||
COALESCE(industrial_output,0), COALESCE(industrial_output_growth,0),
|
||||
COALESCE(tertiary_ratio,0),
|
||||
COALESCE(import_export,0), COALESCE(import_export_growth,0),
|
||||
COALESCE(actual_fdi,0),
|
||||
COALESCE(tech_expenditure,0), COALESCE(tech_expenditure_ratio,0),
|
||||
COALESCE(population,0), COALESCE(urban_income,0), COALESCE(rural_income,0)
|
||||
FROM regional_economic_data
|
||||
WHERE year = ANY($1)
|
||||
AND (cardinality($2::text[]) = 0 OR district_name = ANY($2))
|
||||
ORDER BY year ASC, gdp DESC`, years, districtList)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\n\n## 【数据库真实数据】\n\n")
|
||||
sb.WriteString("| 区县 | 年份 | GDP(亿元) | GDP增速(%) | 人均GDP(元) | 财政收入(亿元) | 财政增速(%) | 固投(亿元) | 固投增速(%) | 社零(亿元) | 社零增速(%) | 规上工业(亿元) | 工业增速(%) | 三产占比(%) | 进出口(亿元) | 进出口增速(%) | 利用外资(亿美元) | R&D投入(亿元) | R&D/GDP(%) | 人口(万) | 城镇收入(元) | 农村收入(元) |\n")
|
||||
sb.WriteString("|------|------|-----------|-----------|------------|-------------|-----------|----------|-----------|----------|-----------|-------------|-----------|-----------|-----------|------------|-------------|------------|----------|----------|-----------|----------|\n")
|
||||
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var year int
|
||||
var gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri float64
|
||||
if err := rows.Scan(&name, &year, &gdp, &gdpG, &gdpPC, &fr, &frG, &fi, &fiG, &rs, &rsG, &io, &ioG, &tr, &ie, &ieG, &fdi, &te, &teR, &pop, &ui, &ri); err != nil {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("| %s | %d | %.1f | %.1f | %.0f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.2f | %.1f | %.2f | %.1f | %.0f | %.0f |\n",
|
||||
name, year, gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) GenerateReport(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req generateAnalysisRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效请求格式")
|
||||
return
|
||||
}
|
||||
if req.TemplateID == "" {
|
||||
response.BadRequest(w, "请选择报告模板")
|
||||
return
|
||||
}
|
||||
|
||||
var appModel string
|
||||
var appMaxTokens int
|
||||
var appConfigRaw json.RawMessage
|
||||
_ = h.pool.QueryRow(r.Context(), `
|
||||
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 8192), COALESCE(app_config, '{}')
|
||||
FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&appModel, &appMaxTokens, &appConfigRaw)
|
||||
|
||||
var appCfg struct {
|
||||
Tools []string `json:"tools"`
|
||||
DataSources []string `json:"data_sources"`
|
||||
TemplateSet string `json:"template_set"`
|
||||
}
|
||||
_ = json.Unmarshal(appConfigRaw, &appCfg)
|
||||
|
||||
var promptTpl, tplName, reportType string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, prompt_template, report_type
|
||||
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
|
||||
).Scan(&tplName, &promptTpl, &reportType)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
prompt := promptTpl
|
||||
for k, v := range req.FieldData {
|
||||
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
|
||||
}
|
||||
prompt = strings.ReplaceAll(prompt, "{{", "")
|
||||
prompt = strings.ReplaceAll(prompt, "}}", "")
|
||||
|
||||
hasDataTool := containsAny(appCfg.Tools, "economic_data_query", "data_analysis")
|
||||
hasDataSource := containsAny(appCfg.DataSources, "regional_economic_data", "statistical_yearbook")
|
||||
useEconomicData := hasDataTool || hasDataSource || reportType == "economic_comparison" || reportType == "trend_analysis"
|
||||
|
||||
if useEconomicData {
|
||||
dataTable := h.queryEconomicData(r.Context(), req.FieldData["districts"], req.FieldData["time_range"])
|
||||
if dataTable != "" {
|
||||
prompt += "\n\n⚠️ 以下是从数据库中检索到的真实经济数据,请基于这些真实数据进行分析,不要编造数据:" + dataTable
|
||||
}
|
||||
}
|
||||
|
||||
llmReq := &llm.ChatRequest{
|
||||
Model: appModel,
|
||||
Messages: []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: "你是一位资深的政务数据分析和研判专家。要求:1)基于提供的真实数据库数据进行分析 2)报告必须完整输出到最后一个章节,不能中途截断 3)保持精炼,每个章节500字以内 4)使用Markdown格式排版 5)表格数据必须引用真实数据 6)报告输出完毕后,必须在最后一行单独输出「---\\n\\n**【完稿】**」作为完成标记"},
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
},
|
||||
Temperature: 0.4,
|
||||
MaxTokens: appMaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalError(w, "Streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
convID := uuid.New().String()
|
||||
msgID := uuid.New().String()
|
||||
var fullResponse strings.Builder
|
||||
|
||||
firstEvent := map[string]string{
|
||||
"conversation_id": convID,
|
||||
"message_id": msgID,
|
||||
"template_name": tplName,
|
||||
}
|
||||
data, _ := json.Marshal(firstEvent)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
transform := llm.TransformOpenAIStream
|
||||
if h.provider == "anthropic" {
|
||||
transform = llm.TransformAnthropicStream
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
var modelName string
|
||||
|
||||
_ = transform(body, func(event llm.StreamEvent) {
|
||||
if event.Answer != "" {
|
||||
fullResponse.WriteString(event.Answer)
|
||||
}
|
||||
if event.Usage != nil {
|
||||
totalTokens = event.Usage.TotalTokens
|
||||
modelName = event.Usage.Model
|
||||
}
|
||||
data, _ := json.Marshal(event)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
})
|
||||
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
userMsg := fmt.Sprintf("[研判报告] %s", tplName)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
_, _ = h.pool.Exec(ctx, `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
|
||||
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
|
||||
_, _ = h.pool.Exec(ctx,
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/auth"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
jwtMgr *auth.JWTManager
|
||||
}
|
||||
|
||||
func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler {
|
||||
return &AuthHandler{pool: pool, jwtMgr: jwtMgr}
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
|
||||
type orgInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
ShortName string `json:"short_name"`
|
||||
}
|
||||
|
||||
type userResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Role string `json:"role"`
|
||||
EmployeeID *string `json:"employee_id"`
|
||||
OrgID *string `json:"org_id"`
|
||||
Org *orgInfo `json:"org,omitempty"`
|
||||
}
|
||||
|
||||
type registerRequest struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || req.Email == "" || req.Password == "" {
|
||||
response.BadRequest(w, "姓名、邮箱和密码不能为空")
|
||||
return
|
||||
}
|
||||
if len(req.Password) < 6 {
|
||||
response.BadRequest(w, "密码长度不能少于6位")
|
||||
return
|
||||
}
|
||||
|
||||
var exists bool
|
||||
h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists)
|
||||
if exists {
|
||||
response.Error(w, http.StatusConflict, 40901, "该邮箱已注册")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
response.InternalError(w, "密码加密失败")
|
||||
return
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`,
|
||||
id, req.Name, req.Email, hash)
|
||||
if err != nil {
|
||||
response.InternalError(w, "注册失败")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user")
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]any{
|
||||
"user": userResponse{
|
||||
ID: id.String(), Name: req.Name, Email: req.Email, Role: "user",
|
||||
},
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Password == "" {
|
||||
response.BadRequest(w, "邮箱和密码不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
id string
|
||||
name string
|
||||
email string
|
||||
passwordHash *string
|
||||
avatarURL *string
|
||||
role string
|
||||
employeeID *string
|
||||
status string
|
||||
orgID *string
|
||||
)
|
||||
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text
|
||||
FROM users WHERE email = $1`, req.Email,
|
||||
).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID)
|
||||
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "邮箱或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
if status != "active" {
|
||||
response.Error(w, http.StatusForbidden, 40302, "账号已被禁用")
|
||||
return
|
||||
}
|
||||
|
||||
if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) {
|
||||
response.Unauthorized(w, "邮箱或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 平台管理员不绑定机构,可登录任意机构入口
|
||||
// 普通用户/机构管理员必须属于所选机构
|
||||
if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID {
|
||||
response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构")
|
||||
return
|
||||
}
|
||||
|
||||
uid, _ := uuid.Parse(id)
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
// Update login info
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id)
|
||||
}()
|
||||
|
||||
// Set cookies
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: tokenPair.RefreshToken,
|
||||
Path: "/api/v1/auth/refresh",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(7 * 24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
usr := userResponse{
|
||||
ID: id, Name: name, Email: email,
|
||||
AvatarURL: avatarURL, Role: role, EmployeeID: employeeID,
|
||||
OrgID: orgID,
|
||||
}
|
||||
if orgID != nil {
|
||||
usr.Org = h.loadOrgInfo(r.Context(), *orgID)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"user": usr,
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
MaxAge: -1,
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/api/v1/auth/refresh",
|
||||
HttpOnly: true,
|
||||
MaxAge: -1,
|
||||
})
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var u userResponse
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, name, email, avatar_url, role, employee_id, org_id::text
|
||||
FROM users WHERE id = $1`, userID,
|
||||
).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID)
|
||||
|
||||
if err != nil {
|
||||
response.NotFound(w, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if u.OrgID != nil {
|
||||
u.Org = h.loadOrgInfo(r.Context(), *u.OrgID)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, u)
|
||||
}
|
||||
|
||||
// ListOrganizations 返回所有可用机构列表
|
||||
func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'')
|
||||
FROM organizations WHERE is_active = true ORDER BY sort_order ASC`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询机构列表失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var orgs []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, slug, shortName, desc, logo string
|
||||
if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil {
|
||||
continue
|
||||
}
|
||||
orgs = append(orgs, map[string]any{
|
||||
"id": id, "name": name, "slug": slug,
|
||||
"short_name": shortName, "description": desc, "logo_url": logo,
|
||||
})
|
||||
}
|
||||
if orgs == nil {
|
||||
orgs = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, orgs)
|
||||
}
|
||||
|
||||
// SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份
|
||||
func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
currentRole := middleware.GetRole(r.Context())
|
||||
if currentRole != "super_admin" && currentRole != "admin" {
|
||||
response.Forbidden(w, "仅管理员可以切换机构")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" {
|
||||
response.BadRequest(w, "请选择机构")
|
||||
return
|
||||
}
|
||||
|
||||
// 校验机构是否存在
|
||||
var exists bool
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists)
|
||||
if !exists {
|
||||
response.NotFound(w, "机构不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator)
|
||||
var targetID uuid.UUID
|
||||
var targetEmail, targetRole, targetName string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, email, role, name FROM users
|
||||
WHERE org_id = $1 AND status = 'active'
|
||||
ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END
|
||||
LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName)
|
||||
if err != nil {
|
||||
response.InternalError(w, "该机构暂无可用用户")
|
||||
return
|
||||
}
|
||||
|
||||
// 为目标用户生成新的JWT token
|
||||
tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成令牌失败")
|
||||
return
|
||||
}
|
||||
|
||||
org := h.loadOrgInfo(r.Context(), req.OrgID)
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"message": "已切换",
|
||||
"org": org,
|
||||
"token": tokens.AccessToken,
|
||||
"user": map[string]any{
|
||||
"id": targetID,
|
||||
"name": targetName,
|
||||
"email": targetEmail,
|
||||
"role": targetRole,
|
||||
"org_id": req.OrgID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo {
|
||||
var o orgInfo
|
||||
err := h.pool.QueryRow(ctx,
|
||||
`SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID,
|
||||
).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &o
|
||||
}
|
||||
|
||||
func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`UPDATE users SET name = COALESCE(NULLIF($2,''), name),
|
||||
avatar_url = COALESCE(NULLIF($3,''), avatar_url),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`, userID, req.Name, req.AvatarURL)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新个人信息失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "Refresh Token 不存在")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := h.jwtMgr.ValidateToken(cookie.Value)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// Get current role from DB
|
||||
var role string
|
||||
err = h.pool.QueryRow(r.Context(),
|
||||
`SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID,
|
||||
).Scan(&role)
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "用户不存在或已被禁用")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/dify"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type ChatHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
dify *dify.Client
|
||||
}
|
||||
|
||||
func NewChatHandler(pool *pgxpool.Pool, difyClient *dify.Client) *ChatHandler {
|
||||
return &ChatHandler{pool: pool, dify: difyClient}
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
Inputs map[string]any `json:"inputs,omitempty"`
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Chat(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Message == "" {
|
||||
response.BadRequest(w, "消息不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// Get app's Dify API key
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
||||
appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在或未上架")
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Call Dify streaming chat
|
||||
difyReq := &dify.ChatRequest{
|
||||
Query: req.Message,
|
||||
Inputs: req.Inputs,
|
||||
ConversationID: req.ConversationID,
|
||||
User: userID.String(),
|
||||
ResponseMode: "streaming",
|
||||
}
|
||||
|
||||
body, err := h.dify.ChatStream(r.Context(), difyAPIKey, difyReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
// Stream SSE to client
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalError(w, "Streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var totalTokens int
|
||||
var modelName string
|
||||
var conversationID string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// Forward SSE event to client
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
// Parse for usage tracking
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err == nil {
|
||||
if cid, ok := event["conversation_id"].(string); ok && cid != "" {
|
||||
conversationID = cid
|
||||
}
|
||||
if event["event"] == "message_end" {
|
||||
if meta, ok := event["metadata"].(map[string]any); ok {
|
||||
if usage, ok := meta["usage"].(map[string]any); ok {
|
||||
if t, ok := usage["total_tokens"].(float64); ok {
|
||||
totalTokens = int(t)
|
||||
}
|
||||
if m, ok := usage["model"].(string); ok {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record usage asynchronously
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(), `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, model_name, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'web')`,
|
||||
appID, userID, conversationID, totalTokens, modelName, duration)
|
||||
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Completion(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Message == "" {
|
||||
response.BadRequest(w, "消息不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
|
||||
appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在或未上架")
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
difyReq := &dify.ChatRequest{
|
||||
Query: req.Message,
|
||||
Inputs: req.Inputs,
|
||||
ConversationID: req.ConversationID,
|
||||
User: userID.String(),
|
||||
ResponseMode: "blocking",
|
||||
}
|
||||
|
||||
result, err := h.dify.ChatBlocking(r.Context(), difyAPIKey, difyReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(), `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, 'web')`,
|
||||
appID, userID, result.ConversationID, 0, duration)
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Feedback(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Rating string `json:"rating"` // "like" | "dislike" | null
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.MessageID == "" {
|
||||
response.BadRequest(w, "message_id 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
feedbackReq := &dify.FeedbackRequest{
|
||||
Rating: req.Rating,
|
||||
User: userID.String(),
|
||||
}
|
||||
if err := h.dify.SubmitFeedback(r.Context(), difyAPIKey, req.MessageID, feedbackReq); err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "提交反馈失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已提交"})
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Conversations(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.dify.ListConversations(r.Context(), difyAPIKey, userID.String(), 20, "")
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "获取对话列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
convID := chi.URLParam(r, "convId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.dify.ListMessages(r.Context(), difyAPIKey, userID.String(), convID, 100, "")
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "获取消息列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *ChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
convID := chi.URLParam(r, "convId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var difyAPIKey string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&difyAPIKey)
|
||||
if err != nil || difyAPIKey == "" {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.dify.DeleteConversation(r.Context(), difyAPIKey, userID.String(), convID); err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50201, "删除对话失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
// Suppress unused import warnings
|
||||
var _ = io.EOF
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,513 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/dify"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type CreatorHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
dify *dify.Client
|
||||
}
|
||||
|
||||
func NewCreatorHandler(pool *pgxpool.Pool, difyClient *dify.Client) *CreatorHandler {
|
||||
return &CreatorHandler{pool: pool, dify: difyClient}
|
||||
}
|
||||
|
||||
type createAppRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
LongDescription string `json:"long_description"`
|
||||
CategoryID string `json:"category_id"`
|
||||
Visibility string `json:"visibility"`
|
||||
AppType string `json:"app_type"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
WelcomeMessage string `json:"welcome_message"`
|
||||
SuggestedPrompts []string `json:"suggested_prompts"`
|
||||
Model string `json:"model"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
KnowledgeBaseIDs []string `json:"knowledge_base_ids"`
|
||||
// Agent-type config
|
||||
Tools []string `json:"tools"`
|
||||
DataSources []string `json:"data_sources"`
|
||||
TemplateSet string `json:"template_set"`
|
||||
// Completion-type config
|
||||
InputLabel string `json:"input_label"`
|
||||
OutputLabel string `json:"output_label"`
|
||||
InputPlaceholder string `json:"input_placeholder"`
|
||||
FormatTemplates map[string]any `json:"format_templates"`
|
||||
}
|
||||
|
||||
func buildAppConfig(req *createAppRequest) json.RawMessage {
|
||||
cfg := map[string]any{
|
||||
"system_prompt": req.SystemPrompt,
|
||||
"model": req.Model,
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
cfg["tools"] = req.Tools
|
||||
}
|
||||
if len(req.DataSources) > 0 {
|
||||
cfg["data_sources"] = req.DataSources
|
||||
}
|
||||
if req.TemplateSet != "" {
|
||||
cfg["template_set"] = req.TemplateSet
|
||||
}
|
||||
if req.InputLabel != "" {
|
||||
cfg["input_label"] = req.InputLabel
|
||||
}
|
||||
if req.OutputLabel != "" {
|
||||
cfg["output_label"] = req.OutputLabel
|
||||
}
|
||||
if req.InputPlaceholder != "" {
|
||||
cfg["input_placeholder"] = req.InputPlaceholder
|
||||
}
|
||||
if len(req.FormatTemplates) > 0 {
|
||||
cfg["format_templates"] = req.FormatTemplates
|
||||
}
|
||||
b, _ := json.Marshal(cfg)
|
||||
return b
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) ListMyApps(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
role := middleware.GetRole(r.Context())
|
||||
isAdmin := role == "admin" || role == "super_admin"
|
||||
|
||||
var rows pgx.Rows
|
||||
var err error
|
||||
if isAdmin {
|
||||
// 管理员查看本机构所有应用(通过用户表获取org_id)
|
||||
rows, err = h.pool.Query(r.Context(), `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, a.dify_app_type, a.status, a.visibility,
|
||||
a.usage_count, a.updated_at
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.org_id = (SELECT org_id FROM users WHERE id = $1)
|
||||
ORDER BY a.updated_at DESC`, userID)
|
||||
} else {
|
||||
rows, err = h.pool.Query(r.Context(), `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, a.dify_app_type, a.status, a.visibility,
|
||||
a.usage_count, a.updated_at
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.creator_id = $1
|
||||
ORDER BY a.updated_at DESC`, userID)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询应用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var apps []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug, status, visibility string
|
||||
desc, iconURL, catName, appType *string
|
||||
usageCount int64
|
||||
updatedAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &catName,
|
||||
&appType, &status, &visibility, &usageCount, &updatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
apps = append(apps, map[string]any{
|
||||
"id": id, "name": name, "slug": slug, "description": desc,
|
||||
"icon_url": iconURL, "category_name": catName,
|
||||
"dify_app_type": appType,
|
||||
"status": status, "visibility": visibility,
|
||||
"usage_count": usageCount, "updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
if apps == nil {
|
||||
apps = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) GetApp(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
role := middleware.GetRole(r.Context())
|
||||
isAdmin := role == "admin" || role == "super_admin"
|
||||
|
||||
var (
|
||||
id, name, slug, status, visibility, version string
|
||||
desc, longDesc, iconURL, catID, difyType *string
|
||||
welcomeMsg *string
|
||||
kbID *string
|
||||
appConfig json.RawMessage
|
||||
suggestedPrompts json.RawMessage
|
||||
maxTokens int
|
||||
temperature float32
|
||||
usageCount int64
|
||||
createdAt, updatedAt time.Time
|
||||
)
|
||||
|
||||
var query string
|
||||
var args []any
|
||||
if isAdmin {
|
||||
query = `SELECT a.id, a.name, a.slug, a.description, a.long_description,
|
||||
a.icon_url, a.category_id, a.dify_app_type,
|
||||
a.app_config, a.welcome_message, a.suggested_prompts,
|
||||
a.max_tokens, a.temperature, a.status, a.visibility,
|
||||
a.is_featured, a.usage_count, a.version,
|
||||
a.knowledge_base_id, a.created_at, a.updated_at
|
||||
FROM applications a WHERE a.id = $1`
|
||||
args = []any{appID}
|
||||
} else {
|
||||
query = `SELECT a.id, a.name, a.slug, a.description, a.long_description,
|
||||
a.icon_url, a.category_id, a.dify_app_type,
|
||||
a.app_config, a.welcome_message, a.suggested_prompts,
|
||||
a.max_tokens, a.temperature, a.status, a.visibility,
|
||||
a.is_featured, a.usage_count, a.version,
|
||||
a.knowledge_base_id, a.created_at, a.updated_at
|
||||
FROM applications a WHERE a.id = $1 AND a.creator_id = $2`
|
||||
args = []any{appID, userID}
|
||||
}
|
||||
|
||||
err := h.pool.QueryRow(r.Context(), query, args...).Scan(
|
||||
&id, &name, &slug, &desc, &longDesc,
|
||||
&iconURL, &catID, &difyType,
|
||||
&appConfig, &welcomeMsg, &suggestedPrompts,
|
||||
&maxTokens, &temperature, &status, &visibility,
|
||||
new(bool), &usageCount, &version,
|
||||
&kbID, &createdAt, &updatedAt)
|
||||
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在或无权访问")
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": id, "name": name, "slug": slug, "description": desc,
|
||||
"long_description": longDesc, "icon_url": iconURL,
|
||||
"category_id": catID, "dify_app_type": difyType,
|
||||
"app_config": appConfig, "welcome_message": welcomeMsg,
|
||||
"suggested_prompts": suggestedPrompts,
|
||||
"max_tokens": maxTokens, "temperature": temperature,
|
||||
"status": status, "visibility": visibility,
|
||||
"usage_count": usageCount, "version": version,
|
||||
"knowledge_base_id": kbID,
|
||||
"created_at": createdAt, "updated_at": updatedAt,
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) CreateApp(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req createAppRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
response.BadRequest(w, "应用名称不能为空")
|
||||
return
|
||||
}
|
||||
if req.Visibility == "" {
|
||||
req.Visibility = "private"
|
||||
}
|
||||
if req.AppType == "" {
|
||||
req.AppType = "chatbot"
|
||||
}
|
||||
if req.Temperature == 0 {
|
||||
req.Temperature = 0.7
|
||||
}
|
||||
if req.MaxTokens == 0 {
|
||||
req.MaxTokens = 4096
|
||||
}
|
||||
|
||||
slug := generateSlug(req.Name)
|
||||
suggestedPromptsJSON, _ := json.Marshal(req.SuggestedPrompts)
|
||||
appConfig := buildAppConfig(&req)
|
||||
|
||||
difyAppID := ""
|
||||
difyAPIKey := ""
|
||||
|
||||
var appID string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
INSERT INTO applications (
|
||||
name, slug, description, long_description, category_id, creator_id,
|
||||
dify_app_id, dify_app_type, dify_api_key,
|
||||
app_config, welcome_message, suggested_prompts,
|
||||
max_tokens, temperature, status, visibility
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, 'draft', $15)
|
||||
RETURNING id`,
|
||||
req.Name, slug, req.Description, req.LongDescription,
|
||||
nilIfEmpty(req.CategoryID), userID,
|
||||
nilIfEmpty(difyAppID), req.AppType, nilIfEmpty(difyAPIKey),
|
||||
appConfig, req.WelcomeMessage, string(suggestedPromptsJSON),
|
||||
req.MaxTokens, req.Temperature, req.Visibility,
|
||||
).Scan(&appID)
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "duplicate key") {
|
||||
response.Error(w, http.StatusConflict, 40901, "应用名称已存在")
|
||||
return
|
||||
}
|
||||
response.InternalError(w, "创建应用失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]any{
|
||||
"id": appID,
|
||||
"name": req.Name,
|
||||
"slug": slug,
|
||||
"status": "draft",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) UpdateApp(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
role := middleware.GetRole(r.Context())
|
||||
isAdmin := role == "admin" || role == "super_admin"
|
||||
|
||||
var status string
|
||||
var creatorID string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status, creator_id FROM applications WHERE id = $1`, appID).Scan(&status, &creatorID)
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
if !isAdmin && creatorID != userID.String() {
|
||||
response.Forbidden(w, "只能修改自己创建的应用")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
createAppRequest
|
||||
KnowledgeBaseID string `json:"knowledge_base_id"`
|
||||
AppType string `json:"app_type"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
suggestedPromptsJSON, _ := json.Marshal(req.SuggestedPrompts)
|
||||
appConfig := buildAppConfig(&req.createAppRequest)
|
||||
|
||||
newStatus := status
|
||||
if status == "draft" || status == "rejected" {
|
||||
newStatus = "draft"
|
||||
}
|
||||
|
||||
_, err = h.pool.Exec(r.Context(), `
|
||||
UPDATE applications SET
|
||||
name = COALESCE(NULLIF($2, ''), name),
|
||||
description = COALESCE(NULLIF($3, ''), description),
|
||||
long_description = $4,
|
||||
category_id = COALESCE($5::UUID, category_id),
|
||||
app_config = $6,
|
||||
welcome_message = $7,
|
||||
suggested_prompts = $8,
|
||||
max_tokens = $9,
|
||||
temperature = $10,
|
||||
visibility = COALESCE(NULLIF($11, ''), visibility),
|
||||
knowledge_base_id = $12::UUID,
|
||||
dify_app_type = COALESCE(NULLIF($13, ''), dify_app_type),
|
||||
status = $14
|
||||
WHERE id = $1`,
|
||||
appID, req.Name, req.Description, req.LongDescription,
|
||||
nilIfEmpty(req.CategoryID),
|
||||
appConfig, req.WelcomeMessage, string(suggestedPromptsJSON),
|
||||
req.MaxTokens, req.Temperature, req.Visibility,
|
||||
nilIfEmpty(req.KnowledgeBaseID), req.AppType, newStatus,
|
||||
)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新应用失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "更新成功"})
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) DeleteApp(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
role := middleware.GetRole(r.Context())
|
||||
isAdmin := role == "admin" || role == "super_admin"
|
||||
|
||||
var tag pgconn.CommandTag
|
||||
var err error
|
||||
if isAdmin {
|
||||
tag, err = h.pool.Exec(r.Context(),
|
||||
`DELETE FROM applications WHERE id = $1`, appID)
|
||||
} else {
|
||||
tag, err = h.pool.Exec(r.Context(),
|
||||
`DELETE FROM applications WHERE id = $1 AND creator_id = $2 AND status = 'draft'`,
|
||||
appID, userID)
|
||||
}
|
||||
if err != nil {
|
||||
response.InternalError(w, "删除失败")
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
if isAdmin {
|
||||
response.NotFound(w, "应用不存在")
|
||||
} else {
|
||||
response.BadRequest(w, "只能删除草稿状态的应用")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) SubmitReview(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var status, creatorID, version string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status, creator_id, version FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&status, &creatorID, &version)
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
if creatorID != userID.String() {
|
||||
response.Forbidden(w, "只能提交自己创建的应用")
|
||||
return
|
||||
}
|
||||
if status != "draft" && status != "rejected" {
|
||||
response.BadRequest(w, "只有草稿或被驳回的应用可以提交审核")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
tx, err := h.pool.Begin(r.Context())
|
||||
if err != nil {
|
||||
response.InternalError(w, "事务开始失败")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
INSERT INTO app_reviews (app_id, version, submitter_id, submit_comment)
|
||||
VALUES ($1, $2, $3, $4)`, appID, version, userID, req.Comment)
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE applications SET status = 'pending_review' WHERE id = $1`, appID)
|
||||
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
response.InternalError(w, "提交审核失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) WithdrawReview(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var creatorID string
|
||||
h.pool.QueryRow(r.Context(), `SELECT creator_id FROM applications WHERE id = $1`, appID).Scan(&creatorID)
|
||||
if creatorID != userID.String() {
|
||||
response.Forbidden(w, "只能撤回自己的审核")
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := h.pool.Begin(r.Context())
|
||||
if err != nil {
|
||||
response.InternalError(w, "事务开始失败")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE app_reviews SET status = 'withdrawn'
|
||||
WHERE app_id = $1 AND status = 'pending'`, appID)
|
||||
|
||||
tx.Exec(r.Context(), `
|
||||
UPDATE applications SET status = 'draft' WHERE id = $1 AND status = 'pending_review'`, appID)
|
||||
|
||||
tx.Commit(r.Context())
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
a.usage_count, a.avg_rating, a.rating_count
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.is_template = true AND a.status = 'approved'
|
||||
ORDER BY a.usage_count DESC`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询模板失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
apps := scanAppList(rows)
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
func generateSlug(name string) string {
|
||||
slug := strings.ToLower(strings.TrimSpace(name))
|
||||
slug = strings.ReplaceAll(slug, " ", "-")
|
||||
return fmt.Sprintf("%s-%d", slug, time.Now().UnixMilli()%10000)
|
||||
}
|
||||
|
||||
func (h *CreatorHandler) RequestDelist(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var status, creatorID string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status, creator_id FROM applications WHERE id = $1`, appID).Scan(&status, &creatorID)
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
if creatorID != userID.String() {
|
||||
response.Forbidden(w, "只能操作自己创建的应用")
|
||||
return
|
||||
}
|
||||
if status != "approved" {
|
||||
response.BadRequest(w, "只有已上架的应用可以申请下架")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "申请下架失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已下架"})
|
||||
}
|
||||
|
||||
func nilIfEmpty(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/llm"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DocTemplateHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
manager *llm.Manager
|
||||
provider string
|
||||
}
|
||||
|
||||
func NewDocTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *DocTemplateHandler {
|
||||
return &DocTemplateHandler{pool: pool, manager: manager, provider: provider}
|
||||
}
|
||||
|
||||
func (h *DocTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `SELECT id, name, doc_type, COALESCE(description,''), COALESCE(icon,''), fields, sort_order
|
||||
FROM document_templates
|
||||
WHERE is_active = true`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (org_id = $1 OR org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY sort_order ASC`
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询模板失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var templates []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, docType, desc, icon string
|
||||
var fields json.RawMessage
|
||||
var sortOrder int
|
||||
if err := rows.Scan(&id, &name, &docType, &desc, &icon, &fields, &sortOrder); err != nil {
|
||||
continue
|
||||
}
|
||||
templates = append(templates, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"doc_type": docType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"fields": fields,
|
||||
"sort_order": sortOrder,
|
||||
})
|
||||
}
|
||||
if templates == nil {
|
||||
templates = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
|
||||
}
|
||||
|
||||
func (h *DocTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "templateId")
|
||||
|
||||
var name, docType, desc, icon, formatStd, promptTpl string
|
||||
var fields json.RawMessage
|
||||
var exampleOutput *string
|
||||
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, doc_type, COALESCE(description,''), COALESCE(icon,''),
|
||||
format_standard, fields, prompt_template, example_output
|
||||
FROM document_templates WHERE id = $1 AND is_active = true`, id,
|
||||
).Scan(&name, &docType, &desc, &icon, &formatStd, &fields, &promptTpl, &exampleOutput)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"doc_type": docType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"format_standard": formatStd,
|
||||
"fields": fields,
|
||||
"prompt_template": promptTpl,
|
||||
}
|
||||
if exampleOutput != nil {
|
||||
result["example_output"] = *exampleOutput
|
||||
}
|
||||
response.JSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
type generateDocRequest struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
FieldData map[string]string `json:"field_data"`
|
||||
}
|
||||
|
||||
func (h *DocTemplateHandler) GenerateDocument(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req generateDocRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效请求格式")
|
||||
return
|
||||
}
|
||||
if req.TemplateID == "" {
|
||||
response.BadRequest(w, "请选择公文模板")
|
||||
return
|
||||
}
|
||||
|
||||
// Load app-specific model and max_tokens
|
||||
var appModel string
|
||||
var appMaxTokens int
|
||||
_ = h.pool.QueryRow(r.Context(), `
|
||||
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 4096)
|
||||
FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&appModel, &appMaxTokens)
|
||||
|
||||
var promptTpl, tplName string
|
||||
var fields json.RawMessage
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, fields, prompt_template
|
||||
FROM document_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
|
||||
).Scan(&tplName, &fields, &promptTpl)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
var fieldDefs []struct {
|
||||
Key string `json:"key"`
|
||||
Label string `json:"label"`
|
||||
Required bool `json:"required"`
|
||||
}
|
||||
_ = json.Unmarshal(fields, &fieldDefs)
|
||||
for _, f := range fieldDefs {
|
||||
if f.Required {
|
||||
if val, ok := req.FieldData[f.Key]; !ok || strings.TrimSpace(val) == "" {
|
||||
response.BadRequest(w, fmt.Sprintf("请填写必填项:%s", f.Label))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt := promptTpl
|
||||
for k, v := range req.FieldData {
|
||||
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
|
||||
}
|
||||
prompt = strings.ReplaceAll(prompt, "{{", "")
|
||||
prompt = strings.ReplaceAll(prompt, "}}", "")
|
||||
|
||||
llmReq := &llm.ChatRequest{
|
||||
Model: appModel,
|
||||
Messages: []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: "你是一个专业的政务公文写作专家,精通《党政机关公文格式》国家标准(GB/T 9704)和《党政机关公文处理工作条例》。请严格按照规范格式生成公文,确保行文庄重、严谨、准确。输出完整公文内容,使用Markdown格式排版。"},
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
},
|
||||
Temperature: 0.3,
|
||||
MaxTokens: appMaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalError(w, "Streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
convID := uuid.New().String()
|
||||
msgID := uuid.New().String()
|
||||
var fullResponse strings.Builder
|
||||
|
||||
firstEvent := map[string]string{
|
||||
"conversation_id": convID,
|
||||
"message_id": msgID,
|
||||
"template_name": tplName,
|
||||
}
|
||||
data, _ := json.Marshal(firstEvent)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
transform := llm.TransformOpenAIStream
|
||||
if h.provider == "anthropic" {
|
||||
transform = llm.TransformAnthropicStream
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
var modelName string
|
||||
|
||||
_ = transform(body, func(event llm.StreamEvent) {
|
||||
if event.Answer != "" {
|
||||
fullResponse.WriteString(event.Answer)
|
||||
}
|
||||
if event.Usage != nil {
|
||||
totalTokens = event.Usage.TotalTokens
|
||||
modelName = event.Usage.Model
|
||||
}
|
||||
data, _ := json.Marshal(event)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
})
|
||||
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
userMsg := fmt.Sprintf("[公文生成] %s", tplName)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
_, _ = h.pool.Exec(ctx, `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
|
||||
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
|
||||
_, _ = h.pool.Exec(ctx,
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type FavoriteHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewFavoriteHandler(pool *pgxpool.Pool) *FavoriteHandler {
|
||||
return &FavoriteHandler{pool: pool}
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) AddFavorite(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`INSERT INTO app_favorites (user_id, app_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`,
|
||||
userID, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "收藏失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET favorite_count = favorite_count + 1 WHERE id = $1`, appID)
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]bool{"favorited": true})
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) RemoveFavorite(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
tag, err := h.pool.Exec(r.Context(),
|
||||
`DELETE FROM app_favorites WHERE user_id = $1 AND app_id = $2`,
|
||||
userID, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "取消收藏失败")
|
||||
return
|
||||
}
|
||||
|
||||
if tag.RowsAffected() > 0 {
|
||||
h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET favorite_count = GREATEST(favorite_count - 1, 0) WHERE id = $1`, appID)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]bool{"favorited": false})
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) ListFavorites(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * 20
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
a.usage_count, a.avg_rating, a.rating_count
|
||||
FROM app_favorites f
|
||||
JOIN applications a ON f.app_id = a.id
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE f.user_id = $1
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT 20 OFFSET $2`, userID, offset)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询收藏失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
apps := scanAppList(rows)
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
// --- Rating ---
|
||||
|
||||
type ratingRequest struct {
|
||||
Score int `json:"score"`
|
||||
Comment string `json:"comment"`
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) AddRating(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req ratingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Score < 1 || req.Score > 5 {
|
||||
response.BadRequest(w, "评分必须在1-5之间")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(), `
|
||||
INSERT INTO app_ratings (app_id, user_id, score, comment)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (app_id, user_id)
|
||||
DO UPDATE SET score = EXCLUDED.score, comment = EXCLUDED.comment`,
|
||||
appID, userID, req.Score, req.Comment)
|
||||
if err != nil {
|
||||
response.InternalError(w, "评分失败")
|
||||
return
|
||||
}
|
||||
|
||||
// Update app avg rating
|
||||
var avgRating float32
|
||||
var ratingCount int
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT COALESCE(AVG(score)::REAL, 0), COUNT(*) FROM app_ratings WHERE app_id = $1`,
|
||||
appID).Scan(&avgRating, &ratingCount)
|
||||
|
||||
h.pool.Exec(r.Context(),
|
||||
`UPDATE applications SET avg_rating = $2, rating_count = $3 WHERE id = $1`,
|
||||
appID, avgRating, ratingCount)
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"avg_rating": avgRating,
|
||||
"rating_count": ratingCount,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) ListRatings(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT r.id, r.score, r.comment, r.created_at,
|
||||
u.name as user_name, u.avatar_url
|
||||
FROM app_ratings r
|
||||
JOIN users u ON r.user_id = u.id
|
||||
WHERE r.app_id = $1
|
||||
ORDER BY r.created_at DESC LIMIT 50`, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询评分失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ratings []map[string]any
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var score int
|
||||
var comment *string
|
||||
var createdAt string
|
||||
var userName string
|
||||
var avatarURL *string
|
||||
if err := rows.Scan(&id, &score, &comment, &createdAt, &userName, &avatarURL); err != nil {
|
||||
continue
|
||||
}
|
||||
ratings = append(ratings, map[string]any{
|
||||
"id": id, "score": score, "comment": comment,
|
||||
"created_at": createdAt, "user_name": userName, "user_avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
if ratings == nil {
|
||||
ratings = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, ratings)
|
||||
}
|
||||
|
||||
func (h *FavoriteHandler) PersonalStats(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var totalConversations, totalTokens, favoriteCount int
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT COUNT(*) FROM app_usage_logs WHERE user_id = $1`, userID).Scan(&totalConversations)
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT COALESCE(SUM(total_tokens), 0) FROM app_usage_logs WHERE user_id = $1`, userID).Scan(&totalTokens)
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT COUNT(*) FROM app_favorites WHERE user_id = $1`, userID).Scan(&favoriteCount)
|
||||
|
||||
var recentApps []map[string]any
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT DISTINCT ON (l.app_id) a.id, a.name, a.icon_url, l.created_at
|
||||
FROM app_usage_logs l
|
||||
JOIN applications a ON l.app_id = a.id
|
||||
WHERE l.user_id = $1
|
||||
ORDER BY l.app_id, l.created_at DESC
|
||||
LIMIT 5`, userID)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
var icon *string
|
||||
var at string
|
||||
if rows.Scan(&id, &name, &icon, &at) == nil {
|
||||
recentApps = append(recentApps, map[string]any{
|
||||
"id": id, "name": name, "icon_url": icon, "last_used": at,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if recentApps == nil {
|
||||
recentApps = []map[string]any{}
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"total_conversations": totalConversations,
|
||||
"total_tokens": totalTokens,
|
||||
"favorite_count": favoriteCount,
|
||||
"recent_apps": recentApps,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
)
|
||||
|
||||
var startTime = time.Now()
|
||||
|
||||
func HealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"status": "ok",
|
||||
"service": "aily-portal-api",
|
||||
"uptime": time.Since(startTime).String(),
|
||||
"go": runtime.Version(),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,522 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/pkg/chunker"
|
||||
"github.com/enterprise-ai-platform/server/pkg/embedding"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
mw "github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func internalErr(w http.ResponseWriter, err error) {
|
||||
response.InternalError(w, err.Error())
|
||||
}
|
||||
|
||||
type KnowledgeHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
embedder *embedding.Client
|
||||
}
|
||||
|
||||
func NewKnowledgeHandler(pool *pgxpool.Pool, embedder *embedding.Client) *KnowledgeHandler {
|
||||
return &KnowledgeHandler{pool: pool, embedder: embedder}
|
||||
}
|
||||
|
||||
type knowledgeBaseRow struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Visibility string `json:"visibility"`
|
||||
DocCount int `json:"document_count"`
|
||||
TotalChars int64 `json:"total_chars"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) ListKnowledgeBases(w http.ResponseWriter, r *http.Request) {
|
||||
userID := mw.GetUserID(r.Context())
|
||||
userRole := mw.GetRole(r.Context())
|
||||
|
||||
var query string
|
||||
var args []any
|
||||
|
||||
if userRole == "super_admin" {
|
||||
// 超级管理员查看全部知识库
|
||||
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
|
||||
FROM knowledge_bases ORDER BY updated_at DESC`
|
||||
args = []any{}
|
||||
} else {
|
||||
// 优先使用前端传入的 org_id 参数(切换机构后),否则从用户表获取
|
||||
orgFilter := r.URL.Query().Get("org_id")
|
||||
if orgFilter == "" {
|
||||
var userOrg *string
|
||||
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrg)
|
||||
if userOrg != nil {
|
||||
orgFilter = *userOrg
|
||||
}
|
||||
}
|
||||
|
||||
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
|
||||
FROM knowledge_bases WHERE (owner_id = $1`
|
||||
args = []any{userID}
|
||||
if orgFilter != "" {
|
||||
query += ` OR org_id = $2`
|
||||
args = append(args, orgFilter)
|
||||
}
|
||||
query += `) ORDER BY updated_at DESC`
|
||||
}
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []knowledgeBaseRow
|
||||
for rows.Next() {
|
||||
var kb knowledgeBaseRow
|
||||
if err := rows.Scan(&kb.ID, &kb.Name, &kb.Description, &kb.Visibility, &kb.DocCount, &kb.TotalChars, &kb.Status, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
items = append(items, kb)
|
||||
}
|
||||
if items == nil {
|
||||
items = []knowledgeBaseRow{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) CreateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
||||
userID := mw.GetUserID(r.Context())
|
||||
|
||||
var body struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Visibility string `json:"visibility"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
response.BadRequest(w, "无效的请求体")
|
||||
return
|
||||
}
|
||||
if body.Name == "" {
|
||||
response.BadRequest(w, "名称不能为空")
|
||||
return
|
||||
}
|
||||
if body.Visibility == "" {
|
||||
body.Visibility = "private"
|
||||
}
|
||||
|
||||
// 获取用户所属机构
|
||||
var userOrgID *string
|
||||
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrgID)
|
||||
|
||||
id := uuid.New()
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`INSERT INTO knowledge_bases (id, name, description, owner_id, visibility, org_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
id, body.Name, body.Description, userID, body.Visibility, userOrgID)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]any{
|
||||
"id": id.String(),
|
||||
"name": body.Name,
|
||||
"description": body.Description,
|
||||
"visibility": body.Visibility,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) UpdateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的ID")
|
||||
return
|
||||
}
|
||||
userID := mw.GetUserID(r.Context())
|
||||
|
||||
var body struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
response.BadRequest(w, "无效的请求体")
|
||||
return
|
||||
}
|
||||
|
||||
tag, err := h.pool.Exec(r.Context(),
|
||||
`UPDATE knowledge_bases SET name = COALESCE(NULLIF($1,''), name),
|
||||
description = $2, updated_at = NOW()
|
||||
WHERE id = $3 AND owner_id = $4`,
|
||||
body.Name, body.Description, id, userID)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
response.NotFound(w, "知识库不存在")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) DeleteKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的ID")
|
||||
return
|
||||
}
|
||||
userID := mw.GetUserID(r.Context())
|
||||
|
||||
tag, err := h.pool.Exec(r.Context(),
|
||||
`DELETE FROM knowledge_bases WHERE id = $1 AND owner_id = $2`, id, userID)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
response.NotFound(w, "知识库不存在")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
||||
}
|
||||
|
||||
type documentRow struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"filename"`
|
||||
FileType string `json:"file_type"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
IndexingStatus string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) ListDocuments(w http.ResponseWriter, r *http.Request) {
|
||||
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的ID")
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, name, COALESCE(file_type,''), file_size, indexing_status, created_at
|
||||
FROM knowledge_documents WHERE kb_id = $1 ORDER BY created_at DESC`, kbID)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []documentRow
|
||||
for rows.Next() {
|
||||
var d documentRow
|
||||
if err := rows.Scan(&d.ID, &d.Name, &d.FileType, &d.FileSize, &d.IndexingStatus, &d.CreatedAt); err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
docs = append(docs, d)
|
||||
}
|
||||
if docs == nil {
|
||||
docs = []documentRow{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, docs)
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) UploadDocument(w http.ResponseWriter, r *http.Request) {
|
||||
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
response.BadRequest(w, "文件过大或格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
response.BadRequest(w, "请上传文件")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var exists bool
|
||||
err = h.pool.QueryRow(r.Context(),
|
||||
`SELECT EXISTS(SELECT 1 FROM knowledge_bases WHERE id = $1)`, kbID).Scan(&exists)
|
||||
if err != nil || !exists {
|
||||
response.NotFound(w, "知识库不存在")
|
||||
return
|
||||
}
|
||||
|
||||
fileType := ""
|
||||
ext := ""
|
||||
if dot := len(header.Filename) - 1; dot > 0 {
|
||||
for i := dot; i >= 0; i-- {
|
||||
if header.Filename[i] == '.' {
|
||||
ext = header.Filename[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
switch ext {
|
||||
case "pdf":
|
||||
fileType = "pdf"
|
||||
case "docx":
|
||||
fileType = "docx"
|
||||
case "txt":
|
||||
fileType = "txt"
|
||||
case "md":
|
||||
fileType = "md"
|
||||
case "csv":
|
||||
fileType = "csv"
|
||||
case "xlsx":
|
||||
fileType = "xlsx"
|
||||
default:
|
||||
fileType = "txt"
|
||||
}
|
||||
|
||||
// 读取文件内容(文本文件)
|
||||
var content string
|
||||
if fileType == "txt" || fileType == "md" || fileType == "csv" {
|
||||
data, err := io.ReadAll(file)
|
||||
if err == nil {
|
||||
content = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
userID := mw.GetUserID(r.Context())
|
||||
docID := uuid.New()
|
||||
|
||||
// 计算分片数
|
||||
chunkCount := 0
|
||||
if content != "" {
|
||||
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
|
||||
chunkCount = len(chunks)
|
||||
}
|
||||
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`INSERT INTO knowledge_documents (id, kb_id, name, file_type, file_size, uploader_id, indexing_status, content, char_count, chunk_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'processing', $7, $8, $9)`,
|
||||
docID, kbID, header.Filename, fileType, header.Size, userID, content, utf8.RuneCountInString(content), chunkCount)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = h.pool.Exec(r.Context(),
|
||||
`UPDATE knowledge_bases SET doc_count = doc_count + 1, updated_at = NOW() WHERE id = $1`, kbID)
|
||||
|
||||
// 异步执行分片和向量化
|
||||
go h.chunkAndEmbed(context.Background(), kbID, docID, content)
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]any{
|
||||
"id": docID.String(),
|
||||
"filename": header.Filename,
|
||||
"size": header.Size,
|
||||
"chunks": chunkCount,
|
||||
"status": "processing",
|
||||
})
|
||||
}
|
||||
|
||||
// chunkAndEmbed 对文档内容执行分片和向量化(异步)
|
||||
func (h *KnowledgeHandler) chunkAndEmbed(ctx context.Context, kbID, docID uuid.UUID, content string) {
|
||||
if content == "" {
|
||||
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
|
||||
return
|
||||
}
|
||||
|
||||
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
|
||||
if len(chunks) == 0 {
|
||||
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
|
||||
return
|
||||
}
|
||||
|
||||
embeddingAvailable := h.embedder != nil && h.embedder.IsConfigured()
|
||||
successCount := 0
|
||||
|
||||
for i, chunk := range chunks {
|
||||
chunkID := uuid.New()
|
||||
charCount := utf8.RuneCountInString(chunk)
|
||||
|
||||
if embeddingAvailable {
|
||||
emb, err := h.embedder.GetEmbedding(ctx, chunk)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("doc_id", docID.String()).Int("chunk", i).Msg("embedding failed")
|
||||
// 无 embedding 也插入 chunk
|
||||
h.pool.Exec(ctx,
|
||||
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
chunkID, kbID, docID, i, chunk, charCount)
|
||||
} else {
|
||||
// 将 []float32 转为 pgvector 格式字符串
|
||||
vecStr := float32SliceToVectorStr(emb)
|
||||
h.pool.Exec(ctx,
|
||||
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)`,
|
||||
chunkID, kbID, docID, i, chunk, charCount, vecStr)
|
||||
successCount++
|
||||
}
|
||||
} else {
|
||||
// 没有 embedding 服务,只存储文本分片
|
||||
h.pool.Exec(ctx,
|
||||
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
chunkID, kbID, docID, i, chunk, charCount)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新文档状态
|
||||
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed', chunk_count = $2 WHERE id = $1`, docID, len(chunks))
|
||||
|
||||
log.Info().
|
||||
Str("doc_id", docID.String()).
|
||||
Int("chunks", len(chunks)).
|
||||
Int("embedded", successCount).
|
||||
Bool("embedding_available", embeddingAvailable).
|
||||
Msg("document chunked and embedded")
|
||||
}
|
||||
|
||||
// float32SliceToVectorStr 将 float32 切片转为 pgvector 格式字符串 "[0.1,0.2,...]"
|
||||
func float32SliceToVectorStr(v []float32) string {
|
||||
s := "["
|
||||
for i, f := range v {
|
||||
if i > 0 {
|
||||
s += ","
|
||||
}
|
||||
s += fmt.Sprintf("%g", f)
|
||||
}
|
||||
s += "]"
|
||||
return s
|
||||
}
|
||||
|
||||
// ReindexAll 对所有未分片的文档执行分片和向量化(管理端点)
|
||||
func (h *KnowledgeHandler) ReindexAll(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT kd.id, kd.kb_id, kd.content FROM knowledge_documents kd
|
||||
WHERE kd.content IS NOT NULL AND kd.content != '' AND kd.chunk_count = 0`)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type docInfo struct {
|
||||
docID uuid.UUID
|
||||
kbID uuid.UUID
|
||||
content string
|
||||
}
|
||||
var docs []docInfo
|
||||
for rows.Next() {
|
||||
var d docInfo
|
||||
if err := rows.Scan(&d.docID, &d.kbID, &d.content); err != nil {
|
||||
continue
|
||||
}
|
||||
docs = append(docs, d)
|
||||
}
|
||||
|
||||
for _, d := range docs {
|
||||
h.chunkAndEmbed(r.Context(), d.kbID, d.docID, d.content)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"message": "重新索引完成",
|
||||
"documents": len(docs),
|
||||
})
|
||||
}
|
||||
|
||||
// ReembedChunks 为已有分片但缺失embedding的chunks补充向量化(管理端点)
|
||||
func (h *KnowledgeHandler) ReembedChunks(w http.ResponseWriter, r *http.Request) {
|
||||
if h.embedder == nil || !h.embedder.IsConfigured() {
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"message": "embedding服务未配置",
|
||||
"updated": 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL AND content IS NOT NULL AND content != '' LIMIT 200`)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type chunkInfo struct {
|
||||
id uuid.UUID
|
||||
content string
|
||||
}
|
||||
var chunks []chunkInfo
|
||||
for rows.Next() {
|
||||
var c chunkInfo
|
||||
if err := rows.Scan(&c.id, &c.content); err != nil {
|
||||
continue
|
||||
}
|
||||
chunks = append(chunks, c)
|
||||
}
|
||||
|
||||
updated := 0
|
||||
for _, c := range chunks {
|
||||
emb, err := h.embedder.GetEmbedding(r.Context(), c.content)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("re-embed failed")
|
||||
continue
|
||||
}
|
||||
vecStr := float32SliceToVectorStr(emb)
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`UPDATE knowledge_chunks SET embedding = $1::vector WHERE id = $2`, vecStr, c.id)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("update embedding failed")
|
||||
continue
|
||||
}
|
||||
updated++
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"message": "向量补充完成",
|
||||
"total": len(chunks),
|
||||
"updated": updated,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *KnowledgeHandler) DeleteDocument(w http.ResponseWriter, r *http.Request) {
|
||||
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的知识库ID")
|
||||
return
|
||||
}
|
||||
docID, err := uuid.Parse(chi.URLParam(r, "docId"))
|
||||
if err != nil {
|
||||
response.BadRequest(w, "无效的文档ID")
|
||||
return
|
||||
}
|
||||
|
||||
tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_documents WHERE id = $1 AND kb_id = $2`, docID, kbID)
|
||||
if err != nil {
|
||||
internalErr(w, err)
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
response.NotFound(w, "文档不存在")
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = h.pool.Exec(r.Context(),
|
||||
`UPDATE knowledge_bases SET doc_count = GREATEST(doc_count - 1, 0), updated_at = NOW() WHERE id = $1`, kbID)
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
||||
}
|
||||
@@ -0,0 +1,828 @@
|
||||
package handler
|
||||
|
||||
// 平台级管理 handler — 仅 super_admin 可访问,跨机构操作,不受 org_id 限制。
|
||||
// 与 admin.go(机构级)的核心区别:
|
||||
// - admin.go 所有查询都加 WHERE org_id = $orgID 过滤
|
||||
// - platform.go 不加 org_id 过滤,可看/操作所有机构数据
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type PlatformHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPlatformHandler(pool *pgxpool.Pool) *PlatformHandler {
|
||||
return &PlatformHandler{pool: pool}
|
||||
}
|
||||
|
||||
// ==================== 平台总览 ====================
|
||||
|
||||
// Overview 平台级数据看板:跨所有机构的聚合统计
|
||||
func (h *PlatformHandler) Overview(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
|
||||
var totalOrgs, activeOrgs, totalUsers, activeUsers, totalApps, approvedApps, todayLogins, todayConvs int
|
||||
var monthlyTokens int64
|
||||
var monthlyCost float64
|
||||
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations`).Scan(&totalOrgs)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations WHERE is_active = true`).Scan(&activeOrgs)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users`).Scan(&totalUsers)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE status = 'active'`).Scan(&activeUsers)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications`).Scan(&totalApps)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE status = 'approved'`).Scan(&approvedApps)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE last_login_at >= $1`, todayStart).Scan(&todayLogins)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM app_usage_logs WHERE created_at >= $1`, todayStart).Scan(&todayConvs)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(total_tokens),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyTokens)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(estimated_cost),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyCost)
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"total_orgs": totalOrgs,
|
||||
"active_orgs": activeOrgs,
|
||||
"total_users": totalUsers,
|
||||
"active_users": activeUsers,
|
||||
"total_apps": totalApps,
|
||||
"approved_apps": approvedApps,
|
||||
"today_logins": todayLogins,
|
||||
"today_convs": todayConvs,
|
||||
"monthly_tokens": monthlyTokens,
|
||||
"monthly_cost": monthlyCost,
|
||||
})
|
||||
}
|
||||
|
||||
// OrgRanking 各机构活跃度排行(用户数 / 应用数 / 月度对话数)
|
||||
func (h *PlatformHandler) OrgRanking(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT o.id, o.name, COALESCE(o.short_name, ''),
|
||||
COUNT(DISTINCT u.id) AS users,
|
||||
COUNT(DISTINCT a.id) FILTER (WHERE a.status='approved') AS apps,
|
||||
COALESCE(SUM(usage.cnt), 0) AS conversations,
|
||||
COALESCE(SUM(usage.tk), 0) AS tokens
|
||||
FROM organizations o
|
||||
LEFT JOIN users u ON u.org_id = o.id
|
||||
LEFT JOIN applications a ON a.org_id = o.id
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT COUNT(*) AS cnt, COALESCE(SUM(total_tokens),0) AS tk
|
||||
FROM app_usage_logs l
|
||||
JOIN users uu ON l.user_id = uu.id
|
||||
WHERE uu.org_id = o.id AND l.created_at >= $1
|
||||
) usage ON true
|
||||
WHERE o.is_active = true
|
||||
GROUP BY o.id, o.name, o.short_name
|
||||
ORDER BY conversations DESC, users DESC
|
||||
LIMIT 50`, monthStart)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, short string
|
||||
users, apps int
|
||||
conversations, toks int64
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &short, &users, &apps, &conversations, &toks); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "name": name, "short_name": short,
|
||||
"users": users, "apps": apps,
|
||||
"conversations": conversations, "tokens": toks,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
// ==================== 机构管理 ====================
|
||||
|
||||
func (h *PlatformHandler) ListOrgs(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT o.id, o.name, o.slug, COALESCE(o.short_name,''), COALESCE(o.description,''),
|
||||
COALESCE(o.logo_url,''), o.sort_order, o.is_active, o.created_at,
|
||||
COALESCE((SELECT COUNT(*) FROM users WHERE org_id = o.id), 0) AS user_count,
|
||||
COALESCE((SELECT COUNT(*) FROM applications WHERE org_id = o.id), 0) AS app_count
|
||||
FROM organizations o ORDER BY o.sort_order, o.created_at`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询机构失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug, short, desc, logo string
|
||||
sortOrder int
|
||||
isActive bool
|
||||
createdAt time.Time
|
||||
userCount, appCount int
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &short, &desc, &logo, &sortOrder, &isActive, &createdAt, &userCount, &appCount); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "name": name, "slug": slug, "short_name": short,
|
||||
"description": desc, "logo_url": logo, "sort_order": sortOrder,
|
||||
"is_active": isActive, "created_at": createdAt,
|
||||
"user_count": userCount, "app_count": appCount,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) CreateOrg(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
ShortName string `json:"short_name"`
|
||||
Description string `json:"description"`
|
||||
LogoURL string `json:"logo_url"`
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || req.Slug == "" {
|
||||
response.BadRequest(w, "机构名称和标识不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var id string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
INSERT INTO organizations (name, slug, short_name, description, logo_url, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id::text`,
|
||||
req.Name, req.Slug, req.ShortName, req.Description, req.LogoURL, req.SortOrder,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
response.InternalError(w, "创建机构失败:标识可能已存在")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) UpdateOrg(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
ShortName *string `json:"short_name"`
|
||||
Description *string `json:"description"`
|
||||
LogoURL *string `json:"logo_url"`
|
||||
SortOrder *int `json:"sort_order"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(), `
|
||||
UPDATE organizations SET
|
||||
name = COALESCE($2, name),
|
||||
short_name = COALESCE($3, short_name),
|
||||
description = COALESCE($4, description),
|
||||
logo_url = COALESCE($5, logo_url),
|
||||
sort_order = COALESCE($6, sort_order),
|
||||
is_active = COALESCE($7, is_active),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`, id, req.Name, req.ShortName, req.Description, req.LogoURL, req.SortOrder, req.IsActive)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) DeleteOrg(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
// 安全检查:机构下还有用户或应用时禁止物理删除,引导走停用
|
||||
var userCount, appCount int
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE org_id = $1`, id).Scan(&userCount)
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE org_id = $1`, id).Scan(&appCount)
|
||||
if userCount > 0 || appCount > 0 {
|
||||
response.BadRequest(w, "该机构尚有用户或应用,无法删除。请先停用或迁移数据")
|
||||
return
|
||||
}
|
||||
_, err := h.pool.Exec(r.Context(), `DELETE FROM organizations WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
response.InternalError(w, "删除失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
||||
}
|
||||
|
||||
// ==================== 全局用户管理 ====================
|
||||
|
||||
// ListAllUsers 全局用户列表(支持分页、按机构/角色/状态过滤)
|
||||
func (h *PlatformHandler) ListAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := 20
|
||||
offset := (page - 1) * pageSize
|
||||
search := q.Get("q")
|
||||
roleFilter := q.Get("role")
|
||||
statusFilter := q.Get("status")
|
||||
orgFilter := q.Get("org_id")
|
||||
|
||||
// 构造 WHERE 子句(list 与 count 共用)
|
||||
where := ` WHERE 1=1`
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
if search != "" {
|
||||
where += ` AND (u.name ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%' OR u.email ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%')`
|
||||
args = append(args, search)
|
||||
argIdx++
|
||||
}
|
||||
if roleFilter != "" {
|
||||
where += ` AND u.role = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, roleFilter)
|
||||
argIdx++
|
||||
}
|
||||
if statusFilter != "" {
|
||||
where += ` AND u.status = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, statusFilter)
|
||||
argIdx++
|
||||
}
|
||||
if orgFilter != "" {
|
||||
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, orgFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
// 先查总数
|
||||
var total int
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users u`+where, args...).Scan(&total)
|
||||
|
||||
// 再分页查列表
|
||||
listQuery := `SELECT u.id, u.name, u.email, u.avatar_url, u.role, u.status, u.employee_id,
|
||||
u.last_login_at, u.login_count, u.created_at,
|
||||
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
|
||||
FROM users u LEFT JOIN organizations o ON u.org_id = o.id` + where +
|
||||
` ORDER BY u.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
listArgs := append(args, pageSize, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询用户失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, email, role, status string
|
||||
avatarURL, employeeID *string
|
||||
lastLoginAt *time.Time
|
||||
loginCount int
|
||||
createdAt time.Time
|
||||
orgID, orgName, orgShort string
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &email, &avatarURL, &role, &status, &employeeID, &lastLoginAt, &loginCount, &createdAt, &orgID, &orgName, &orgShort); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "name": name, "email": email, "avatar_url": avatarURL,
|
||||
"role": role, "status": status, "employee_id": employeeID,
|
||||
"last_login_at": lastLoginAt, "login_count": loginCount, "created_at": createdAt,
|
||||
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// AssignUserOrg 把用户分配/迁移到指定机构
|
||||
func (h *PlatformHandler) AssignUserOrg(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.OrgID == "" {
|
||||
response.BadRequest(w, "机构ID不能为空")
|
||||
return
|
||||
}
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE users SET org_id = $2 WHERE id = $1`, userID, req.OrgID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "迁移失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已迁移"})
|
||||
}
|
||||
|
||||
// UpdateUserRole 平台管理员可设置任意角色(包括 super_admin)
|
||||
func (h *PlatformHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
valid := map[string]bool{"user": true, "creator": true, "admin": true, "super_admin": true}
|
||||
if !valid[req.Role] {
|
||||
response.BadRequest(w, "无效的角色")
|
||||
return
|
||||
}
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE users SET role = $2 WHERE id = $1`, userID, req.Role)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新角色失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
// UpdateUserStatus 启用/禁用任意用户
|
||||
func (h *PlatformHandler) UpdateUserStatus(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Status != "active" && req.Status != "disabled" {
|
||||
response.BadRequest(w, "无效的状态")
|
||||
return
|
||||
}
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE users SET status = $2 WHERE id = $1`, userID, req.Status)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新状态失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
// ==================== 全局应用管理 ====================
|
||||
|
||||
// ListAllApps 全局应用列表(支持分页、按机构/状态过滤)
|
||||
func (h *PlatformHandler) ListAllApps(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := 20
|
||||
offset := (page - 1) * pageSize
|
||||
statusFilter := q.Get("status")
|
||||
orgFilter := q.Get("org_id")
|
||||
|
||||
where := ` WHERE 1=1`
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
if statusFilter != "" {
|
||||
where += ` AND a.status = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, statusFilter)
|
||||
argIdx++
|
||||
}
|
||||
if orgFilter != "" {
|
||||
where += ` AND a.org_id = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, orgFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int
|
||||
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications a`+where, args...).Scan(&total)
|
||||
|
||||
listQuery := `SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
a.dify_app_type, a.status, a.visibility, a.usage_count, a.is_featured, a.created_at,
|
||||
COALESCE(c.name, ''), COALESCE(u.name, ''),
|
||||
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
LEFT JOIN organizations o ON a.org_id = o.id` + where +
|
||||
` ORDER BY a.usage_count DESC, a.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
listArgs := append(args, pageSize, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询应用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug, status, visibility string
|
||||
desc, iconURL *string
|
||||
appType *string
|
||||
usageCount int64
|
||||
isFeatured bool
|
||||
createdAt time.Time
|
||||
catName, creatorName string
|
||||
orgID, orgName, orgShort string
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &appType, &status, &visibility,
|
||||
&usageCount, &isFeatured, &createdAt, &catName, &creatorName, &orgID, &orgName, &orgShort); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "name": name, "slug": slug, "description": desc, "icon_url": iconURL,
|
||||
"dify_app_type": appType, "status": status, "visibility": visibility,
|
||||
"usage_count": usageCount, "is_featured": isFeatured, "created_at": createdAt,
|
||||
"category_name": catName, "creator_name": creatorName,
|
||||
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// SetFeatured 平台级精选/取消精选
|
||||
func (h *PlatformHandler) SetFeatured(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
IsFeatured bool `json:"is_featured"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET is_featured = $2 WHERE id = $1`, appID, req.IsFeatured)
|
||||
if err != nil {
|
||||
response.InternalError(w, "操作失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
// ForceDelist 平台强制下架(不受机构限制)
|
||||
func (h *PlatformHandler) ForceDelist(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "下架失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已强制下架"})
|
||||
}
|
||||
|
||||
// ==================== 全局审计日志 ====================
|
||||
|
||||
// ListAllAuditLogs 全局审计日志(支持分页、按机构/操作类型过滤)
|
||||
func (h *PlatformHandler) ListAllAuditLogs(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := 20
|
||||
offset := (page - 1) * pageSize
|
||||
orgFilter := q.Get("org_id")
|
||||
actionFilter := q.Get("action")
|
||||
|
||||
where := ` WHERE 1=1`
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
if orgFilter != "" {
|
||||
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, orgFilter)
|
||||
argIdx++
|
||||
}
|
||||
if actionFilter != "" {
|
||||
where += ` AND al.action ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%'`
|
||||
args = append(args, actionFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT COUNT(*) FROM audit_logs al LEFT JOIN users u ON al.user_id = u.id`+where,
|
||||
args...).Scan(&total)
|
||||
|
||||
listQuery := `SELECT al.id, al.action, al.resource_type, al.resource_id::text, al.details,
|
||||
al.ip_address::text, al.created_at,
|
||||
COALESCE(u.name, ''), COALESCE(u.email, ''),
|
||||
COALESCE(o.name, ''), COALESCE(o.short_name, '')
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN organizations o ON u.org_id = o.id` + where +
|
||||
` ORDER BY al.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
listArgs := append(args, pageSize, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询审计日志失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, action, resType string
|
||||
resID, ipAddr *string
|
||||
details json.RawMessage
|
||||
createdAt time.Time
|
||||
userName, userEmail, orgName, orgShort string
|
||||
)
|
||||
if err := rows.Scan(&id, &action, &resType, &resID, &details, &ipAddr, &createdAt,
|
||||
&userName, &userEmail, &orgName, &orgShort); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "action": action, "resource_type": resType, "resource_id": resID,
|
||||
"details": details, "ip_address": ipAddr, "created_at": createdAt,
|
||||
"user_name": userName, "user_email": userEmail,
|
||||
"org_name": orgName, "org_short": orgShort,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 模型提供商管理 ====================
|
||||
|
||||
func (h *PlatformHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT id, name, base_url, models, is_active, priority, created_at, updated_at
|
||||
FROM model_providers ORDER BY priority DESC, created_at`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, baseURL string
|
||||
models json.RawMessage
|
||||
isActive bool
|
||||
priority int
|
||||
createdAt, updatedAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &baseURL, &models, &isActive, &priority, &createdAt, &updatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "name": name, "base_url": baseURL, "models": models,
|
||||
"is_active": isActive, "priority": priority,
|
||||
"created_at": createdAt, "updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Models json.RawMessage `json:"models"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || req.BaseURL == "" || req.APIKey == "" {
|
||||
response.BadRequest(w, "名称、URL、API Key 不能为空")
|
||||
return
|
||||
}
|
||||
if len(req.Models) == 0 {
|
||||
req.Models = []byte(`[]`)
|
||||
}
|
||||
|
||||
var id string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
INSERT INTO model_providers (name, base_url, api_key_encrypted, models, is_active, priority)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id::text`,
|
||||
req.Name, req.BaseURL, req.APIKey, req.Models, req.IsActive, req.Priority,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
response.InternalError(w, "创建失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) UpdateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
var req struct {
|
||||
Name *string `json:"name"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
APIKey *string `json:"api_key"`
|
||||
Models json.RawMessage `json:"models"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
Priority *int `json:"priority"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 仅当 api_key 非空才更新(避免误清空)
|
||||
var apiKeyArg any = nil
|
||||
if req.APIKey != nil && *req.APIKey != "" {
|
||||
apiKeyArg = *req.APIKey
|
||||
}
|
||||
var modelsArg any = nil
|
||||
if len(req.Models) > 0 {
|
||||
modelsArg = req.Models
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(), `
|
||||
UPDATE model_providers SET
|
||||
name = COALESCE($2, name),
|
||||
base_url = COALESCE($3, base_url),
|
||||
api_key_encrypted = COALESCE($4, api_key_encrypted),
|
||||
models = COALESCE($5::jsonb, models),
|
||||
is_active = COALESCE($6, is_active),
|
||||
priority = COALESCE($7, priority),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`,
|
||||
id, req.Name, req.BaseURL, apiKeyArg, modelsArg, req.IsActive, req.Priority)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) DeleteProvider(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_providers WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
response.InternalError(w, "删除失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
||||
}
|
||||
|
||||
// ==================== 全局配额管理 ====================
|
||||
|
||||
func (h *PlatformHandler) ListQuotas(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT q.id, q.target_type, q.target_id::text, q.model_name,
|
||||
q.daily_token_limit, q.monthly_token_limit, q.daily_request_limit,
|
||||
q.is_active, q.created_at,
|
||||
COALESCE(target_name.name, '')
|
||||
FROM model_quotas q
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT CASE q.target_type
|
||||
WHEN 'global' THEN '全局'
|
||||
WHEN 'department' THEN (SELECT name FROM departments WHERE id = q.target_id)
|
||||
WHEN 'user' THEN (SELECT name FROM users WHERE id = q.target_id)
|
||||
END AS name
|
||||
) target_name ON true
|
||||
ORDER BY q.target_type, q.created_at DESC LIMIT 200`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询配额失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := []map[string]any{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, targetType string
|
||||
targetID *string
|
||||
modelName *string
|
||||
dailyTokens, monthlyTokens *int64
|
||||
dailyReqs *int
|
||||
isActive bool
|
||||
createdAt time.Time
|
||||
targetName string
|
||||
)
|
||||
if err := rows.Scan(&id, &targetType, &targetID, &modelName, &dailyTokens, &monthlyTokens,
|
||||
&dailyReqs, &isActive, &createdAt, &targetName); err != nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"id": id, "target_type": targetType, "target_id": targetID,
|
||||
"target_name": targetName,
|
||||
"model_name": modelName,
|
||||
"daily_token_limit": dailyTokens,
|
||||
"monthly_token_limit": monthlyTokens,
|
||||
"daily_request_limit": dailyReqs,
|
||||
"is_active": isActive,
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, items)
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) UpsertQuota(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
ID *string `json:"id"`
|
||||
TargetType string `json:"target_type"`
|
||||
TargetID *string `json:"target_id"`
|
||||
ModelName *string `json:"model_name"`
|
||||
DailyTokenLimit *int64 `json:"daily_token_limit"`
|
||||
MonthlyTokenLimit *int64 `json:"monthly_token_limit"`
|
||||
DailyRequestLimit *int `json:"daily_request_limit"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
valid := map[string]bool{"global": true, "department": true, "user": true}
|
||||
if !valid[req.TargetType] {
|
||||
response.BadRequest(w, "无效的目标类型")
|
||||
return
|
||||
}
|
||||
if req.TargetType != "global" && (req.TargetID == nil || *req.TargetID == "") {
|
||||
response.BadRequest(w, "部门/用户配额必须指定 target_id")
|
||||
return
|
||||
}
|
||||
|
||||
if req.ID != nil && *req.ID != "" {
|
||||
_, err := h.pool.Exec(r.Context(), `
|
||||
UPDATE model_quotas SET
|
||||
target_type = $2, target_id = NULLIF($3,'')::uuid, model_name = $4,
|
||||
daily_token_limit = $5, monthly_token_limit = $6, daily_request_limit = $7,
|
||||
is_active = $8, updated_at = NOW()
|
||||
WHERE id = $1`,
|
||||
*req.ID, req.TargetType, nullableString(req.TargetID), req.ModelName,
|
||||
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新配额失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"id": *req.ID})
|
||||
return
|
||||
}
|
||||
|
||||
var id string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
INSERT INTO model_quotas (target_type, target_id, model_name, daily_token_limit, monthly_token_limit, daily_request_limit, is_active)
|
||||
VALUES ($1, NULLIF($2,'')::uuid, $3, $4, $5, $6, $7)
|
||||
RETURNING id::text`,
|
||||
req.TargetType, nullableString(req.TargetID), req.ModelName,
|
||||
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
response.InternalError(w, "创建配额失败")
|
||||
return
|
||||
}
|
||||
response.InternalError(w, "创建配额失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusCreated, map[string]string{"id": id})
|
||||
}
|
||||
|
||||
func (h *PlatformHandler) DeleteQuota(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_quotas WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
response.InternalError(w, "删除失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
||||
}
|
||||
|
||||
func nullableString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type PPTHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
workerURL string
|
||||
}
|
||||
|
||||
func NewPPTHandler(pool *pgxpool.Pool, rdb *redis.Client, workerURL string) *PPTHandler {
|
||||
return &PPTHandler{pool: pool, rdb: rdb, workerURL: workerURL}
|
||||
}
|
||||
|
||||
// ==================== 请求/响应结构 ====================
|
||||
|
||||
type createPPTRequest struct {
|
||||
Title string `json:"title"`
|
||||
SourceType string `json:"source_type"` // text / url
|
||||
SourceContent string `json:"source_content"` // 文本内容或 URL
|
||||
Config map[string]any `json:"config"`
|
||||
}
|
||||
|
||||
type pptTaskResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
StatusMessage *string `json:"status_message,omitempty"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
OutputFile *string `json:"output_file,omitempty"`
|
||||
PageCount *int `json:"page_count,omitempty"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// ==================== 接口实现 ====================
|
||||
|
||||
// CreateTask 创建 PPT 生成任务(文本/URL 输入)
|
||||
func (h *PPTHandler) CreateTask(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req createPPTRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Title == "" {
|
||||
response.BadRequest(w, "标题不能为空")
|
||||
return
|
||||
}
|
||||
if req.SourceType == "" {
|
||||
req.SourceType = "text"
|
||||
}
|
||||
if req.SourceContent == "" {
|
||||
response.BadRequest(w, "请提供源内容")
|
||||
return
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
configJSON, _ := json.Marshal(req.Config)
|
||||
|
||||
// 写入数据库
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
taskID, userID, req.Title, req.SourceType, req.SourceContent, configJSON,
|
||||
)
|
||||
if err != nil {
|
||||
response.InternalError(w, "创建任务失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 推送到 Redis 队列
|
||||
taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID})
|
||||
h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg)
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]string{
|
||||
"task_id": taskID,
|
||||
"status": "pending",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateTaskWithFile 创建带文件上传的 PPT 生成任务
|
||||
func (h *PPTHandler) CreateTaskWithFile(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
if err := r.ParseMultipartForm(50 << 20); err != nil { // 50MB 限制
|
||||
response.BadRequest(w, "文件过大或请求格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
title := r.FormValue("title")
|
||||
if title == "" {
|
||||
response.BadRequest(w, "标题不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
configStr := r.FormValue("config")
|
||||
var taskConfig map[string]any
|
||||
if configStr != "" {
|
||||
json.Unmarshal([]byte(configStr), &taskConfig)
|
||||
}
|
||||
if taskConfig == nil {
|
||||
taskConfig = map[string]any{}
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
response.BadRequest(w, "请上传文件")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 将文件转发到 PPT Worker
|
||||
taskID := uuid.New().String()
|
||||
configJSON, _ := json.Marshal(taskConfig)
|
||||
|
||||
// 转发文件到 Worker 服务
|
||||
err = h.forwardFileToWorker(r.Context(), taskID, userID.String(), title, string(configJSON), file, header)
|
||||
if err != nil {
|
||||
response.InternalError(w, "提交任务失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]string{
|
||||
"task_id": taskID,
|
||||
"status": "pending",
|
||||
})
|
||||
}
|
||||
|
||||
// GetTaskStatus 查询任务状态
|
||||
func (h *PPTHandler) GetTaskStatus(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
// 先从 Redis 快速查询
|
||||
key := "ppt:status:" + taskID
|
||||
cached, err := h.rdb.HGetAll(r.Context(), key).Result()
|
||||
if err == nil && len(cached) > 0 {
|
||||
progress := 0
|
||||
fmt.Sscanf(cached["progress"], "%d", &progress)
|
||||
msg := cached["message"]
|
||||
response.JSON(w, http.StatusOK, pptTaskResponse{
|
||||
TaskID: taskID,
|
||||
Status: cached["status"],
|
||||
Progress: progress,
|
||||
StatusMessage: &msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 回退到数据库
|
||||
var task pptTaskResponse
|
||||
var statusMsg, errMsg, outputFile *string
|
||||
var pageCount *int
|
||||
var createdAt time.Time
|
||||
|
||||
err = h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
||||
FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID,
|
||||
).Scan(&task.TaskID, &task.Status, &task.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &task.Title, &createdAt)
|
||||
if err != nil {
|
||||
response.NotFound(w, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
task.StatusMessage = statusMsg
|
||||
task.ErrorMessage = errMsg
|
||||
task.OutputFile = outputFile
|
||||
task.PageCount = pageCount
|
||||
task.CreatedAt = createdAt.Format(time.RFC3339)
|
||||
|
||||
response.JSON(w, http.StatusOK, task)
|
||||
}
|
||||
|
||||
// ListTasks 列出用户的 PPT 任务
|
||||
func (h *PPTHandler) ListTasks(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
|
||||
FROM ppt_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`, userID,
|
||||
)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []pptTaskResponse
|
||||
for rows.Next() {
|
||||
var t pptTaskResponse
|
||||
var statusMsg, errMsg, outputFile *string
|
||||
var pageCount *int
|
||||
var createdAt time.Time
|
||||
|
||||
if err := rows.Scan(&t.TaskID, &t.Status, &t.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &t.Title, &createdAt); err != nil {
|
||||
continue
|
||||
}
|
||||
t.StatusMessage = statusMsg
|
||||
t.ErrorMessage = errMsg
|
||||
t.OutputFile = outputFile
|
||||
t.PageCount = pageCount
|
||||
t.CreatedAt = createdAt.Format(time.RFC3339)
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
|
||||
if tasks == nil {
|
||||
tasks = []pptTaskResponse{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, tasks)
|
||||
}
|
||||
|
||||
// DownloadTask 下载生成的 PPTX 文件
|
||||
func (h *PPTHandler) DownloadTask(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var status, title string
|
||||
var outputFile *string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT status, title, output_file FROM ppt_tasks WHERE id = $1 AND user_id = $2`,
|
||||
taskID, userID,
|
||||
).Scan(&status, &title, &outputFile)
|
||||
if err != nil {
|
||||
response.NotFound(w, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if status != "completed" {
|
||||
response.BadRequest(w, "任务未完成")
|
||||
return
|
||||
}
|
||||
|
||||
// 代理下载请求到 Worker 服务
|
||||
workerResp, err := http.Get(h.workerURL + "/api/tasks/" + taskID + "/download")
|
||||
if err != nil {
|
||||
response.InternalError(w, "下载服务不可用")
|
||||
return
|
||||
}
|
||||
defer workerResp.Body.Close()
|
||||
|
||||
if workerResp.StatusCode != http.StatusOK {
|
||||
response.InternalError(w, "文件下载失败")
|
||||
return
|
||||
}
|
||||
|
||||
filename := title + ".pptx"
|
||||
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
|
||||
io.Copy(w, workerResp.Body)
|
||||
}
|
||||
|
||||
// ==================== 内部方法 ====================
|
||||
|
||||
func (h *PPTHandler) forwardFileToWorker(ctx context.Context, taskID, userID, title, configJSON string, file multipart.File, header *multipart.FileHeader) error {
|
||||
// 构造 multipart 请求转发到 Worker
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
writer.WriteField("task_id", taskID)
|
||||
writer.WriteField("user_id", userID)
|
||||
writer.WriteField("title", title)
|
||||
writer.WriteField("config_json", configJSON)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(header.Filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", h.workerURL+"/api/tasks/upload", &buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("worker 返回错误 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,350 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type StoreHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewStoreHandler(pool *pgxpool.Pool) *StoreHandler {
|
||||
return &StoreHandler{pool: pool}
|
||||
}
|
||||
|
||||
func (h *StoreHandler) ListCategories(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `SELECT c.id, c.name, c.slug, c.icon, c.description, c.sort_order,
|
||||
COALESCE((SELECT COUNT(*) FROM applications a WHERE a.category_id = c.id AND a.status = 'approved'), 0) AS app_count
|
||||
FROM categories c WHERE c.status = 'active'`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (c.org_id = $1 OR c.org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY c.sort_order ASC`
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询分类失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var cats []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, slug string
|
||||
var icon, desc *string
|
||||
var sortOrder int
|
||||
var appCount int
|
||||
if err := rows.Scan(&id, &name, &slug, &icon, &desc, &sortOrder, &appCount); err != nil {
|
||||
continue
|
||||
}
|
||||
cats = append(cats, map[string]any{
|
||||
"id": id, "name": name, "slug": slug,
|
||||
"icon": icon, "description": desc, "sort_order": sortOrder,
|
||||
"app_count": appCount,
|
||||
})
|
||||
}
|
||||
response.JSON(w, http.StatusOK, cats)
|
||||
}
|
||||
|
||||
func (h *StoreHandler) ListApps(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(q.Get("page_size"))
|
||||
if pageSize < 1 || pageSize > 50 {
|
||||
pageSize = 20
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
search := q.Get("q")
|
||||
category := q.Get("category")
|
||||
sort := q.Get("sort")
|
||||
orgFilter := q.Get("org_id")
|
||||
if sort == "" {
|
||||
sort = "popular"
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
u.name as creator_name,
|
||||
a.usage_count, a.favorite_count, a.avg_rating, a.rating_count,
|
||||
a.dify_app_type, a.welcome_message, a.published_at::text
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.status = 'approved' AND a.visibility = 'public'`
|
||||
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
// 按机构过滤:显示指定机构的应用 + 无机构归属的全局应用
|
||||
if orgFilter != "" {
|
||||
query += ` AND (a.org_id = $` + strconv.Itoa(argIdx) + ` OR a.org_id IS NULL)`
|
||||
args = append(args, orgFilter)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if search != "" {
|
||||
query += ` AND to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
|
||||
@@ plainto_tsquery('simple', $` + strconv.Itoa(argIdx) + `)`
|
||||
args = append(args, search)
|
||||
argIdx++
|
||||
}
|
||||
if category != "" {
|
||||
query += ` AND c.slug = $` + strconv.Itoa(argIdx)
|
||||
args = append(args, category)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
switch sort {
|
||||
case "rating":
|
||||
query += ` ORDER BY a.avg_rating DESC, a.usage_count DESC`
|
||||
case "latest":
|
||||
query += ` ORDER BY a.published_at DESC NULLS LAST`
|
||||
default:
|
||||
query += ` ORDER BY a.usage_count DESC`
|
||||
}
|
||||
|
||||
query += ` LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询应用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var apps []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug string
|
||||
desc, iconURL *string
|
||||
catName, catSlug *string
|
||||
creatorName *string
|
||||
usageCount int64
|
||||
favCount, ratingCt int
|
||||
avgRating float32
|
||||
difyType, welcome *string
|
||||
publishedAt *string
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL,
|
||||
&catName, &catSlug, &creatorName,
|
||||
&usageCount, &favCount, &avgRating, &ratingCt,
|
||||
&difyType, &welcome, &publishedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
apps = append(apps, map[string]any{
|
||||
"id": id, "name": name, "slug": slug,
|
||||
"description": desc, "icon_url": iconURL,
|
||||
"category_name": catName, "category_slug": catSlug,
|
||||
"creator_name": creatorName,
|
||||
"usage_count": usageCount, "favorite_count": favCount,
|
||||
"avg_rating": avgRating, "rating_count": ratingCt,
|
||||
"dify_app_type": difyType, "welcome_message": welcome,
|
||||
"published_at": publishedAt,
|
||||
})
|
||||
}
|
||||
|
||||
if apps == nil {
|
||||
apps = []map[string]any{}
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"items": apps,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *StoreHandler) GetApp(w http.ResponseWriter, r *http.Request) {
|
||||
slug := chi.URLParam(r, "slug")
|
||||
|
||||
var (
|
||||
id, name, appSlug string
|
||||
desc, longDesc *string
|
||||
iconURL *string
|
||||
catName, catSlug *string
|
||||
creatorName *string
|
||||
usageCount int64
|
||||
favCount, ratingCt int
|
||||
avgRating float32
|
||||
difyType, welcome *string
|
||||
suggestedPrompts *string
|
||||
appConfig *string
|
||||
publishedAt *string
|
||||
version string
|
||||
)
|
||||
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.long_description, a.icon_url,
|
||||
c.name, c.slug, u.name,
|
||||
a.usage_count, a.favorite_count, a.avg_rating, a.rating_count,
|
||||
a.dify_app_type, a.welcome_message, a.suggested_prompts::text,
|
||||
a.app_config::text,
|
||||
a.published_at::text, a.version
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.slug = $1 AND a.status = 'approved'`, slug,
|
||||
).Scan(&id, &name, &appSlug, &desc, &longDesc, &iconURL,
|
||||
&catName, &catSlug, &creatorName,
|
||||
&usageCount, &favCount, &avgRating, &ratingCt,
|
||||
&difyType, &welcome, &suggestedPrompts,
|
||||
&appConfig,
|
||||
&publishedAt, &version)
|
||||
|
||||
if err != nil {
|
||||
response.NotFound(w, "应用不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user favorited this app
|
||||
isFavorited := false
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
if userID.String() != "00000000-0000-0000-0000-000000000000" {
|
||||
_ = h.pool.QueryRow(r.Context(),
|
||||
`SELECT EXISTS(SELECT 1 FROM app_favorites WHERE user_id = $1 AND app_id = $2)`,
|
||||
userID, id).Scan(&isFavorited)
|
||||
}
|
||||
|
||||
// Parse app_config as JSON if present
|
||||
var configData any
|
||||
if appConfig != nil {
|
||||
_ = json.Unmarshal([]byte(*appConfig), &configData)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"id": id, "name": name, "slug": appSlug,
|
||||
"description": desc, "long_description": longDesc, "icon_url": iconURL,
|
||||
"category_name": catName, "category_slug": catSlug,
|
||||
"creator_name": creatorName,
|
||||
"usage_count": usageCount, "favorite_count": favCount,
|
||||
"avg_rating": avgRating, "rating_count": ratingCt,
|
||||
"dify_app_type": difyType, "welcome_message": welcome,
|
||||
"suggested_prompts": suggestedPrompts,
|
||||
"app_config": configData,
|
||||
"published_at": publishedAt, "version": version,
|
||||
"is_favorited": isFavorited,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *StoreHandler) Featured(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.is_featured = true AND a.status = 'approved' AND a.visibility = 'public'`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (a.org_id = $1 OR a.org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY a.usage_count DESC LIMIT 4`
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询精选应用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
apps := scanAppList(rows)
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
func (h *StoreHandler) Rankings(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `
|
||||
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.status = 'approved' AND a.visibility = 'public'`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (a.org_id = $1 OR a.org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY a.usage_count DESC LIMIT 50`
|
||||
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询排行榜失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
apps := scanAppList(rows)
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
func (h *StoreHandler) Recent(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
rows, err := h.pool.Query(r.Context(), `
|
||||
SELECT DISTINCT ON (a.id) a.id, a.name, a.slug, a.description, a.icon_url,
|
||||
c.name as category_name, c.slug as category_slug,
|
||||
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
|
||||
FROM app_usage_logs l
|
||||
JOIN applications a ON l.app_id = a.id
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE l.user_id = $1
|
||||
ORDER BY a.id, l.created_at DESC
|
||||
LIMIT 10`, userID)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询最近使用失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
apps := scanAppList(rows)
|
||||
response.JSON(w, http.StatusOK, apps)
|
||||
}
|
||||
|
||||
func scanAppList(rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...any) error
|
||||
}) []map[string]any {
|
||||
var apps []map[string]any
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, name, slug string
|
||||
desc, iconURL *string
|
||||
catName, catSlug *string
|
||||
usageCount int64
|
||||
avgRating float32
|
||||
ratingCt int
|
||||
difyType *string
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL,
|
||||
&catName, &catSlug, &usageCount, &avgRating, &ratingCt, &difyType); err != nil {
|
||||
continue
|
||||
}
|
||||
apps = append(apps, map[string]any{
|
||||
"id": id, "name": name, "slug": slug,
|
||||
"description": desc, "icon_url": iconURL,
|
||||
"category_name": catName, "category_slug": catSlug,
|
||||
"usage_count": usageCount, "avg_rating": avgRating, "rating_count": ratingCt,
|
||||
"dify_app_type": difyType,
|
||||
})
|
||||
}
|
||||
if apps == nil {
|
||||
apps = []map[string]any{}
|
||||
}
|
||||
return apps
|
||||
}
|
||||
Reference in New Issue
Block a user