Files
2026-06-15 23:48:37 +08:00

829 lines
28 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
// 平台级管理 handler — 仅 super_admin 可访问,跨机构操作,不受 org_id 限制。
// 与 admin.go(机构级)的核心区别:
// - admin.go 所有查询都加 WHERE org_id = $orgID 过滤
// - platform.go 不加 org_id 过滤,可看/操作所有机构数据
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"time"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type PlatformHandler struct {
pool *pgxpool.Pool
}
func NewPlatformHandler(pool *pgxpool.Pool) *PlatformHandler {
return &PlatformHandler{pool: pool}
}
// ==================== 平台总览 ====================
// Overview 平台级数据看板:跨所有机构的聚合统计
func (h *PlatformHandler) Overview(w http.ResponseWriter, r *http.Request) {
now := time.Now()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
var totalOrgs, activeOrgs, totalUsers, activeUsers, totalApps, approvedApps, todayLogins, todayConvs int
var monthlyTokens int64
var monthlyCost float64
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations`).Scan(&totalOrgs)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations WHERE is_active = true`).Scan(&activeOrgs)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users`).Scan(&totalUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE status = 'active'`).Scan(&activeUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications`).Scan(&totalApps)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE status = 'approved'`).Scan(&approvedApps)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE last_login_at >= $1`, todayStart).Scan(&todayLogins)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM app_usage_logs WHERE created_at >= $1`, todayStart).Scan(&todayConvs)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(total_tokens),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyTokens)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(estimated_cost),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyCost)
response.JSON(w, http.StatusOK, map[string]any{
"total_orgs": totalOrgs,
"active_orgs": activeOrgs,
"total_users": totalUsers,
"active_users": activeUsers,
"total_apps": totalApps,
"approved_apps": approvedApps,
"today_logins": todayLogins,
"today_convs": todayConvs,
"monthly_tokens": monthlyTokens,
"monthly_cost": monthlyCost,
})
}
// OrgRanking 各机构活跃度排行(用户数 / 应用数 / 月度对话数)
func (h *PlatformHandler) OrgRanking(w http.ResponseWriter, r *http.Request) {
now := time.Now()
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
rows, err := h.pool.Query(r.Context(), `
SELECT o.id, o.name, COALESCE(o.short_name, ''),
COUNT(DISTINCT u.id) AS users,
COUNT(DISTINCT a.id) FILTER (WHERE a.status='approved') AS apps,
COALESCE(SUM(usage.cnt), 0) AS conversations,
COALESCE(SUM(usage.tk), 0) AS tokens
FROM organizations o
LEFT JOIN users u ON u.org_id = o.id
LEFT JOIN applications a ON a.org_id = o.id
LEFT JOIN LATERAL (
SELECT COUNT(*) AS cnt, COALESCE(SUM(total_tokens),0) AS tk
FROM app_usage_logs l
JOIN users uu ON l.user_id = uu.id
WHERE uu.org_id = o.id AND l.created_at >= $1
) usage ON true
WHERE o.is_active = true
GROUP BY o.id, o.name, o.short_name
ORDER BY conversations DESC, users DESC
LIMIT 50`, monthStart)
if err != nil {
response.InternalError(w, "查询失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, short string
users, apps int
conversations, toks int64
)
if err := rows.Scan(&id, &name, &short, &users, &apps, &conversations, &toks); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "short_name": short,
"users": users, "apps": apps,
"conversations": conversations, "tokens": toks,
})
}
response.JSON(w, http.StatusOK, items)
}
// ==================== 机构管理 ====================
func (h *PlatformHandler) ListOrgs(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT o.id, o.name, o.slug, COALESCE(o.short_name,''), COALESCE(o.description,''),
COALESCE(o.logo_url,''), o.sort_order, o.is_active, o.created_at,
COALESCE((SELECT COUNT(*) FROM users WHERE org_id = o.id), 0) AS user_count,
COALESCE((SELECT COUNT(*) FROM applications WHERE org_id = o.id), 0) AS app_count
FROM organizations o ORDER BY o.sort_order, o.created_at`)
if err != nil {
response.InternalError(w, "查询机构失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, slug, short, desc, logo string
sortOrder int
isActive bool
createdAt time.Time
userCount, appCount int
)
if err := rows.Scan(&id, &name, &slug, &short, &desc, &logo, &sortOrder, &isActive, &createdAt, &userCount, &appCount); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "slug": slug, "short_name": short,
"description": desc, "logo_url": logo, "sort_order": sortOrder,
"is_active": isActive, "created_at": createdAt,
"user_count": userCount, "app_count": appCount,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) CreateOrg(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
ShortName string `json:"short_name"`
Description string `json:"description"`
LogoURL string `json:"logo_url"`
SortOrder int `json:"sort_order"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.Slug == "" {
response.BadRequest(w, "机构名称和标识不能为空")
return
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO organizations (name, slug, short_name, description, logo_url, sort_order)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id::text`,
req.Name, req.Slug, req.ShortName, req.Description, req.LogoURL, req.SortOrder,
).Scan(&id)
if err != nil {
response.InternalError(w, "创建机构失败:标识可能已存在")
return
}
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
}
func (h *PlatformHandler) UpdateOrg(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
var req struct {
Name *string `json:"name"`
ShortName *string `json:"short_name"`
Description *string `json:"description"`
LogoURL *string `json:"logo_url"`
SortOrder *int `json:"sort_order"`
IsActive *bool `json:"is_active"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(), `
UPDATE organizations SET
name = COALESCE($2, name),
short_name = COALESCE($3, short_name),
description = COALESCE($4, description),
logo_url = COALESCE($5, logo_url),
sort_order = COALESCE($6, sort_order),
is_active = COALESCE($7, is_active),
updated_at = NOW()
WHERE id = $1`, id, req.Name, req.ShortName, req.Description, req.LogoURL, req.SortOrder, req.IsActive)
if err != nil {
response.InternalError(w, "更新失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *PlatformHandler) DeleteOrg(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// 安全检查:机构下还有用户或应用时禁止物理删除,引导走停用
var userCount, appCount int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE org_id = $1`, id).Scan(&userCount)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE org_id = $1`, id).Scan(&appCount)
if userCount > 0 || appCount > 0 {
response.BadRequest(w, "该机构尚有用户或应用,无法删除。请先停用或迁移数据")
return
}
_, err := h.pool.Exec(r.Context(), `DELETE FROM organizations WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
// ==================== 全局用户管理 ====================
// ListAllUsers 全局用户列表(支持分页、按机构/角色/状态过滤)
func (h *PlatformHandler) ListAllUsers(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
search := q.Get("q")
roleFilter := q.Get("role")
statusFilter := q.Get("status")
orgFilter := q.Get("org_id")
// 构造 WHERE 子句(list 与 count 共用)
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if search != "" {
where += ` AND (u.name ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%' OR u.email ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%')`
args = append(args, search)
argIdx++
}
if roleFilter != "" {
where += ` AND u.role = $` + strconv.Itoa(argIdx)
args = append(args, roleFilter)
argIdx++
}
if statusFilter != "" {
where += ` AND u.status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
if orgFilter != "" {
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
// 先查总数
var total int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users u`+where, args...).Scan(&total)
// 再分页查列表
listQuery := `SELECT u.id, u.name, u.email, u.avatar_url, u.role, u.status, u.employee_id,
u.last_login_at, u.login_count, u.created_at,
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM users u LEFT JOIN organizations o ON u.org_id = o.id` + where +
` ORDER BY u.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询用户失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, email, role, status string
avatarURL, employeeID *string
lastLoginAt *time.Time
loginCount int
createdAt time.Time
orgID, orgName, orgShort string
)
if err := rows.Scan(&id, &name, &email, &avatarURL, &role, &status, &employeeID, &lastLoginAt, &loginCount, &createdAt, &orgID, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "email": email, "avatar_url": avatarURL,
"role": role, "status": status, "employee_id": employeeID,
"last_login_at": lastLoginAt, "login_count": loginCount, "created_at": createdAt,
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// AssignUserOrg 把用户分配/迁移到指定机构
func (h *PlatformHandler) AssignUserOrg(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
OrgID string `json:"org_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.OrgID == "" {
response.BadRequest(w, "机构ID不能为空")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET org_id = $2 WHERE id = $1`, userID, req.OrgID)
if err != nil {
response.InternalError(w, "迁移失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已迁移"})
}
// UpdateUserRole 平台管理员可设置任意角色(包括 super_admin
func (h *PlatformHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
Role string `json:"role"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
valid := map[string]bool{"user": true, "creator": true, "admin": true, "super_admin": true}
if !valid[req.Role] {
response.BadRequest(w, "无效的角色")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET role = $2 WHERE id = $1`, userID, req.Role)
if err != nil {
response.InternalError(w, "更新角色失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// UpdateUserStatus 启用/禁用任意用户
func (h *PlatformHandler) UpdateUserStatus(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Status != "active" && req.Status != "disabled" {
response.BadRequest(w, "无效的状态")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET status = $2 WHERE id = $1`, userID, req.Status)
if err != nil {
response.InternalError(w, "更新状态失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// ==================== 全局应用管理 ====================
// ListAllApps 全局应用列表(支持分页、按机构/状态过滤)
func (h *PlatformHandler) ListAllApps(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
statusFilter := q.Get("status")
orgFilter := q.Get("org_id")
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if statusFilter != "" {
where += ` AND a.status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
if orgFilter != "" {
where += ` AND a.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
var total int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications a`+where, args...).Scan(&total)
listQuery := `SELECT a.id, a.name, a.slug, a.description, a.icon_url,
a.dify_app_type, a.status, a.visibility, a.usage_count, a.is_featured, a.created_at,
COALESCE(c.name, ''), COALESCE(u.name, ''),
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
LEFT JOIN organizations o ON a.org_id = o.id` + where +
` ORDER BY a.usage_count DESC, a.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询应用失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, slug, status, visibility string
desc, iconURL *string
appType *string
usageCount int64
isFeatured bool
createdAt time.Time
catName, creatorName string
orgID, orgName, orgShort string
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &appType, &status, &visibility,
&usageCount, &isFeatured, &createdAt, &catName, &creatorName, &orgID, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "slug": slug, "description": desc, "icon_url": iconURL,
"dify_app_type": appType, "status": status, "visibility": visibility,
"usage_count": usageCount, "is_featured": isFeatured, "created_at": createdAt,
"category_name": catName, "creator_name": creatorName,
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// SetFeatured 平台级精选/取消精选
func (h *PlatformHandler) SetFeatured(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
var req struct {
IsFeatured bool `json:"is_featured"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET is_featured = $2 WHERE id = $1`, appID, req.IsFeatured)
if err != nil {
response.InternalError(w, "操作失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// ForceDelist 平台强制下架(不受机构限制)
func (h *PlatformHandler) ForceDelist(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
if err != nil {
response.InternalError(w, "下架失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已强制下架"})
}
// ==================== 全局审计日志 ====================
// ListAllAuditLogs 全局审计日志(支持分页、按机构/操作类型过滤)
func (h *PlatformHandler) ListAllAuditLogs(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
orgFilter := q.Get("org_id")
actionFilter := q.Get("action")
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if orgFilter != "" {
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
if actionFilter != "" {
where += ` AND al.action ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%'`
args = append(args, actionFilter)
argIdx++
}
var total int
h.pool.QueryRow(r.Context(),
`SELECT COUNT(*) FROM audit_logs al LEFT JOIN users u ON al.user_id = u.id`+where,
args...).Scan(&total)
listQuery := `SELECT al.id, al.action, al.resource_type, al.resource_id::text, al.details,
al.ip_address::text, al.created_at,
COALESCE(u.name, ''), COALESCE(u.email, ''),
COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN organizations o ON u.org_id = o.id` + where +
` ORDER BY al.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询审计日志失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, action, resType string
resID, ipAddr *string
details json.RawMessage
createdAt time.Time
userName, userEmail, orgName, orgShort string
)
if err := rows.Scan(&id, &action, &resType, &resID, &details, &ipAddr, &createdAt,
&userName, &userEmail, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "action": action, "resource_type": resType, "resource_id": resID,
"details": details, "ip_address": ipAddr, "created_at": createdAt,
"user_name": userName, "user_email": userEmail,
"org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// ==================== 模型提供商管理 ====================
func (h *PlatformHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT id, name, base_url, models, is_active, priority, created_at, updated_at
FROM model_providers ORDER BY priority DESC, created_at`)
if err != nil {
response.InternalError(w, "查询失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, baseURL string
models json.RawMessage
isActive bool
priority int
createdAt, updatedAt time.Time
)
if err := rows.Scan(&id, &name, &baseURL, &models, &isActive, &priority, &createdAt, &updatedAt); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "base_url": baseURL, "models": models,
"is_active": isActive, "priority": priority,
"created_at": createdAt, "updated_at": updatedAt,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
Models json.RawMessage `json:"models"`
IsActive bool `json:"is_active"`
Priority int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.BaseURL == "" || req.APIKey == "" {
response.BadRequest(w, "名称、URL、API Key 不能为空")
return
}
if len(req.Models) == 0 {
req.Models = []byte(`[]`)
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO model_providers (name, base_url, api_key_encrypted, models, is_active, priority)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id::text`,
req.Name, req.BaseURL, req.APIKey, req.Models, req.IsActive, req.Priority,
).Scan(&id)
if err != nil {
response.InternalError(w, "创建失败")
return
}
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
}
func (h *PlatformHandler) UpdateProvider(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
var req struct {
Name *string `json:"name"`
BaseURL *string `json:"base_url"`
APIKey *string `json:"api_key"`
Models json.RawMessage `json:"models"`
IsActive *bool `json:"is_active"`
Priority *int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
// 仅当 api_key 非空才更新(避免误清空)
var apiKeyArg any = nil
if req.APIKey != nil && *req.APIKey != "" {
apiKeyArg = *req.APIKey
}
var modelsArg any = nil
if len(req.Models) > 0 {
modelsArg = req.Models
}
_, err := h.pool.Exec(r.Context(), `
UPDATE model_providers SET
name = COALESCE($2, name),
base_url = COALESCE($3, base_url),
api_key_encrypted = COALESCE($4, api_key_encrypted),
models = COALESCE($5::jsonb, models),
is_active = COALESCE($6, is_active),
priority = COALESCE($7, priority),
updated_at = NOW()
WHERE id = $1`,
id, req.Name, req.BaseURL, apiKeyArg, modelsArg, req.IsActive, req.Priority)
if err != nil {
response.InternalError(w, "更新失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *PlatformHandler) DeleteProvider(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_providers WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
// ==================== 全局配额管理 ====================
func (h *PlatformHandler) ListQuotas(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT q.id, q.target_type, q.target_id::text, q.model_name,
q.daily_token_limit, q.monthly_token_limit, q.daily_request_limit,
q.is_active, q.created_at,
COALESCE(target_name.name, '')
FROM model_quotas q
LEFT JOIN LATERAL (
SELECT CASE q.target_type
WHEN 'global' THEN '全局'
WHEN 'department' THEN (SELECT name FROM departments WHERE id = q.target_id)
WHEN 'user' THEN (SELECT name FROM users WHERE id = q.target_id)
END AS name
) target_name ON true
ORDER BY q.target_type, q.created_at DESC LIMIT 200`)
if err != nil {
response.InternalError(w, "查询配额失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, targetType string
targetID *string
modelName *string
dailyTokens, monthlyTokens *int64
dailyReqs *int
isActive bool
createdAt time.Time
targetName string
)
if err := rows.Scan(&id, &targetType, &targetID, &modelName, &dailyTokens, &monthlyTokens,
&dailyReqs, &isActive, &createdAt, &targetName); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "target_type": targetType, "target_id": targetID,
"target_name": targetName,
"model_name": modelName,
"daily_token_limit": dailyTokens,
"monthly_token_limit": monthlyTokens,
"daily_request_limit": dailyReqs,
"is_active": isActive,
"created_at": createdAt,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) UpsertQuota(w http.ResponseWriter, r *http.Request) {
var req struct {
ID *string `json:"id"`
TargetType string `json:"target_type"`
TargetID *string `json:"target_id"`
ModelName *string `json:"model_name"`
DailyTokenLimit *int64 `json:"daily_token_limit"`
MonthlyTokenLimit *int64 `json:"monthly_token_limit"`
DailyRequestLimit *int `json:"daily_request_limit"`
IsActive bool `json:"is_active"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
valid := map[string]bool{"global": true, "department": true, "user": true}
if !valid[req.TargetType] {
response.BadRequest(w, "无效的目标类型")
return
}
if req.TargetType != "global" && (req.TargetID == nil || *req.TargetID == "") {
response.BadRequest(w, "部门/用户配额必须指定 target_id")
return
}
if req.ID != nil && *req.ID != "" {
_, err := h.pool.Exec(r.Context(), `
UPDATE model_quotas SET
target_type = $2, target_id = NULLIF($3,'')::uuid, model_name = $4,
daily_token_limit = $5, monthly_token_limit = $6, daily_request_limit = $7,
is_active = $8, updated_at = NOW()
WHERE id = $1`,
*req.ID, req.TargetType, nullableString(req.TargetID), req.ModelName,
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive)
if err != nil {
response.InternalError(w, "更新配额失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"id": *req.ID})
return
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO model_quotas (target_type, target_id, model_name, daily_token_limit, monthly_token_limit, daily_request_limit, is_active)
VALUES ($1, NULLIF($2,'')::uuid, $3, $4, $5, $6, $7)
RETURNING id::text`,
req.TargetType, nullableString(req.TargetID), req.ModelName,
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive,
).Scan(&id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
response.InternalError(w, "创建配额失败")
return
}
response.InternalError(w, "创建配额失败")
return
}
response.JSON(w, http.StatusCreated, map[string]string{"id": id})
}
func (h *PlatformHandler) DeleteQuota(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_quotas WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
func nullableString(s *string) string {
if s == nil {
return ""
}
return *s
}