Initial commit: GovAI 政务AI平台

This commit is contained in:
freedakgmail
2026-06-15 23:48:37 +08:00
commit 0f490f72a9
245 changed files with 51669 additions and 0 deletions
+135
View File
@@ -0,0 +1,135 @@
package config
import (
"os"
"time"
)
type Config struct {
Server ServerConfig
Database DatabaseConfig
Redis RedisConfig
JWT JWTConfig
Dify DifyConfig
LLM LLMConfig
Embedding EmbeddingConfig
Gateway GatewayConfig
MinIO MinIOConfig
PPTWorker PPTWorkerConfig
}
type PPTWorkerConfig struct {
URL string // PPT Worker 微服务地址
}
type LLMConfig struct {
Provider string // "openai" or "anthropic"
OpenAIKey string
OpenAIBaseURL string
OpenAIModel string
AnthropicKey string
AnthropicBaseURL string
AnthropicModel string
}
type EmbeddingConfig struct {
APIKey string // Embedding API 密钥
BaseURL string // Embedding API 基础 URLOpenAI 兼容格式)
Model string // 向量模型名称
Dimensions int // 向量维度
}
type ServerConfig struct {
Host string
Port string
}
type DatabaseConfig struct {
URL string
}
type RedisConfig struct {
URL string
}
type JWTConfig struct {
Secret string
AccessExpiry time.Duration
RefreshExpiry time.Duration
}
type DifyConfig struct {
APIURL string
APIKey string
}
type GatewayConfig struct {
URL string
}
type MinIOConfig struct {
Endpoint string
AccessKey string
SecretKey string
Bucket string
UseSSL bool
}
func Load() *Config {
return &Config{
Server: ServerConfig{
Host: getEnv("SERVER_HOST", "0.0.0.0"),
Port: getEnv("SERVER_PORT", "8080"),
},
Database: DatabaseConfig{
URL: getEnv("DATABASE_URL", "postgres://localhost:5432/govai_portal?sslmode=disable"),
},
Redis: RedisConfig{
URL: getEnv("REDIS_URL", "redis://localhost:6379"),
},
JWT: JWTConfig{
Secret: getEnv("JWT_SECRET", "dev-secret-change-in-production"),
AccessExpiry: 24 * time.Hour,
RefreshExpiry: 7 * 24 * time.Hour,
},
Dify: DifyConfig{
APIURL: getEnv("DIFY_API_URL", "http://localhost:5001/v1"),
APIKey: getEnv("DIFY_API_KEY", ""),
},
LLM: LLMConfig{
Provider: getEnv("LLM_PROVIDER", "openai"),
OpenAIKey: getEnv("OPENAI_API_KEY", ""),
OpenAIBaseURL: getEnv("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
OpenAIModel: getEnv("OPENAI_MODEL", "qwen-plus"),
AnthropicKey: getEnv("ANTHROPIC_API_KEY", ""),
AnthropicBaseURL: getEnv("ANTHROPIC_BASE_URL", "https://api.anthropic.com"),
AnthropicModel: getEnv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514"),
},
Embedding: EmbeddingConfig{
APIKey: getEnv("EMBEDDING_API_KEY", ""),
BaseURL: getEnv("EMBEDDING_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
Model: getEnv("EMBEDDING_MODEL", "text-embedding-v3"),
Dimensions: 1024,
},
Gateway: GatewayConfig{
URL: getEnv("MODEL_GATEWAY_URL", "http://localhost:8081"),
},
MinIO: MinIOConfig{
Endpoint: getEnv("MINIO_ENDPOINT", "localhost:9000"),
AccessKey: getEnv("MINIO_ACCESS_KEY", "minioadmin"),
SecretKey: getEnv("MINIO_SECRET_KEY", "minioadmin"),
Bucket: getEnv("MINIO_BUCKET", "aily-files"),
UseSSL: false,
},
PPTWorker: PPTWorkerConfig{
URL: getEnv("PPT_WORKER_URL", "http://localhost:8090"),
},
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
+614
View File
@@ -0,0 +1,614 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type AdminHandler struct {
pool *pgxpool.Pool
}
func NewAdminHandler(pool *pgxpool.Pool) *AdminHandler {
return &AdminHandler{pool: pool}
}
// getUserOrgID 通过当前登录用户ID查询其所属机构ID,用于多租户数据隔离
func (h *AdminHandler) getUserOrgID(r *http.Request) (string, error) {
userID := middleware.GetUserID(r.Context())
var orgID string
err := h.pool.QueryRow(r.Context(),
`SELECT COALESCE(org_id::text, '') FROM users WHERE id = $1`, userID).Scan(&orgID)
return orgID, err
}
// --- Overview Stats ---
func (h *AdminHandler) Overview(w http.ResponseWriter, r *http.Request) {
orgID, err := h.getUserOrgID(r)
if err != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
now := time.Now()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
var totalUsers, totalApps, activeUsers, todayConversations int
var monthlyTokens int64
var monthlyCost float64
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE status = 'active' AND org_id = $1`, orgID).Scan(&totalUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE status = 'approved' AND org_id = $1`, orgID).Scan(&totalApps)
h.pool.QueryRow(r.Context(), `SELECT COUNT(DISTINCT l.user_id) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, todayStart, orgID).Scan(&activeUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, todayStart, orgID).Scan(&todayConversations)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(l.total_tokens), 0) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, monthStart, orgID).Scan(&monthlyTokens)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(l.estimated_cost), 0) FROM app_usage_logs l JOIN users u ON l.user_id = u.id WHERE l.created_at >= $1 AND u.org_id = $2`, monthStart, orgID).Scan(&monthlyCost)
response.JSON(w, http.StatusOK, map[string]any{
"total_users": totalUsers,
"total_apps": totalApps,
"active_users": activeUsers,
"total_conversations": todayConversations,
"monthly_tokens": monthlyTokens,
"monthly_cost": monthlyCost,
})
}
// --- User Management ---
func (h *AdminHandler) ListUsers(w http.ResponseWriter, r *http.Request) {
orgID, err := h.getUserOrgID(r)
if err != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
search := q.Get("q")
roleFilter := q.Get("role")
statusFilter := q.Get("status")
query := `SELECT id, name, email, avatar_url, role, status, employee_id, last_login_at, login_count, created_at
FROM users WHERE org_id = $1`
args := []any{orgID}
argIdx := 2
if search != "" {
query += ` AND (name ILIKE '%' || $` + strconv.Itoa(argIdx) + ` || '%' OR email ILIKE '%' || $` + strconv.Itoa(argIdx) + ` || '%')`
args = append(args, search)
argIdx++
}
if roleFilter != "" {
query += ` AND role = $` + strconv.Itoa(argIdx)
args = append(args, roleFilter)
argIdx++
}
if statusFilter != "" {
query += ` AND status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
query += ` ORDER BY created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
args = append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询用户失败")
return
}
defer rows.Close()
var users []map[string]any
for rows.Next() {
var (
id, name, email, role, status string
avatarURL, employeeID *string
lastLoginAt *time.Time
loginCount int
createdAt time.Time
)
if err := rows.Scan(&id, &name, &email, &avatarURL, &role, &status, &employeeID, &lastLoginAt, &loginCount, &createdAt); err != nil {
continue
}
users = append(users, map[string]any{
"id": id, "name": name, "email": email,
"avatar_url": avatarURL, "role": role, "status": status,
"employee_id": employeeID, "last_login_at": lastLoginAt,
"login_count": loginCount, "created_at": createdAt,
})
}
if users == nil {
users = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"items": users, "page": page})
}
func (h *AdminHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
operatorRole := middleware.GetRole(r.Context())
var req struct {
Role string `json:"role"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
validRoles := map[string]bool{"user": true, "creator": true, "admin": true, "super_admin": true}
if !validRoles[req.Role] {
response.BadRequest(w, "无效的角色")
return
}
if req.Role == "super_admin" && operatorRole != "super_admin" {
response.Forbidden(w, "只有超级管理员才能设置超级管理员角色")
return
}
// 确保只能修改本机构用户
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET role = $2 WHERE id = $1 AND org_id = $3`, userID, req.Role, orgID)
if err != nil {
response.InternalError(w, "更新角色失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
func (h *AdminHandler) UpdateUserStatus(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Status != "active" && req.Status != "disabled" {
response.BadRequest(w, "无效的状态")
return
}
// 确保只能修改本机构用户
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET status = $2 WHERE id = $1 AND org_id = $3`, userID, req.Status, orgID)
if err != nil {
response.InternalError(w, "更新状态失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
// --- App Management ---
func (h *AdminHandler) ListAllApps(w http.ResponseWriter, r *http.Request) {
orgID, err := h.getUserOrgID(r)
if err != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
offset := (page - 1) * 20
statusFilter := q.Get("status")
query := `SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, u.name as creator_name,
a.dify_app_type, a.status, a.visibility, a.usage_count, a.created_at
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.org_id = $1`
args := []any{orgID}
argIdx := 2
if statusFilter != "" {
query += ` AND a.status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
query += ` ORDER BY a.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
args = append(args, 20, offset)
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询应用失败")
return
}
defer rows.Close()
var apps []map[string]any
for rows.Next() {
var (
id, name, slug, status, visibility string
desc, iconURL, catName, creator *string
appType *string
usageCount int64
createdAt time.Time
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &catName, &creator,
&appType, &status, &visibility, &usageCount, &createdAt); err != nil {
continue
}
apps = append(apps, map[string]any{
"id": id, "name": name, "slug": slug, "description": desc,
"icon_url": iconURL, "category_name": catName, "creator_name": creator,
"dify_app_type": appType,
"status": status, "visibility": visibility, "usage_count": usageCount,
"created_at": createdAt,
})
}
if apps == nil {
apps = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"items": apps, "page": page})
}
// --- Audit Logs ---
func (h *AdminHandler) ListAuditLogs(w http.ResponseWriter, r *http.Request) {
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
offset := (page - 1) * 50
rows, err := h.pool.Query(r.Context(), `
SELECT al.id, al.action, al.resource_type, al.resource_id,
al.details, al.ip_address, al.created_at,
u.name as user_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
WHERE u.org_id = $2
ORDER BY al.created_at DESC
LIMIT 50 OFFSET $1`, offset, orgID)
if err != nil {
response.InternalError(w, "查询审计日志失败")
return
}
defer rows.Close()
var logs []map[string]any
for rows.Next() {
var (
id, action, resType string
resID *string
details json.RawMessage
ipAddr *string
createdAt time.Time
userName *string
)
if err := rows.Scan(&id, &action, &resType, &resID, &details, &ipAddr, &createdAt, &userName); err != nil {
continue
}
logs = append(logs, map[string]any{
"id": id, "action": action, "resource_type": resType,
"resource_id": resID, "details": details,
"ip_address": ipAddr, "created_at": createdAt, "user_name": userName,
})
}
if logs == nil {
logs = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"items": logs, "page": page})
}
// --- Usage Analytics ---
func (h *AdminHandler) UsageAnalytics(w http.ResponseWriter, r *http.Request) {
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
days := 7
if d, err := strconv.Atoi(r.URL.Query().Get("days")); err == nil && d > 0 && d <= 90 {
days = d
}
rows, err := h.pool.Query(r.Context(), `
SELECT DATE(l.created_at) as day, COUNT(*) as count, COALESCE(SUM(l.total_tokens), 0) as tokens
FROM app_usage_logs l
JOIN users u ON l.user_id = u.id
WHERE l.created_at >= NOW() - $1::interval AND u.org_id = $2
GROUP BY DATE(l.created_at)
ORDER BY day`, strconv.Itoa(days)+" days", orgID)
if err != nil {
response.InternalError(w, "查询使用统计失败")
return
}
defer rows.Close()
var dailyStats []map[string]any
for rows.Next() {
var (
day time.Time
count int
tokens int64
)
if err := rows.Scan(&day, &count, &tokens); err != nil {
continue
}
dailyStats = append(dailyStats, map[string]any{
"date": day.Format("2006-01-02"),
"count": count,
"total_tokens": tokens,
})
}
if dailyStats == nil {
dailyStats = []map[string]any{}
}
// Top apps
topRows, err := h.pool.Query(r.Context(), `
SELECT a.name, COUNT(l.id) as usage_count
FROM app_usage_logs l
JOIN applications a ON l.app_id = a.id
WHERE l.created_at >= NOW() - $1::interval AND a.org_id = $2
GROUP BY a.name
ORDER BY usage_count DESC LIMIT 10`, strconv.Itoa(days)+" days", orgID)
if err != nil {
response.JSON(w, http.StatusOK, map[string]any{"daily": dailyStats})
return
}
defer topRows.Close()
var topApps []map[string]any
for topRows.Next() {
var name string
var count int
if err := topRows.Scan(&name, &count); err != nil {
continue
}
topApps = append(topApps, map[string]any{"name": name, "count": count})
}
if topApps == nil {
topApps = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{
"daily": dailyStats,
"top_apps": topApps,
})
}
// --- Review Management ---
func (h *AdminHandler) ListPendingReviews(w http.ResponseWriter, r *http.Request) {
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
rows, err := h.pool.Query(r.Context(), `
SELECT r.id, r.app_id, r.version, r.submit_comment, r.submitted_at,
a.name as app_name, a.description as app_description, a.icon_url,
u.name as submitter_name
FROM app_reviews r
JOIN applications a ON r.app_id = a.id
JOIN users u ON r.submitter_id = u.id
WHERE r.status = 'pending' AND a.org_id = $1
ORDER BY r.submitted_at ASC LIMIT 50`, orgID)
if err != nil {
response.InternalError(w, "查询审核列表失败")
return
}
defer rows.Close()
var reviews []map[string]any
for rows.Next() {
var (
id, appID, version string
comment *string
submittedAt time.Time
appName, appDesc *string
appIcon *string
submitterName string
)
if err := rows.Scan(&id, &appID, &version, &comment, &submittedAt,
&appName, &appDesc, &appIcon, &submitterName); err != nil {
continue
}
reviews = append(reviews, map[string]any{
"id": id, "app_id": appID, "version": version,
"submit_comment": comment, "submitted_at": submittedAt,
"app_name": appName, "app_description": appDesc, "app_icon": appIcon,
"submitter_name": submitterName,
})
}
if reviews == nil {
reviews = []map[string]any{}
}
response.JSON(w, http.StatusOK, reviews)
}
func (h *AdminHandler) ApproveReview(w http.ResponseWriter, r *http.Request) {
reviewID := chi.URLParam(r, "id")
reviewerID := middleware.GetUserID(r.Context())
var req struct {
Comment string `json:"comment"`
}
json.NewDecoder(r.Body).Decode(&req)
// Get app_id from review
var appID string
err := h.pool.QueryRow(r.Context(),
`SELECT app_id FROM app_reviews WHERE id = $1 AND status = 'pending'`, reviewID).Scan(&appID)
if err != nil {
response.NotFound(w, "审核记录不存在或已处理")
return
}
tx, err := h.pool.Begin(r.Context())
if err != nil {
response.InternalError(w, "事务开始失败")
return
}
defer tx.Rollback(r.Context())
tx.Exec(r.Context(), `
UPDATE app_reviews SET status = 'approved', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
WHERE id = $1`, reviewID, reviewerID, req.Comment)
tx.Exec(r.Context(), `
UPDATE applications SET status = 'approved', published_at = NOW()
WHERE id = $1`, appID)
if err := tx.Commit(r.Context()); err != nil {
response.InternalError(w, "审核通过失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
func (h *AdminHandler) RejectReview(w http.ResponseWriter, r *http.Request) {
reviewID := chi.URLParam(r, "id")
reviewerID := middleware.GetUserID(r.Context())
var req struct {
Comment string `json:"comment"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Comment == "" {
response.BadRequest(w, "驳回必须填写原因")
return
}
var appID string
err := h.pool.QueryRow(r.Context(),
`SELECT app_id FROM app_reviews WHERE id = $1 AND status = 'pending'`, reviewID).Scan(&appID)
if err != nil {
response.NotFound(w, "审核记录不存在或已处理")
return
}
tx, err := h.pool.Begin(r.Context())
if err != nil {
response.InternalError(w, "事务开始失败")
return
}
defer tx.Rollback(r.Context())
tx.Exec(r.Context(), `
UPDATE app_reviews SET status = 'rejected', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
WHERE id = $1`, reviewID, reviewerID, req.Comment)
tx.Exec(r.Context(), `
UPDATE applications SET status = 'rejected'
WHERE id = $1`, appID)
if err := tx.Commit(r.Context()); err != nil {
response.InternalError(w, "驳回失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
func (h *AdminHandler) DelistApp(w http.ResponseWriter, r *http.Request) {
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
appID := chi.URLParam(r, "id")
var status string
err := h.pool.QueryRow(r.Context(),
`SELECT status FROM applications WHERE id = $1 AND org_id = $2`, appID, orgID).Scan(&status)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
if status != "approved" {
response.BadRequest(w, "只有已上架的应用可以撤架")
return
}
_, err = h.pool.Exec(r.Context(),
`UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
if err != nil {
response.InternalError(w, "撤架失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已撤架"})
}
func (h *AdminHandler) RelistApp(w http.ResponseWriter, r *http.Request) {
orgID, orgErr := h.getUserOrgID(r)
if orgErr != nil || orgID == "" {
response.InternalError(w, "无法确定当前机构")
return
}
appID := chi.URLParam(r, "id")
var status string
err := h.pool.QueryRow(r.Context(),
`SELECT status FROM applications WHERE id = $1 AND org_id = $2`, appID, orgID).Scan(&status)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
if status != "archived" {
response.BadRequest(w, "只有已归档的应用可以重新上架")
return
}
_, err = h.pool.Exec(r.Context(),
`UPDATE applications SET status = 'approved', published_at = NOW() WHERE id = $1`, appID)
if err != nil {
response.InternalError(w, "重新上架失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已重新上架"})
}
@@ -0,0 +1,305 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/llm"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type AnalysisTemplateHandler struct {
pool *pgxpool.Pool
manager *llm.Manager
provider string
}
func NewAnalysisTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *AnalysisTemplateHandler {
return &AnalysisTemplateHandler{pool: pool, manager: manager, provider: provider}
}
func (h *AnalysisTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("org_id")
query := `SELECT id, name, report_type, COALESCE(description,''), COALESCE(icon,''), steps, sort_order
FROM analysis_report_templates
WHERE is_active = true`
var args []any
if orgID != "" {
query += ` AND (org_id = $1 OR org_id IS NULL)`
args = append(args, orgID)
}
query += ` ORDER BY sort_order ASC`
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询模板失败")
return
}
defer rows.Close()
var templates []map[string]any
for rows.Next() {
var id, name, reportType, desc, icon string
var steps json.RawMessage
var sortOrder int
if err := rows.Scan(&id, &name, &reportType, &desc, &icon, &steps, &sortOrder); err != nil {
continue
}
templates = append(templates, map[string]any{
"id": id,
"name": name,
"report_type": reportType,
"description": desc,
"icon": icon,
"steps": steps,
"sort_order": sortOrder,
})
}
if templates == nil {
templates = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
}
func (h *AnalysisTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "templateId")
var name, reportType, desc, icon, promptTpl string
var steps, outputSections json.RawMessage
err := h.pool.QueryRow(r.Context(), `
SELECT name, report_type, COALESCE(description,''), COALESCE(icon,''),
steps, prompt_template, COALESCE(output_sections, '[]')
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, id,
).Scan(&name, &reportType, &desc, &icon, &steps, &promptTpl, &outputSections)
if err != nil {
response.NotFound(w, "模板不存在")
return
}
response.JSON(w, http.StatusOK, map[string]any{
"id": id,
"name": name,
"report_type": reportType,
"description": desc,
"icon": icon,
"steps": steps,
"prompt_template": promptTpl,
"output_sections": outputSections,
})
}
type generateAnalysisRequest struct {
TemplateID string `json:"template_id"`
FieldData map[string]string `json:"field_data"`
}
func containsAny(list []string, targets ...string) bool {
for _, item := range list {
for _, t := range targets {
if item == t {
return true
}
}
}
return false
}
func (h *AnalysisTemplateHandler) queryEconomicData(ctx context.Context, districts, timeRange string) string {
years := []int{2025}
switch timeRange {
case "2024-2025年两年对比", "2024_2025":
years = []int{2024, 2025}
case "2023-2025年三年趋势", "2023_2025":
years = []int{2023, 2024, 2025}
case "2024年全年", "2024_full":
years = []int{2024}
}
districtList := strings.Split(districts, ",")
for i := range districtList {
districtList[i] = strings.TrimSpace(districtList[i])
}
rows, err := h.pool.Query(ctx, `
SELECT district_name, year,
COALESCE(gdp,0), COALESCE(gdp_growth,0), COALESCE(gdp_per_capita,0),
COALESCE(fiscal_revenue,0), COALESCE(fiscal_revenue_growth,0),
COALESCE(fixed_investment,0), COALESCE(fixed_investment_growth,0),
COALESCE(retail_sales,0), COALESCE(retail_sales_growth,0),
COALESCE(industrial_output,0), COALESCE(industrial_output_growth,0),
COALESCE(tertiary_ratio,0),
COALESCE(import_export,0), COALESCE(import_export_growth,0),
COALESCE(actual_fdi,0),
COALESCE(tech_expenditure,0), COALESCE(tech_expenditure_ratio,0),
COALESCE(population,0), COALESCE(urban_income,0), COALESCE(rural_income,0)
FROM regional_economic_data
WHERE year = ANY($1)
AND (cardinality($2::text[]) = 0 OR district_name = ANY($2))
ORDER BY year ASC, gdp DESC`, years, districtList)
if err != nil {
return ""
}
defer rows.Close()
var sb strings.Builder
sb.WriteString("\n\n## 【数据库真实数据】\n\n")
sb.WriteString("| 区县 | 年份 | GDP(亿元) | GDP增速(%) | 人均GDP(元) | 财政收入(亿元) | 财政增速(%) | 固投(亿元) | 固投增速(%) | 社零(亿元) | 社零增速(%) | 规上工业(亿元) | 工业增速(%) | 三产占比(%) | 进出口(亿元) | 进出口增速(%) | 利用外资(亿美元) | R&D投入(亿元) | R&D/GDP(%) | 人口(万) | 城镇收入(元) | 农村收入(元) |\n")
sb.WriteString("|------|------|-----------|-----------|------------|-------------|-----------|----------|-----------|----------|-----------|-------------|-----------|-----------|-----------|------------|-------------|------------|----------|----------|-----------|----------|\n")
for rows.Next() {
var name string
var year int
var gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri float64
if err := rows.Scan(&name, &year, &gdp, &gdpG, &gdpPC, &fr, &frG, &fi, &fiG, &rs, &rsG, &io, &ioG, &tr, &ie, &ieG, &fdi, &te, &teR, &pop, &ui, &ri); err != nil {
continue
}
sb.WriteString(fmt.Sprintf("| %s | %d | %.1f | %.1f | %.0f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.2f | %.1f | %.2f | %.1f | %.0f | %.0f |\n",
name, year, gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri))
}
return sb.String()
}
func (h *AnalysisTemplateHandler) GenerateReport(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req generateAnalysisRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效请求格式")
return
}
if req.TemplateID == "" {
response.BadRequest(w, "请选择报告模板")
return
}
var appModel string
var appMaxTokens int
var appConfigRaw json.RawMessage
_ = h.pool.QueryRow(r.Context(), `
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 8192), COALESCE(app_config, '{}')
FROM applications WHERE id = $1`, appID,
).Scan(&appModel, &appMaxTokens, &appConfigRaw)
var appCfg struct {
Tools []string `json:"tools"`
DataSources []string `json:"data_sources"`
TemplateSet string `json:"template_set"`
}
_ = json.Unmarshal(appConfigRaw, &appCfg)
var promptTpl, tplName, reportType string
err := h.pool.QueryRow(r.Context(), `
SELECT name, prompt_template, report_type
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
).Scan(&tplName, &promptTpl, &reportType)
if err != nil {
response.NotFound(w, "模板不存在")
return
}
prompt := promptTpl
for k, v := range req.FieldData {
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
}
prompt = strings.ReplaceAll(prompt, "{{", "")
prompt = strings.ReplaceAll(prompt, "}}", "")
hasDataTool := containsAny(appCfg.Tools, "economic_data_query", "data_analysis")
hasDataSource := containsAny(appCfg.DataSources, "regional_economic_data", "statistical_yearbook")
useEconomicData := hasDataTool || hasDataSource || reportType == "economic_comparison" || reportType == "trend_analysis"
if useEconomicData {
dataTable := h.queryEconomicData(r.Context(), req.FieldData["districts"], req.FieldData["time_range"])
if dataTable != "" {
prompt += "\n\n⚠️ 以下是从数据库中检索到的真实经济数据,请基于这些真实数据进行分析,不要编造数据:" + dataTable
}
}
llmReq := &llm.ChatRequest{
Model: appModel,
Messages: []llm.Message{
{Role: llm.RoleSystem, Content: "你是一位资深的政务数据分析和研判专家。要求:1)基于提供的真实数据库数据进行分析 2)报告必须完整输出到最后一个章节,不能中途截断 3)保持精炼,每个章节500字以内 4)使用Markdown格式排版 5)表格数据必须引用真实数据 6)报告输出完毕后,必须在最后一行单独输出「---\\n\\n**【完稿】**」作为完成标记"},
{Role: llm.RoleUser, Content: prompt},
},
Temperature: 0.4,
MaxTokens: appMaxTokens,
Stream: true,
}
startTime := time.Now()
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
return
}
defer body.Close()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
convID := uuid.New().String()
msgID := uuid.New().String()
var fullResponse strings.Builder
firstEvent := map[string]string{
"conversation_id": convID,
"message_id": msgID,
"template_name": tplName,
}
data, _ := json.Marshal(firstEvent)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
transform := llm.TransformOpenAIStream
if h.provider == "anthropic" {
transform = llm.TransformAnthropicStream
}
var totalTokens int
var modelName string
_ = transform(body, func(event llm.StreamEvent) {
if event.Answer != "" {
fullResponse.WriteString(event.Answer)
}
if event.Usage != nil {
totalTokens = event.Usage.TotalTokens
modelName = event.Usage.Model
}
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
})
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
duration := time.Since(startTime).Milliseconds()
userMsg := fmt.Sprintf("[研判报告] %s", tplName)
go func() {
ctx := context.Background()
_, _ = h.pool.Exec(ctx, `
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
_, _ = h.pool.Exec(ctx,
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
}()
}
+418
View File
@@ -0,0 +1,418 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/auth"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type AuthHandler struct {
pool *pgxpool.Pool
jwtMgr *auth.JWTManager
}
func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler {
return &AuthHandler{pool: pool, jwtMgr: jwtMgr}
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
OrgID string `json:"org_id"`
}
type orgInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
ShortName string `json:"short_name"`
}
type userResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL *string `json:"avatar_url"`
Role string `json:"role"`
EmployeeID *string `json:"employee_id"`
OrgID *string `json:"org_id"`
Org *orgInfo `json:"org,omitempty"`
}
type registerRequest struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req registerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.Email == "" || req.Password == "" {
response.BadRequest(w, "姓名、邮箱和密码不能为空")
return
}
if len(req.Password) < 6 {
response.BadRequest(w, "密码长度不能少于6位")
return
}
var exists bool
h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists)
if exists {
response.Error(w, http.StatusConflict, 40901, "该邮箱已注册")
return
}
hash, err := auth.HashPassword(req.Password)
if err != nil {
response.InternalError(w, "密码加密失败")
return
}
id := uuid.New()
_, err = h.pool.Exec(r.Context(),
`INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`,
id, req.Name, req.Email, hash)
if err != nil {
response.InternalError(w, "注册失败")
return
}
tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user")
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
response.JSON(w, http.StatusCreated, map[string]any{
"user": userResponse{
ID: id.String(), Name: req.Name, Email: req.Email, Role: "user",
},
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Email == "" || req.Password == "" {
response.BadRequest(w, "邮箱和密码不能为空")
return
}
var (
id string
name string
email string
passwordHash *string
avatarURL *string
role string
employeeID *string
status string
orgID *string
)
err := h.pool.QueryRow(r.Context(),
`SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text
FROM users WHERE email = $1`, req.Email,
).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID)
if err != nil {
response.Unauthorized(w, "邮箱或密码错误")
return
}
if status != "active" {
response.Error(w, http.StatusForbidden, 40302, "账号已被禁用")
return
}
if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) {
response.Unauthorized(w, "邮箱或密码错误")
return
}
// 平台管理员不绑定机构,可登录任意机构入口
// 普通用户/机构管理员必须属于所选机构
if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID {
response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构")
return
}
uid, _ := uuid.Parse(id)
tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role)
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
// Update login info
go func() {
_, _ = h.pool.Exec(context.Background(),
`UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id)
}()
// Set cookies
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/api/v1/auth/refresh",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(7 * 24 * time.Hour / time.Second),
})
usr := userResponse{
ID: id, Name: name, Email: email,
AvatarURL: avatarURL, Role: role, EmployeeID: employeeID,
OrgID: orgID,
}
if orgID != nil {
usr.Org = h.loadOrgInfo(r.Context(), *orgID)
}
response.JSON(w, http.StatusOK, map[string]any{
"user": usr,
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: "",
Path: "/",
HttpOnly: true,
MaxAge: -1,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/api/v1/auth/refresh",
HttpOnly: true,
MaxAge: -1,
})
response.JSON(w, http.StatusOK, nil)
}
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var u userResponse
err := h.pool.QueryRow(r.Context(),
`SELECT id, name, email, avatar_url, role, employee_id, org_id::text
FROM users WHERE id = $1`, userID,
).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID)
if err != nil {
response.NotFound(w, "用户不存在")
return
}
if u.OrgID != nil {
u.Org = h.loadOrgInfo(r.Context(), *u.OrgID)
}
response.JSON(w, http.StatusOK, u)
}
// ListOrganizations 返回所有可用机构列表
func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(),
`SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'')
FROM organizations WHERE is_active = true ORDER BY sort_order ASC`)
if err != nil {
response.InternalError(w, "查询机构列表失败")
return
}
defer rows.Close()
var orgs []map[string]any
for rows.Next() {
var id, name, slug, shortName, desc, logo string
if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil {
continue
}
orgs = append(orgs, map[string]any{
"id": id, "name": name, "slug": slug,
"short_name": shortName, "description": desc, "logo_url": logo,
})
}
if orgs == nil {
orgs = []map[string]any{}
}
response.JSON(w, http.StatusOK, orgs)
}
// SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份
func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) {
currentRole := middleware.GetRole(r.Context())
if currentRole != "super_admin" && currentRole != "admin" {
response.Forbidden(w, "仅管理员可以切换机构")
return
}
var req struct {
OrgID string `json:"org_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" {
response.BadRequest(w, "请选择机构")
return
}
// 校验机构是否存在
var exists bool
h.pool.QueryRow(r.Context(),
`SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists)
if !exists {
response.NotFound(w, "机构不存在")
return
}
// 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator
var targetID uuid.UUID
var targetEmail, targetRole, targetName string
err := h.pool.QueryRow(r.Context(),
`SELECT id, email, role, name FROM users
WHERE org_id = $1 AND status = 'active'
ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END
LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName)
if err != nil {
response.InternalError(w, "该机构暂无可用用户")
return
}
// 为目标用户生成新的JWT token
tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole)
if err != nil {
response.InternalError(w, "生成令牌失败")
return
}
org := h.loadOrgInfo(r.Context(), req.OrgID)
response.JSON(w, http.StatusOK, map[string]any{
"message": "已切换",
"org": org,
"token": tokens.AccessToken,
"user": map[string]any{
"id": targetID,
"name": targetName,
"email": targetEmail,
"role": targetRole,
"org_id": req.OrgID,
},
})
}
func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo {
var o orgInfo
err := h.pool.QueryRow(ctx,
`SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID,
).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName)
if err != nil {
return nil
}
return &o
}
func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var req struct {
Name string `json:"name"`
AvatarURL string `json:"avatar_url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(),
`UPDATE users SET name = COALESCE(NULLIF($2,''), name),
avatar_url = COALESCE(NULLIF($3,''), avatar_url),
updated_at = NOW()
WHERE id = $1`, userID, req.Name, req.AvatarURL)
if err != nil {
response.InternalError(w, "更新个人信息失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("refresh_token")
if err != nil {
response.Unauthorized(w, "Refresh Token 不存在")
return
}
claims, err := h.jwtMgr.ValidateToken(cookie.Value)
if err != nil {
response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期")
return
}
// Get current role from DB
var role string
err = h.pool.QueryRow(r.Context(),
`SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID,
).Scan(&role)
if err != nil {
response.Unauthorized(w, "用户不存在或已被禁用")
return
}
tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role)
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
response.JSON(w, http.StatusOK, map[string]any{
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}
+307
View File
@@ -0,0 +1,307 @@
package handler
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/dify"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type ChatHandler struct {
pool *pgxpool.Pool
dify *dify.Client
}
func NewChatHandler(pool *pgxpool.Pool, difyClient *dify.Client) *ChatHandler {
return &ChatHandler{pool: pool, dify: difyClient}
}
type chatRequest struct {
Message string `json:"message"`
ConversationID string `json:"conversation_id,omitempty"`
Inputs map[string]any `json:"inputs,omitempty"`
}
func (h *ChatHandler) Chat(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req chatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Message == "" {
response.BadRequest(w, "消息不能为空")
return
}
// Get app's Dify API key
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在或未上架")
return
}
startTime := time.Now()
// Call Dify streaming chat
difyReq := &dify.ChatRequest{
Query: req.Message,
Inputs: req.Inputs,
ConversationID: req.ConversationID,
User: userID.String(),
ResponseMode: "streaming",
}
body, err := h.dify.ChatStream(r.Context(), difyAPIKey, difyReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
return
}
defer body.Close()
// Stream SSE to client
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
var totalTokens int
var modelName string
var conversationID string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
break
}
// Forward SSE event to client
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
// Parse for usage tracking
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err == nil {
if cid, ok := event["conversation_id"].(string); ok && cid != "" {
conversationID = cid
}
if event["event"] == "message_end" {
if meta, ok := event["metadata"].(map[string]any); ok {
if usage, ok := meta["usage"].(map[string]any); ok {
if t, ok := usage["total_tokens"].(float64); ok {
totalTokens = int(t)
}
if m, ok := usage["model"].(string); ok {
modelName = m
}
}
}
}
}
}
// Record usage asynchronously
duration := time.Since(startTime).Milliseconds()
go func() {
_, _ = h.pool.Exec(context.Background(), `
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, model_name, duration_ms, client_type)
VALUES ($1, $2, $3, $4, $5, $6, 'web')`,
appID, userID, conversationID, totalTokens, modelName, duration)
_, _ = h.pool.Exec(context.Background(),
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
}()
}
func (h *ChatHandler) Completion(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req chatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Message == "" {
response.BadRequest(w, "消息不能为空")
return
}
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1 AND status = 'approved'`,
appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在或未上架")
return
}
startTime := time.Now()
difyReq := &dify.ChatRequest{
Query: req.Message,
Inputs: req.Inputs,
ConversationID: req.ConversationID,
User: userID.String(),
ResponseMode: "blocking",
}
result, err := h.dify.ChatBlocking(r.Context(), difyAPIKey, difyReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50201, "Dify 服务不可用: "+err.Error())
return
}
duration := time.Since(startTime).Milliseconds()
go func() {
_, _ = h.pool.Exec(context.Background(), `
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, total_tokens, duration_ms, client_type)
VALUES ($1, $2, $3, $4, $5, 'web')`,
appID, userID, result.ConversationID, 0, duration)
_, _ = h.pool.Exec(context.Background(),
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
}()
response.JSON(w, http.StatusOK, result)
}
func (h *ChatHandler) Feedback(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req struct {
MessageID string `json:"message_id"`
Rating string `json:"rating"` // "like" | "dislike" | null
Comment string `json:"comment"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.MessageID == "" {
response.BadRequest(w, "message_id 不能为空")
return
}
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在")
return
}
feedbackReq := &dify.FeedbackRequest{
Rating: req.Rating,
User: userID.String(),
}
if err := h.dify.SubmitFeedback(r.Context(), difyAPIKey, req.MessageID, feedbackReq); err != nil {
response.Error(w, http.StatusBadGateway, 50201, "提交反馈失败: "+err.Error())
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已提交"})
}
func (h *ChatHandler) Conversations(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在")
return
}
result, err := h.dify.ListConversations(r.Context(), difyAPIKey, userID.String(), 20, "")
if err != nil {
response.Error(w, http.StatusBadGateway, 50201, "获取对话列表失败")
return
}
response.JSON(w, http.StatusOK, result)
}
func (h *ChatHandler) Messages(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
convID := chi.URLParam(r, "convId")
userID := middleware.GetUserID(r.Context())
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在")
return
}
result, err := h.dify.ListMessages(r.Context(), difyAPIKey, userID.String(), convID, 100, "")
if err != nil {
response.Error(w, http.StatusBadGateway, 50201, "获取消息列表失败")
return
}
response.JSON(w, http.StatusOK, result)
}
func (h *ChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
convID := chi.URLParam(r, "convId")
userID := middleware.GetUserID(r.Context())
var difyAPIKey string
err := h.pool.QueryRow(r.Context(),
`SELECT dify_api_key FROM applications WHERE id = $1`, appID,
).Scan(&difyAPIKey)
if err != nil || difyAPIKey == "" {
response.NotFound(w, "应用不存在")
return
}
if err := h.dify.DeleteConversation(r.Context(), difyAPIKey, userID.String(), convID); err != nil {
response.Error(w, http.StatusBadGateway, 50201, "删除对话失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
// Suppress unused import warnings
var _ = io.EOF
File diff suppressed because it is too large Load Diff
+513
View File
@@ -0,0 +1,513 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/dify"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type CreatorHandler struct {
pool *pgxpool.Pool
dify *dify.Client
}
func NewCreatorHandler(pool *pgxpool.Pool, difyClient *dify.Client) *CreatorHandler {
return &CreatorHandler{pool: pool, dify: difyClient}
}
type createAppRequest struct {
Name string `json:"name"`
Description string `json:"description"`
LongDescription string `json:"long_description"`
CategoryID string `json:"category_id"`
Visibility string `json:"visibility"`
AppType string `json:"app_type"`
SystemPrompt string `json:"system_prompt"`
WelcomeMessage string `json:"welcome_message"`
SuggestedPrompts []string `json:"suggested_prompts"`
Model string `json:"model"`
Temperature float32 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
KnowledgeBaseIDs []string `json:"knowledge_base_ids"`
// Agent-type config
Tools []string `json:"tools"`
DataSources []string `json:"data_sources"`
TemplateSet string `json:"template_set"`
// Completion-type config
InputLabel string `json:"input_label"`
OutputLabel string `json:"output_label"`
InputPlaceholder string `json:"input_placeholder"`
FormatTemplates map[string]any `json:"format_templates"`
}
func buildAppConfig(req *createAppRequest) json.RawMessage {
cfg := map[string]any{
"system_prompt": req.SystemPrompt,
"model": req.Model,
}
if len(req.Tools) > 0 {
cfg["tools"] = req.Tools
}
if len(req.DataSources) > 0 {
cfg["data_sources"] = req.DataSources
}
if req.TemplateSet != "" {
cfg["template_set"] = req.TemplateSet
}
if req.InputLabel != "" {
cfg["input_label"] = req.InputLabel
}
if req.OutputLabel != "" {
cfg["output_label"] = req.OutputLabel
}
if req.InputPlaceholder != "" {
cfg["input_placeholder"] = req.InputPlaceholder
}
if len(req.FormatTemplates) > 0 {
cfg["format_templates"] = req.FormatTemplates
}
b, _ := json.Marshal(cfg)
return b
}
func (h *CreatorHandler) ListMyApps(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
role := middleware.GetRole(r.Context())
isAdmin := role == "admin" || role == "super_admin"
var rows pgx.Rows
var err error
if isAdmin {
// 管理员查看本机构所有应用(通过用户表获取org_id)
rows, err = h.pool.Query(r.Context(), `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, a.dify_app_type, a.status, a.visibility,
a.usage_count, a.updated_at
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.org_id = (SELECT org_id FROM users WHERE id = $1)
ORDER BY a.updated_at DESC`, userID)
} else {
rows, err = h.pool.Query(r.Context(), `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, a.dify_app_type, a.status, a.visibility,
a.usage_count, a.updated_at
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.creator_id = $1
ORDER BY a.updated_at DESC`, userID)
}
if err != nil {
response.InternalError(w, "查询应用失败")
return
}
defer rows.Close()
var apps []map[string]any
for rows.Next() {
var (
id, name, slug, status, visibility string
desc, iconURL, catName, appType *string
usageCount int64
updatedAt time.Time
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &catName,
&appType, &status, &visibility, &usageCount, &updatedAt); err != nil {
continue
}
apps = append(apps, map[string]any{
"id": id, "name": name, "slug": slug, "description": desc,
"icon_url": iconURL, "category_name": catName,
"dify_app_type": appType,
"status": status, "visibility": visibility,
"usage_count": usageCount, "updated_at": updatedAt,
})
}
if apps == nil {
apps = []map[string]any{}
}
response.JSON(w, http.StatusOK, apps)
}
func (h *CreatorHandler) GetApp(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
role := middleware.GetRole(r.Context())
isAdmin := role == "admin" || role == "super_admin"
var (
id, name, slug, status, visibility, version string
desc, longDesc, iconURL, catID, difyType *string
welcomeMsg *string
kbID *string
appConfig json.RawMessage
suggestedPrompts json.RawMessage
maxTokens int
temperature float32
usageCount int64
createdAt, updatedAt time.Time
)
var query string
var args []any
if isAdmin {
query = `SELECT a.id, a.name, a.slug, a.description, a.long_description,
a.icon_url, a.category_id, a.dify_app_type,
a.app_config, a.welcome_message, a.suggested_prompts,
a.max_tokens, a.temperature, a.status, a.visibility,
a.is_featured, a.usage_count, a.version,
a.knowledge_base_id, a.created_at, a.updated_at
FROM applications a WHERE a.id = $1`
args = []any{appID}
} else {
query = `SELECT a.id, a.name, a.slug, a.description, a.long_description,
a.icon_url, a.category_id, a.dify_app_type,
a.app_config, a.welcome_message, a.suggested_prompts,
a.max_tokens, a.temperature, a.status, a.visibility,
a.is_featured, a.usage_count, a.version,
a.knowledge_base_id, a.created_at, a.updated_at
FROM applications a WHERE a.id = $1 AND a.creator_id = $2`
args = []any{appID, userID}
}
err := h.pool.QueryRow(r.Context(), query, args...).Scan(
&id, &name, &slug, &desc, &longDesc,
&iconURL, &catID, &difyType,
&appConfig, &welcomeMsg, &suggestedPrompts,
&maxTokens, &temperature, &status, &visibility,
new(bool), &usageCount, &version,
&kbID, &createdAt, &updatedAt)
if err != nil {
response.NotFound(w, "应用不存在或无权访问")
return
}
result := map[string]any{
"id": id, "name": name, "slug": slug, "description": desc,
"long_description": longDesc, "icon_url": iconURL,
"category_id": catID, "dify_app_type": difyType,
"app_config": appConfig, "welcome_message": welcomeMsg,
"suggested_prompts": suggestedPrompts,
"max_tokens": maxTokens, "temperature": temperature,
"status": status, "visibility": visibility,
"usage_count": usageCount, "version": version,
"knowledge_base_id": kbID,
"created_at": createdAt, "updated_at": updatedAt,
}
response.JSON(w, http.StatusOK, result)
}
func (h *CreatorHandler) CreateApp(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var req createAppRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" {
response.BadRequest(w, "应用名称不能为空")
return
}
if req.Visibility == "" {
req.Visibility = "private"
}
if req.AppType == "" {
req.AppType = "chatbot"
}
if req.Temperature == 0 {
req.Temperature = 0.7
}
if req.MaxTokens == 0 {
req.MaxTokens = 4096
}
slug := generateSlug(req.Name)
suggestedPromptsJSON, _ := json.Marshal(req.SuggestedPrompts)
appConfig := buildAppConfig(&req)
difyAppID := ""
difyAPIKey := ""
var appID string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO applications (
name, slug, description, long_description, category_id, creator_id,
dify_app_id, dify_app_type, dify_api_key,
app_config, welcome_message, suggested_prompts,
max_tokens, temperature, status, visibility
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, 'draft', $15)
RETURNING id`,
req.Name, slug, req.Description, req.LongDescription,
nilIfEmpty(req.CategoryID), userID,
nilIfEmpty(difyAppID), req.AppType, nilIfEmpty(difyAPIKey),
appConfig, req.WelcomeMessage, string(suggestedPromptsJSON),
req.MaxTokens, req.Temperature, req.Visibility,
).Scan(&appID)
if err != nil {
if strings.Contains(err.Error(), "duplicate key") {
response.Error(w, http.StatusConflict, 40901, "应用名称已存在")
return
}
response.InternalError(w, "创建应用失败: "+err.Error())
return
}
response.JSON(w, http.StatusCreated, map[string]any{
"id": appID,
"name": req.Name,
"slug": slug,
"status": "draft",
})
}
func (h *CreatorHandler) UpdateApp(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
role := middleware.GetRole(r.Context())
isAdmin := role == "admin" || role == "super_admin"
var status string
var creatorID string
err := h.pool.QueryRow(r.Context(),
`SELECT status, creator_id FROM applications WHERE id = $1`, appID).Scan(&status, &creatorID)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
if !isAdmin && creatorID != userID.String() {
response.Forbidden(w, "只能修改自己创建的应用")
return
}
var req struct {
createAppRequest
KnowledgeBaseID string `json:"knowledge_base_id"`
AppType string `json:"app_type"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
suggestedPromptsJSON, _ := json.Marshal(req.SuggestedPrompts)
appConfig := buildAppConfig(&req.createAppRequest)
newStatus := status
if status == "draft" || status == "rejected" {
newStatus = "draft"
}
_, err = h.pool.Exec(r.Context(), `
UPDATE applications SET
name = COALESCE(NULLIF($2, ''), name),
description = COALESCE(NULLIF($3, ''), description),
long_description = $4,
category_id = COALESCE($5::UUID, category_id),
app_config = $6,
welcome_message = $7,
suggested_prompts = $8,
max_tokens = $9,
temperature = $10,
visibility = COALESCE(NULLIF($11, ''), visibility),
knowledge_base_id = $12::UUID,
dify_app_type = COALESCE(NULLIF($13, ''), dify_app_type),
status = $14
WHERE id = $1`,
appID, req.Name, req.Description, req.LongDescription,
nilIfEmpty(req.CategoryID),
appConfig, req.WelcomeMessage, string(suggestedPromptsJSON),
req.MaxTokens, req.Temperature, req.Visibility,
nilIfEmpty(req.KnowledgeBaseID), req.AppType, newStatus,
)
if err != nil {
response.InternalError(w, "更新应用失败: "+err.Error())
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "更新成功"})
}
func (h *CreatorHandler) DeleteApp(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
role := middleware.GetRole(r.Context())
isAdmin := role == "admin" || role == "super_admin"
var tag pgconn.CommandTag
var err error
if isAdmin {
tag, err = h.pool.Exec(r.Context(),
`DELETE FROM applications WHERE id = $1`, appID)
} else {
tag, err = h.pool.Exec(r.Context(),
`DELETE FROM applications WHERE id = $1 AND creator_id = $2 AND status = 'draft'`,
appID, userID)
}
if err != nil {
response.InternalError(w, "删除失败")
return
}
if tag.RowsAffected() == 0 {
if isAdmin {
response.NotFound(w, "应用不存在")
} else {
response.BadRequest(w, "只能删除草稿状态的应用")
}
return
}
response.JSON(w, http.StatusOK, nil)
}
func (h *CreatorHandler) SubmitReview(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var status, creatorID, version string
err := h.pool.QueryRow(r.Context(),
`SELECT status, creator_id, version FROM applications WHERE id = $1`, appID,
).Scan(&status, &creatorID, &version)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
if creatorID != userID.String() {
response.Forbidden(w, "只能提交自己创建的应用")
return
}
if status != "draft" && status != "rejected" {
response.BadRequest(w, "只有草稿或被驳回的应用可以提交审核")
return
}
var req struct {
Comment string `json:"comment"`
}
json.NewDecoder(r.Body).Decode(&req)
tx, err := h.pool.Begin(r.Context())
if err != nil {
response.InternalError(w, "事务开始失败")
return
}
defer tx.Rollback(r.Context())
tx.Exec(r.Context(), `
INSERT INTO app_reviews (app_id, version, submitter_id, submit_comment)
VALUES ($1, $2, $3, $4)`, appID, version, userID, req.Comment)
tx.Exec(r.Context(), `
UPDATE applications SET status = 'pending_review' WHERE id = $1`, appID)
if err := tx.Commit(r.Context()); err != nil {
response.InternalError(w, "提交审核失败")
return
}
response.JSON(w, http.StatusOK, nil)
}
func (h *CreatorHandler) WithdrawReview(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var creatorID string
h.pool.QueryRow(r.Context(), `SELECT creator_id FROM applications WHERE id = $1`, appID).Scan(&creatorID)
if creatorID != userID.String() {
response.Forbidden(w, "只能撤回自己的审核")
return
}
tx, err := h.pool.Begin(r.Context())
if err != nil {
response.InternalError(w, "事务开始失败")
return
}
defer tx.Rollback(r.Context())
tx.Exec(r.Context(), `
UPDATE app_reviews SET status = 'withdrawn'
WHERE app_id = $1 AND status = 'pending'`, appID)
tx.Exec(r.Context(), `
UPDATE applications SET status = 'draft' WHERE id = $1 AND status = 'pending_review'`, appID)
tx.Commit(r.Context())
response.JSON(w, http.StatusOK, nil)
}
func (h *CreatorHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
a.usage_count, a.avg_rating, a.rating_count
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.is_template = true AND a.status = 'approved'
ORDER BY a.usage_count DESC`)
if err != nil {
response.InternalError(w, "查询模板失败")
return
}
defer rows.Close()
apps := scanAppList(rows)
response.JSON(w, http.StatusOK, apps)
}
func generateSlug(name string) string {
slug := strings.ToLower(strings.TrimSpace(name))
slug = strings.ReplaceAll(slug, " ", "-")
return fmt.Sprintf("%s-%d", slug, time.Now().UnixMilli()%10000)
}
func (h *CreatorHandler) RequestDelist(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var status, creatorID string
err := h.pool.QueryRow(r.Context(),
`SELECT status, creator_id FROM applications WHERE id = $1`, appID).Scan(&status, &creatorID)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
if creatorID != userID.String() {
response.Forbidden(w, "只能操作自己创建的应用")
return
}
if status != "approved" {
response.BadRequest(w, "只有已上架的应用可以申请下架")
return
}
_, err = h.pool.Exec(r.Context(),
`UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
if err != nil {
response.InternalError(w, "申请下架失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已下架"})
}
func nilIfEmpty(s string) *string {
if s == "" {
return nil
}
return &s
}
+242
View File
@@ -0,0 +1,242 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/llm"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type DocTemplateHandler struct {
pool *pgxpool.Pool
manager *llm.Manager
provider string
}
func NewDocTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *DocTemplateHandler {
return &DocTemplateHandler{pool: pool, manager: manager, provider: provider}
}
func (h *DocTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("org_id")
query := `SELECT id, name, doc_type, COALESCE(description,''), COALESCE(icon,''), fields, sort_order
FROM document_templates
WHERE is_active = true`
var args []any
if orgID != "" {
query += ` AND (org_id = $1 OR org_id IS NULL)`
args = append(args, orgID)
}
query += ` ORDER BY sort_order ASC`
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询模板失败")
return
}
defer rows.Close()
var templates []map[string]any
for rows.Next() {
var id, name, docType, desc, icon string
var fields json.RawMessage
var sortOrder int
if err := rows.Scan(&id, &name, &docType, &desc, &icon, &fields, &sortOrder); err != nil {
continue
}
templates = append(templates, map[string]any{
"id": id,
"name": name,
"doc_type": docType,
"description": desc,
"icon": icon,
"fields": fields,
"sort_order": sortOrder,
})
}
if templates == nil {
templates = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
}
func (h *DocTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "templateId")
var name, docType, desc, icon, formatStd, promptTpl string
var fields json.RawMessage
var exampleOutput *string
err := h.pool.QueryRow(r.Context(), `
SELECT name, doc_type, COALESCE(description,''), COALESCE(icon,''),
format_standard, fields, prompt_template, example_output
FROM document_templates WHERE id = $1 AND is_active = true`, id,
).Scan(&name, &docType, &desc, &icon, &formatStd, &fields, &promptTpl, &exampleOutput)
if err != nil {
response.NotFound(w, "模板不存在")
return
}
result := map[string]any{
"id": id,
"name": name,
"doc_type": docType,
"description": desc,
"icon": icon,
"format_standard": formatStd,
"fields": fields,
"prompt_template": promptTpl,
}
if exampleOutput != nil {
result["example_output"] = *exampleOutput
}
response.JSON(w, http.StatusOK, result)
}
type generateDocRequest struct {
TemplateID string `json:"template_id"`
FieldData map[string]string `json:"field_data"`
}
func (h *DocTemplateHandler) GenerateDocument(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req generateDocRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效请求格式")
return
}
if req.TemplateID == "" {
response.BadRequest(w, "请选择公文模板")
return
}
// Load app-specific model and max_tokens
var appModel string
var appMaxTokens int
_ = h.pool.QueryRow(r.Context(), `
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 4096)
FROM applications WHERE id = $1`, appID,
).Scan(&appModel, &appMaxTokens)
var promptTpl, tplName string
var fields json.RawMessage
err := h.pool.QueryRow(r.Context(), `
SELECT name, fields, prompt_template
FROM document_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
).Scan(&tplName, &fields, &promptTpl)
if err != nil {
response.NotFound(w, "模板不存在")
return
}
var fieldDefs []struct {
Key string `json:"key"`
Label string `json:"label"`
Required bool `json:"required"`
}
_ = json.Unmarshal(fields, &fieldDefs)
for _, f := range fieldDefs {
if f.Required {
if val, ok := req.FieldData[f.Key]; !ok || strings.TrimSpace(val) == "" {
response.BadRequest(w, fmt.Sprintf("请填写必填项:%s", f.Label))
return
}
}
}
prompt := promptTpl
for k, v := range req.FieldData {
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
}
prompt = strings.ReplaceAll(prompt, "{{", "")
prompt = strings.ReplaceAll(prompt, "}}", "")
llmReq := &llm.ChatRequest{
Model: appModel,
Messages: []llm.Message{
{Role: llm.RoleSystem, Content: "你是一个专业的政务公文写作专家,精通《党政机关公文格式》国家标准(GB/T 9704)和《党政机关公文处理工作条例》。请严格按照规范格式生成公文,确保行文庄重、严谨、准确。输出完整公文内容,使用Markdown格式排版。"},
{Role: llm.RoleUser, Content: prompt},
},
Temperature: 0.3,
MaxTokens: appMaxTokens,
Stream: true,
}
startTime := time.Now()
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
return
}
defer body.Close()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
convID := uuid.New().String()
msgID := uuid.New().String()
var fullResponse strings.Builder
firstEvent := map[string]string{
"conversation_id": convID,
"message_id": msgID,
"template_name": tplName,
}
data, _ := json.Marshal(firstEvent)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
transform := llm.TransformOpenAIStream
if h.provider == "anthropic" {
transform = llm.TransformAnthropicStream
}
var totalTokens int
var modelName string
_ = transform(body, func(event llm.StreamEvent) {
if event.Answer != "" {
fullResponse.WriteString(event.Answer)
}
if event.Usage != nil {
totalTokens = event.Usage.TotalTokens
modelName = event.Usage.Model
}
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
})
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
duration := time.Since(startTime).Milliseconds()
userMsg := fmt.Sprintf("[公文生成] %s", tplName)
go func() {
ctx := context.Background()
_, _ = h.pool.Exec(ctx, `
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
_, _ = h.pool.Exec(ctx,
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
}()
}
+217
View File
@@ -0,0 +1,217 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type FavoriteHandler struct {
pool *pgxpool.Pool
}
func NewFavoriteHandler(pool *pgxpool.Pool) *FavoriteHandler {
return &FavoriteHandler{pool: pool}
}
func (h *FavoriteHandler) AddFavorite(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
_, err := h.pool.Exec(r.Context(),
`INSERT INTO app_favorites (user_id, app_id) VALUES ($1, $2) ON CONFLICT DO NOTHING`,
userID, appID)
if err != nil {
response.InternalError(w, "收藏失败")
return
}
h.pool.Exec(r.Context(),
`UPDATE applications SET favorite_count = favorite_count + 1 WHERE id = $1`, appID)
response.JSON(w, http.StatusOK, map[string]bool{"favorited": true})
}
func (h *FavoriteHandler) RemoveFavorite(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
tag, err := h.pool.Exec(r.Context(),
`DELETE FROM app_favorites WHERE user_id = $1 AND app_id = $2`,
userID, appID)
if err != nil {
response.InternalError(w, "取消收藏失败")
return
}
if tag.RowsAffected() > 0 {
h.pool.Exec(r.Context(),
`UPDATE applications SET favorite_count = GREATEST(favorite_count - 1, 0) WHERE id = $1`, appID)
}
response.JSON(w, http.StatusOK, map[string]bool{"favorited": false})
}
func (h *FavoriteHandler) ListFavorites(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
offset := (page - 1) * 20
rows, err := h.pool.Query(r.Context(), `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
a.usage_count, a.avg_rating, a.rating_count
FROM app_favorites f
JOIN applications a ON f.app_id = a.id
LEFT JOIN categories c ON a.category_id = c.id
WHERE f.user_id = $1
ORDER BY f.created_at DESC
LIMIT 20 OFFSET $2`, userID, offset)
if err != nil {
response.InternalError(w, "查询收藏失败")
return
}
defer rows.Close()
apps := scanAppList(rows)
response.JSON(w, http.StatusOK, apps)
}
// --- Rating ---
type ratingRequest struct {
Score int `json:"score"`
Comment string `json:"comment"`
}
func (h *FavoriteHandler) AddRating(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req ratingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Score < 1 || req.Score > 5 {
response.BadRequest(w, "评分必须在1-5之间")
return
}
_, err := h.pool.Exec(r.Context(), `
INSERT INTO app_ratings (app_id, user_id, score, comment)
VALUES ($1, $2, $3, $4)
ON CONFLICT (app_id, user_id)
DO UPDATE SET score = EXCLUDED.score, comment = EXCLUDED.comment`,
appID, userID, req.Score, req.Comment)
if err != nil {
response.InternalError(w, "评分失败")
return
}
// Update app avg rating
var avgRating float32
var ratingCount int
h.pool.QueryRow(r.Context(),
`SELECT COALESCE(AVG(score)::REAL, 0), COUNT(*) FROM app_ratings WHERE app_id = $1`,
appID).Scan(&avgRating, &ratingCount)
h.pool.Exec(r.Context(),
`UPDATE applications SET avg_rating = $2, rating_count = $3 WHERE id = $1`,
appID, avgRating, ratingCount)
response.JSON(w, http.StatusOK, map[string]any{
"avg_rating": avgRating,
"rating_count": ratingCount,
})
}
func (h *FavoriteHandler) ListRatings(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
rows, err := h.pool.Query(r.Context(), `
SELECT r.id, r.score, r.comment, r.created_at,
u.name as user_name, u.avatar_url
FROM app_ratings r
JOIN users u ON r.user_id = u.id
WHERE r.app_id = $1
ORDER BY r.created_at DESC LIMIT 50`, appID)
if err != nil {
response.InternalError(w, "查询评分失败")
return
}
defer rows.Close()
var ratings []map[string]any
for rows.Next() {
var id string
var score int
var comment *string
var createdAt string
var userName string
var avatarURL *string
if err := rows.Scan(&id, &score, &comment, &createdAt, &userName, &avatarURL); err != nil {
continue
}
ratings = append(ratings, map[string]any{
"id": id, "score": score, "comment": comment,
"created_at": createdAt, "user_name": userName, "user_avatar": avatarURL,
})
}
if ratings == nil {
ratings = []map[string]any{}
}
response.JSON(w, http.StatusOK, ratings)
}
func (h *FavoriteHandler) PersonalStats(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var totalConversations, totalTokens, favoriteCount int
h.pool.QueryRow(r.Context(),
`SELECT COUNT(*) FROM app_usage_logs WHERE user_id = $1`, userID).Scan(&totalConversations)
h.pool.QueryRow(r.Context(),
`SELECT COALESCE(SUM(total_tokens), 0) FROM app_usage_logs WHERE user_id = $1`, userID).Scan(&totalTokens)
h.pool.QueryRow(r.Context(),
`SELECT COUNT(*) FROM app_favorites WHERE user_id = $1`, userID).Scan(&favoriteCount)
var recentApps []map[string]any
rows, err := h.pool.Query(r.Context(), `
SELECT DISTINCT ON (l.app_id) a.id, a.name, a.icon_url, l.created_at
FROM app_usage_logs l
JOIN applications a ON l.app_id = a.id
WHERE l.user_id = $1
ORDER BY l.app_id, l.created_at DESC
LIMIT 5`, userID)
if err == nil {
defer rows.Close()
for rows.Next() {
var id, name string
var icon *string
var at string
if rows.Scan(&id, &name, &icon, &at) == nil {
recentApps = append(recentApps, map[string]any{
"id": id, "name": name, "icon_url": icon, "last_used": at,
})
}
}
}
if recentApps == nil {
recentApps = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{
"total_conversations": totalConversations,
"total_tokens": totalTokens,
"favorite_count": favoriteCount,
"recent_apps": recentApps,
})
}
+20
View File
@@ -0,0 +1,20 @@
package handler
import (
"net/http"
"runtime"
"time"
"github.com/enterprise-ai-platform/server/internal/response"
)
var startTime = time.Now()
func HealthCheck(w http.ResponseWriter, r *http.Request) {
response.JSON(w, http.StatusOK, map[string]any{
"status": "ok",
"service": "aily-portal-api",
"uptime": time.Since(startTime).String(),
"go": runtime.Version(),
})
}
+522
View File
@@ -0,0 +1,522 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"unicode/utf8"
"github.com/enterprise-ai-platform/server/pkg/chunker"
"github.com/enterprise-ai-platform/server/pkg/embedding"
"github.com/rs/zerolog/log"
mw "github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
func internalErr(w http.ResponseWriter, err error) {
response.InternalError(w, err.Error())
}
type KnowledgeHandler struct {
pool *pgxpool.Pool
embedder *embedding.Client
}
func NewKnowledgeHandler(pool *pgxpool.Pool, embedder *embedding.Client) *KnowledgeHandler {
return &KnowledgeHandler{pool: pool, embedder: embedder}
}
type knowledgeBaseRow struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Visibility string `json:"visibility"`
DocCount int `json:"document_count"`
TotalChars int64 `json:"total_chars"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (h *KnowledgeHandler) ListKnowledgeBases(w http.ResponseWriter, r *http.Request) {
userID := mw.GetUserID(r.Context())
userRole := mw.GetRole(r.Context())
var query string
var args []any
if userRole == "super_admin" {
// 超级管理员查看全部知识库
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
FROM knowledge_bases ORDER BY updated_at DESC`
args = []any{}
} else {
// 优先使用前端传入的 org_id 参数(切换机构后),否则从用户表获取
orgFilter := r.URL.Query().Get("org_id")
if orgFilter == "" {
var userOrg *string
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrg)
if userOrg != nil {
orgFilter = *userOrg
}
}
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
FROM knowledge_bases WHERE (owner_id = $1`
args = []any{userID}
if orgFilter != "" {
query += ` OR org_id = $2`
args = append(args, orgFilter)
}
query += `) ORDER BY updated_at DESC`
}
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
var items []knowledgeBaseRow
for rows.Next() {
var kb knowledgeBaseRow
if err := rows.Scan(&kb.ID, &kb.Name, &kb.Description, &kb.Visibility, &kb.DocCount, &kb.TotalChars, &kb.Status, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
internalErr(w, err)
return
}
items = append(items, kb)
}
if items == nil {
items = []knowledgeBaseRow{}
}
response.JSON(w, http.StatusOK, items)
}
func (h *KnowledgeHandler) CreateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
userID := mw.GetUserID(r.Context())
var body struct {
Name string `json:"name"`
Description string `json:"description"`
Visibility string `json:"visibility"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
response.BadRequest(w, "无效的请求体")
return
}
if body.Name == "" {
response.BadRequest(w, "名称不能为空")
return
}
if body.Visibility == "" {
body.Visibility = "private"
}
// 获取用户所属机构
var userOrgID *string
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrgID)
id := uuid.New()
_, err := h.pool.Exec(r.Context(),
`INSERT INTO knowledge_bases (id, name, description, owner_id, visibility, org_id)
VALUES ($1, $2, $3, $4, $5, $6)`,
id, body.Name, body.Description, userID, body.Visibility, userOrgID)
if err != nil {
internalErr(w, err)
return
}
response.JSON(w, http.StatusCreated, map[string]any{
"id": id.String(),
"name": body.Name,
"description": body.Description,
"visibility": body.Visibility,
})
}
func (h *KnowledgeHandler) UpdateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
id, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
userID := mw.GetUserID(r.Context())
var body struct {
Name string `json:"name"`
Description string `json:"description"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
response.BadRequest(w, "无效的请求体")
return
}
tag, err := h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET name = COALESCE(NULLIF($1,''), name),
description = $2, updated_at = NOW()
WHERE id = $3 AND owner_id = $4`,
body.Name, body.Description, id, userID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "知识库不存在")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *KnowledgeHandler) DeleteKnowledgeBase(w http.ResponseWriter, r *http.Request) {
id, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
userID := mw.GetUserID(r.Context())
tag, err := h.pool.Exec(r.Context(),
`DELETE FROM knowledge_bases WHERE id = $1 AND owner_id = $2`, id, userID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "知识库不存在")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
type documentRow struct {
ID string `json:"id"`
Name string `json:"filename"`
FileType string `json:"file_type"`
FileSize int64 `json:"file_size"`
IndexingStatus string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
func (h *KnowledgeHandler) ListDocuments(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
rows, err := h.pool.Query(r.Context(),
`SELECT id, name, COALESCE(file_type,''), file_size, indexing_status, created_at
FROM knowledge_documents WHERE kb_id = $1 ORDER BY created_at DESC`, kbID)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
var docs []documentRow
for rows.Next() {
var d documentRow
if err := rows.Scan(&d.ID, &d.Name, &d.FileType, &d.FileSize, &d.IndexingStatus, &d.CreatedAt); err != nil {
internalErr(w, err)
return
}
docs = append(docs, d)
}
if docs == nil {
docs = []documentRow{}
}
response.JSON(w, http.StatusOK, docs)
}
func (h *KnowledgeHandler) UploadDocument(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
if err := r.ParseMultipartForm(32 << 20); err != nil {
response.BadRequest(w, "文件过大或格式错误")
return
}
file, header, err := r.FormFile("file")
if err != nil {
response.BadRequest(w, "请上传文件")
return
}
defer file.Close()
var exists bool
err = h.pool.QueryRow(r.Context(),
`SELECT EXISTS(SELECT 1 FROM knowledge_bases WHERE id = $1)`, kbID).Scan(&exists)
if err != nil || !exists {
response.NotFound(w, "知识库不存在")
return
}
fileType := ""
ext := ""
if dot := len(header.Filename) - 1; dot > 0 {
for i := dot; i >= 0; i-- {
if header.Filename[i] == '.' {
ext = header.Filename[i+1:]
break
}
}
}
switch ext {
case "pdf":
fileType = "pdf"
case "docx":
fileType = "docx"
case "txt":
fileType = "txt"
case "md":
fileType = "md"
case "csv":
fileType = "csv"
case "xlsx":
fileType = "xlsx"
default:
fileType = "txt"
}
// 读取文件内容(文本文件)
var content string
if fileType == "txt" || fileType == "md" || fileType == "csv" {
data, err := io.ReadAll(file)
if err == nil {
content = string(data)
}
}
userID := mw.GetUserID(r.Context())
docID := uuid.New()
// 计算分片数
chunkCount := 0
if content != "" {
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
chunkCount = len(chunks)
}
_, err = h.pool.Exec(r.Context(),
`INSERT INTO knowledge_documents (id, kb_id, name, file_type, file_size, uploader_id, indexing_status, content, char_count, chunk_count)
VALUES ($1, $2, $3, $4, $5, $6, 'processing', $7, $8, $9)`,
docID, kbID, header.Filename, fileType, header.Size, userID, content, utf8.RuneCountInString(content), chunkCount)
if err != nil {
internalErr(w, err)
return
}
_, _ = h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET doc_count = doc_count + 1, updated_at = NOW() WHERE id = $1`, kbID)
// 异步执行分片和向量化
go h.chunkAndEmbed(context.Background(), kbID, docID, content)
response.JSON(w, http.StatusCreated, map[string]any{
"id": docID.String(),
"filename": header.Filename,
"size": header.Size,
"chunks": chunkCount,
"status": "processing",
})
}
// chunkAndEmbed 对文档内容执行分片和向量化(异步)
func (h *KnowledgeHandler) chunkAndEmbed(ctx context.Context, kbID, docID uuid.UUID, content string) {
if content == "" {
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
return
}
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
if len(chunks) == 0 {
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
return
}
embeddingAvailable := h.embedder != nil && h.embedder.IsConfigured()
successCount := 0
for i, chunk := range chunks {
chunkID := uuid.New()
charCount := utf8.RuneCountInString(chunk)
if embeddingAvailable {
emb, err := h.embedder.GetEmbedding(ctx, chunk)
if err != nil {
log.Warn().Err(err).Str("doc_id", docID.String()).Int("chunk", i).Msg("embedding failed")
// 无 embedding 也插入 chunk
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
VALUES ($1, $2, $3, $4, $5, $6)`,
chunkID, kbID, docID, i, chunk, charCount)
} else {
// 将 []float32 转为 pgvector 格式字符串
vecStr := float32SliceToVectorStr(emb)
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)`,
chunkID, kbID, docID, i, chunk, charCount, vecStr)
successCount++
}
} else {
// 没有 embedding 服务,只存储文本分片
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
VALUES ($1, $2, $3, $4, $5, $6)`,
chunkID, kbID, docID, i, chunk, charCount)
}
}
// 更新文档状态
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed', chunk_count = $2 WHERE id = $1`, docID, len(chunks))
log.Info().
Str("doc_id", docID.String()).
Int("chunks", len(chunks)).
Int("embedded", successCount).
Bool("embedding_available", embeddingAvailable).
Msg("document chunked and embedded")
}
// float32SliceToVectorStr 将 float32 切片转为 pgvector 格式字符串 "[0.1,0.2,...]"
func float32SliceToVectorStr(v []float32) string {
s := "["
for i, f := range v {
if i > 0 {
s += ","
}
s += fmt.Sprintf("%g", f)
}
s += "]"
return s
}
// ReindexAll 对所有未分片的文档执行分片和向量化(管理端点)
func (h *KnowledgeHandler) ReindexAll(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(),
`SELECT kd.id, kd.kb_id, kd.content FROM knowledge_documents kd
WHERE kd.content IS NOT NULL AND kd.content != '' AND kd.chunk_count = 0`)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
type docInfo struct {
docID uuid.UUID
kbID uuid.UUID
content string
}
var docs []docInfo
for rows.Next() {
var d docInfo
if err := rows.Scan(&d.docID, &d.kbID, &d.content); err != nil {
continue
}
docs = append(docs, d)
}
for _, d := range docs {
h.chunkAndEmbed(r.Context(), d.kbID, d.docID, d.content)
}
response.JSON(w, http.StatusOK, map[string]any{
"message": "重新索引完成",
"documents": len(docs),
})
}
// ReembedChunks 为已有分片但缺失embedding的chunks补充向量化(管理端点)
func (h *KnowledgeHandler) ReembedChunks(w http.ResponseWriter, r *http.Request) {
if h.embedder == nil || !h.embedder.IsConfigured() {
response.JSON(w, http.StatusOK, map[string]any{
"message": "embedding服务未配置",
"updated": 0,
})
return
}
rows, err := h.pool.Query(r.Context(),
`SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL AND content IS NOT NULL AND content != '' LIMIT 200`)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
type chunkInfo struct {
id uuid.UUID
content string
}
var chunks []chunkInfo
for rows.Next() {
var c chunkInfo
if err := rows.Scan(&c.id, &c.content); err != nil {
continue
}
chunks = append(chunks, c)
}
updated := 0
for _, c := range chunks {
emb, err := h.embedder.GetEmbedding(r.Context(), c.content)
if err != nil {
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("re-embed failed")
continue
}
vecStr := float32SliceToVectorStr(emb)
_, err = h.pool.Exec(r.Context(),
`UPDATE knowledge_chunks SET embedding = $1::vector WHERE id = $2`, vecStr, c.id)
if err != nil {
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("update embedding failed")
continue
}
updated++
}
response.JSON(w, http.StatusOK, map[string]any{
"message": "向量补充完成",
"total": len(chunks),
"updated": updated,
})
}
func (h *KnowledgeHandler) DeleteDocument(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的知识库ID")
return
}
docID, err := uuid.Parse(chi.URLParam(r, "docId"))
if err != nil {
response.BadRequest(w, "无效的文档ID")
return
}
tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_documents WHERE id = $1 AND kb_id = $2`, docID, kbID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "文档不存在")
return
}
_, _ = h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET doc_count = GREATEST(doc_count - 1, 0), updated_at = NOW() WHERE id = $1`, kbID)
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
+828
View File
@@ -0,0 +1,828 @@
package handler
// 平台级管理 handler — 仅 super_admin 可访问,跨机构操作,不受 org_id 限制。
// 与 admin.go(机构级)的核心区别:
// - admin.go 所有查询都加 WHERE org_id = $orgID 过滤
// - platform.go 不加 org_id 过滤,可看/操作所有机构数据
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"time"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type PlatformHandler struct {
pool *pgxpool.Pool
}
func NewPlatformHandler(pool *pgxpool.Pool) *PlatformHandler {
return &PlatformHandler{pool: pool}
}
// ==================== 平台总览 ====================
// Overview 平台级数据看板:跨所有机构的聚合统计
func (h *PlatformHandler) Overview(w http.ResponseWriter, r *http.Request) {
now := time.Now()
todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
var totalOrgs, activeOrgs, totalUsers, activeUsers, totalApps, approvedApps, todayLogins, todayConvs int
var monthlyTokens int64
var monthlyCost float64
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations`).Scan(&totalOrgs)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM organizations WHERE is_active = true`).Scan(&activeOrgs)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users`).Scan(&totalUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE status = 'active'`).Scan(&activeUsers)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications`).Scan(&totalApps)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE status = 'approved'`).Scan(&approvedApps)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE last_login_at >= $1`, todayStart).Scan(&todayLogins)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM app_usage_logs WHERE created_at >= $1`, todayStart).Scan(&todayConvs)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(total_tokens),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyTokens)
h.pool.QueryRow(r.Context(), `SELECT COALESCE(SUM(estimated_cost),0) FROM app_usage_logs WHERE created_at >= $1`, monthStart).Scan(&monthlyCost)
response.JSON(w, http.StatusOK, map[string]any{
"total_orgs": totalOrgs,
"active_orgs": activeOrgs,
"total_users": totalUsers,
"active_users": activeUsers,
"total_apps": totalApps,
"approved_apps": approvedApps,
"today_logins": todayLogins,
"today_convs": todayConvs,
"monthly_tokens": monthlyTokens,
"monthly_cost": monthlyCost,
})
}
// OrgRanking 各机构活跃度排行(用户数 / 应用数 / 月度对话数)
func (h *PlatformHandler) OrgRanking(w http.ResponseWriter, r *http.Request) {
now := time.Now()
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
rows, err := h.pool.Query(r.Context(), `
SELECT o.id, o.name, COALESCE(o.short_name, ''),
COUNT(DISTINCT u.id) AS users,
COUNT(DISTINCT a.id) FILTER (WHERE a.status='approved') AS apps,
COALESCE(SUM(usage.cnt), 0) AS conversations,
COALESCE(SUM(usage.tk), 0) AS tokens
FROM organizations o
LEFT JOIN users u ON u.org_id = o.id
LEFT JOIN applications a ON a.org_id = o.id
LEFT JOIN LATERAL (
SELECT COUNT(*) AS cnt, COALESCE(SUM(total_tokens),0) AS tk
FROM app_usage_logs l
JOIN users uu ON l.user_id = uu.id
WHERE uu.org_id = o.id AND l.created_at >= $1
) usage ON true
WHERE o.is_active = true
GROUP BY o.id, o.name, o.short_name
ORDER BY conversations DESC, users DESC
LIMIT 50`, monthStart)
if err != nil {
response.InternalError(w, "查询失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, short string
users, apps int
conversations, toks int64
)
if err := rows.Scan(&id, &name, &short, &users, &apps, &conversations, &toks); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "short_name": short,
"users": users, "apps": apps,
"conversations": conversations, "tokens": toks,
})
}
response.JSON(w, http.StatusOK, items)
}
// ==================== 机构管理 ====================
func (h *PlatformHandler) ListOrgs(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT o.id, o.name, o.slug, COALESCE(o.short_name,''), COALESCE(o.description,''),
COALESCE(o.logo_url,''), o.sort_order, o.is_active, o.created_at,
COALESCE((SELECT COUNT(*) FROM users WHERE org_id = o.id), 0) AS user_count,
COALESCE((SELECT COUNT(*) FROM applications WHERE org_id = o.id), 0) AS app_count
FROM organizations o ORDER BY o.sort_order, o.created_at`)
if err != nil {
response.InternalError(w, "查询机构失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, slug, short, desc, logo string
sortOrder int
isActive bool
createdAt time.Time
userCount, appCount int
)
if err := rows.Scan(&id, &name, &slug, &short, &desc, &logo, &sortOrder, &isActive, &createdAt, &userCount, &appCount); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "slug": slug, "short_name": short,
"description": desc, "logo_url": logo, "sort_order": sortOrder,
"is_active": isActive, "created_at": createdAt,
"user_count": userCount, "app_count": appCount,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) CreateOrg(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
Slug string `json:"slug"`
ShortName string `json:"short_name"`
Description string `json:"description"`
LogoURL string `json:"logo_url"`
SortOrder int `json:"sort_order"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.Slug == "" {
response.BadRequest(w, "机构名称和标识不能为空")
return
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO organizations (name, slug, short_name, description, logo_url, sort_order)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id::text`,
req.Name, req.Slug, req.ShortName, req.Description, req.LogoURL, req.SortOrder,
).Scan(&id)
if err != nil {
response.InternalError(w, "创建机构失败:标识可能已存在")
return
}
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
}
func (h *PlatformHandler) UpdateOrg(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
var req struct {
Name *string `json:"name"`
ShortName *string `json:"short_name"`
Description *string `json:"description"`
LogoURL *string `json:"logo_url"`
SortOrder *int `json:"sort_order"`
IsActive *bool `json:"is_active"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(), `
UPDATE organizations SET
name = COALESCE($2, name),
short_name = COALESCE($3, short_name),
description = COALESCE($4, description),
logo_url = COALESCE($5, logo_url),
sort_order = COALESCE($6, sort_order),
is_active = COALESCE($7, is_active),
updated_at = NOW()
WHERE id = $1`, id, req.Name, req.ShortName, req.Description, req.LogoURL, req.SortOrder, req.IsActive)
if err != nil {
response.InternalError(w, "更新失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *PlatformHandler) DeleteOrg(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// 安全检查:机构下还有用户或应用时禁止物理删除,引导走停用
var userCount, appCount int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users WHERE org_id = $1`, id).Scan(&userCount)
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications WHERE org_id = $1`, id).Scan(&appCount)
if userCount > 0 || appCount > 0 {
response.BadRequest(w, "该机构尚有用户或应用,无法删除。请先停用或迁移数据")
return
}
_, err := h.pool.Exec(r.Context(), `DELETE FROM organizations WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
// ==================== 全局用户管理 ====================
// ListAllUsers 全局用户列表(支持分页、按机构/角色/状态过滤)
func (h *PlatformHandler) ListAllUsers(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
search := q.Get("q")
roleFilter := q.Get("role")
statusFilter := q.Get("status")
orgFilter := q.Get("org_id")
// 构造 WHERE 子句(list 与 count 共用)
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if search != "" {
where += ` AND (u.name ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%' OR u.email ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%')`
args = append(args, search)
argIdx++
}
if roleFilter != "" {
where += ` AND u.role = $` + strconv.Itoa(argIdx)
args = append(args, roleFilter)
argIdx++
}
if statusFilter != "" {
where += ` AND u.status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
if orgFilter != "" {
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
// 先查总数
var total int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM users u`+where, args...).Scan(&total)
// 再分页查列表
listQuery := `SELECT u.id, u.name, u.email, u.avatar_url, u.role, u.status, u.employee_id,
u.last_login_at, u.login_count, u.created_at,
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM users u LEFT JOIN organizations o ON u.org_id = o.id` + where +
` ORDER BY u.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询用户失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, email, role, status string
avatarURL, employeeID *string
lastLoginAt *time.Time
loginCount int
createdAt time.Time
orgID, orgName, orgShort string
)
if err := rows.Scan(&id, &name, &email, &avatarURL, &role, &status, &employeeID, &lastLoginAt, &loginCount, &createdAt, &orgID, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "email": email, "avatar_url": avatarURL,
"role": role, "status": status, "employee_id": employeeID,
"last_login_at": lastLoginAt, "login_count": loginCount, "created_at": createdAt,
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// AssignUserOrg 把用户分配/迁移到指定机构
func (h *PlatformHandler) AssignUserOrg(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
OrgID string `json:"org_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.OrgID == "" {
response.BadRequest(w, "机构ID不能为空")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET org_id = $2 WHERE id = $1`, userID, req.OrgID)
if err != nil {
response.InternalError(w, "迁移失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已迁移"})
}
// UpdateUserRole 平台管理员可设置任意角色(包括 super_admin
func (h *PlatformHandler) UpdateUserRole(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
Role string `json:"role"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
valid := map[string]bool{"user": true, "creator": true, "admin": true, "super_admin": true}
if !valid[req.Role] {
response.BadRequest(w, "无效的角色")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET role = $2 WHERE id = $1`, userID, req.Role)
if err != nil {
response.InternalError(w, "更新角色失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// UpdateUserStatus 启用/禁用任意用户
func (h *PlatformHandler) UpdateUserStatus(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "id")
var req struct {
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Status != "active" && req.Status != "disabled" {
response.BadRequest(w, "无效的状态")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE users SET status = $2 WHERE id = $1`, userID, req.Status)
if err != nil {
response.InternalError(w, "更新状态失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// ==================== 全局应用管理 ====================
// ListAllApps 全局应用列表(支持分页、按机构/状态过滤)
func (h *PlatformHandler) ListAllApps(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
statusFilter := q.Get("status")
orgFilter := q.Get("org_id")
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if statusFilter != "" {
where += ` AND a.status = $` + strconv.Itoa(argIdx)
args = append(args, statusFilter)
argIdx++
}
if orgFilter != "" {
where += ` AND a.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
var total int
h.pool.QueryRow(r.Context(), `SELECT COUNT(*) FROM applications a`+where, args...).Scan(&total)
listQuery := `SELECT a.id, a.name, a.slug, a.description, a.icon_url,
a.dify_app_type, a.status, a.visibility, a.usage_count, a.is_featured, a.created_at,
COALESCE(c.name, ''), COALESCE(u.name, ''),
COALESCE(o.id::text, ''), COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
LEFT JOIN organizations o ON a.org_id = o.id` + where +
` ORDER BY a.usage_count DESC, a.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询应用失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, slug, status, visibility string
desc, iconURL *string
appType *string
usageCount int64
isFeatured bool
createdAt time.Time
catName, creatorName string
orgID, orgName, orgShort string
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL, &appType, &status, &visibility,
&usageCount, &isFeatured, &createdAt, &catName, &creatorName, &orgID, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "slug": slug, "description": desc, "icon_url": iconURL,
"dify_app_type": appType, "status": status, "visibility": visibility,
"usage_count": usageCount, "is_featured": isFeatured, "created_at": createdAt,
"category_name": catName, "creator_name": creatorName,
"org_id": orgID, "org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// SetFeatured 平台级精选/取消精选
func (h *PlatformHandler) SetFeatured(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
var req struct {
IsFeatured bool `json:"is_featured"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET is_featured = $2 WHERE id = $1`, appID, req.IsFeatured)
if err != nil {
response.InternalError(w, "操作失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
// ForceDelist 平台强制下架(不受机构限制)
func (h *PlatformHandler) ForceDelist(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `UPDATE applications SET status = 'archived' WHERE id = $1`, appID)
if err != nil {
response.InternalError(w, "下架失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已强制下架"})
}
// ==================== 全局审计日志 ====================
// ListAllAuditLogs 全局审计日志(支持分页、按机构/操作类型过滤)
func (h *PlatformHandler) ListAllAuditLogs(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize := 20
offset := (page - 1) * pageSize
orgFilter := q.Get("org_id")
actionFilter := q.Get("action")
where := ` WHERE 1=1`
args := []any{}
argIdx := 1
if orgFilter != "" {
where += ` AND u.org_id = $` + strconv.Itoa(argIdx)
args = append(args, orgFilter)
argIdx++
}
if actionFilter != "" {
where += ` AND al.action ILIKE '%'||$` + strconv.Itoa(argIdx) + `||'%'`
args = append(args, actionFilter)
argIdx++
}
var total int
h.pool.QueryRow(r.Context(),
`SELECT COUNT(*) FROM audit_logs al LEFT JOIN users u ON al.user_id = u.id`+where,
args...).Scan(&total)
listQuery := `SELECT al.id, al.action, al.resource_type, al.resource_id::text, al.details,
al.ip_address::text, al.created_at,
COALESCE(u.name, ''), COALESCE(u.email, ''),
COALESCE(o.name, ''), COALESCE(o.short_name, '')
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN organizations o ON u.org_id = o.id` + where +
` ORDER BY al.created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
listArgs := append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), listQuery, listArgs...)
if err != nil {
response.InternalError(w, "查询审计日志失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, action, resType string
resID, ipAddr *string
details json.RawMessage
createdAt time.Time
userName, userEmail, orgName, orgShort string
)
if err := rows.Scan(&id, &action, &resType, &resID, &details, &ipAddr, &createdAt,
&userName, &userEmail, &orgName, &orgShort); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "action": action, "resource_type": resType, "resource_id": resID,
"details": details, "ip_address": ipAddr, "created_at": createdAt,
"user_name": userName, "user_email": userEmail,
"org_name": orgName, "org_short": orgShort,
})
}
response.JSON(w, http.StatusOK, map[string]any{
"items": items,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// ==================== 模型提供商管理 ====================
func (h *PlatformHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT id, name, base_url, models, is_active, priority, created_at, updated_at
FROM model_providers ORDER BY priority DESC, created_at`)
if err != nil {
response.InternalError(w, "查询失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, name, baseURL string
models json.RawMessage
isActive bool
priority int
createdAt, updatedAt time.Time
)
if err := rows.Scan(&id, &name, &baseURL, &models, &isActive, &priority, &createdAt, &updatedAt); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "name": name, "base_url": baseURL, "models": models,
"is_active": isActive, "priority": priority,
"created_at": createdAt, "updated_at": updatedAt,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
var req struct {
Name string `json:"name"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
Models json.RawMessage `json:"models"`
IsActive bool `json:"is_active"`
Priority int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.BaseURL == "" || req.APIKey == "" {
response.BadRequest(w, "名称、URL、API Key 不能为空")
return
}
if len(req.Models) == 0 {
req.Models = []byte(`[]`)
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO model_providers (name, base_url, api_key_encrypted, models, is_active, priority)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id::text`,
req.Name, req.BaseURL, req.APIKey, req.Models, req.IsActive, req.Priority,
).Scan(&id)
if err != nil {
response.InternalError(w, "创建失败")
return
}
response.JSON(w, http.StatusCreated, map[string]any{"id": id})
}
func (h *PlatformHandler) UpdateProvider(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
var req struct {
Name *string `json:"name"`
BaseURL *string `json:"base_url"`
APIKey *string `json:"api_key"`
Models json.RawMessage `json:"models"`
IsActive *bool `json:"is_active"`
Priority *int `json:"priority"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
// 仅当 api_key 非空才更新(避免误清空)
var apiKeyArg any = nil
if req.APIKey != nil && *req.APIKey != "" {
apiKeyArg = *req.APIKey
}
var modelsArg any = nil
if len(req.Models) > 0 {
modelsArg = req.Models
}
_, err := h.pool.Exec(r.Context(), `
UPDATE model_providers SET
name = COALESCE($2, name),
base_url = COALESCE($3, base_url),
api_key_encrypted = COALESCE($4, api_key_encrypted),
models = COALESCE($5::jsonb, models),
is_active = COALESCE($6, is_active),
priority = COALESCE($7, priority),
updated_at = NOW()
WHERE id = $1`,
id, req.Name, req.BaseURL, apiKeyArg, modelsArg, req.IsActive, req.Priority)
if err != nil {
response.InternalError(w, "更新失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *PlatformHandler) DeleteProvider(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_providers WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
// ==================== 全局配额管理 ====================
func (h *PlatformHandler) ListQuotas(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(), `
SELECT q.id, q.target_type, q.target_id::text, q.model_name,
q.daily_token_limit, q.monthly_token_limit, q.daily_request_limit,
q.is_active, q.created_at,
COALESCE(target_name.name, '')
FROM model_quotas q
LEFT JOIN LATERAL (
SELECT CASE q.target_type
WHEN 'global' THEN '全局'
WHEN 'department' THEN (SELECT name FROM departments WHERE id = q.target_id)
WHEN 'user' THEN (SELECT name FROM users WHERE id = q.target_id)
END AS name
) target_name ON true
ORDER BY q.target_type, q.created_at DESC LIMIT 200`)
if err != nil {
response.InternalError(w, "查询配额失败")
return
}
defer rows.Close()
items := []map[string]any{}
for rows.Next() {
var (
id, targetType string
targetID *string
modelName *string
dailyTokens, monthlyTokens *int64
dailyReqs *int
isActive bool
createdAt time.Time
targetName string
)
if err := rows.Scan(&id, &targetType, &targetID, &modelName, &dailyTokens, &monthlyTokens,
&dailyReqs, &isActive, &createdAt, &targetName); err != nil {
continue
}
items = append(items, map[string]any{
"id": id, "target_type": targetType, "target_id": targetID,
"target_name": targetName,
"model_name": modelName,
"daily_token_limit": dailyTokens,
"monthly_token_limit": monthlyTokens,
"daily_request_limit": dailyReqs,
"is_active": isActive,
"created_at": createdAt,
})
}
response.JSON(w, http.StatusOK, items)
}
func (h *PlatformHandler) UpsertQuota(w http.ResponseWriter, r *http.Request) {
var req struct {
ID *string `json:"id"`
TargetType string `json:"target_type"`
TargetID *string `json:"target_id"`
ModelName *string `json:"model_name"`
DailyTokenLimit *int64 `json:"daily_token_limit"`
MonthlyTokenLimit *int64 `json:"monthly_token_limit"`
DailyRequestLimit *int `json:"daily_request_limit"`
IsActive bool `json:"is_active"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
valid := map[string]bool{"global": true, "department": true, "user": true}
if !valid[req.TargetType] {
response.BadRequest(w, "无效的目标类型")
return
}
if req.TargetType != "global" && (req.TargetID == nil || *req.TargetID == "") {
response.BadRequest(w, "部门/用户配额必须指定 target_id")
return
}
if req.ID != nil && *req.ID != "" {
_, err := h.pool.Exec(r.Context(), `
UPDATE model_quotas SET
target_type = $2, target_id = NULLIF($3,'')::uuid, model_name = $4,
daily_token_limit = $5, monthly_token_limit = $6, daily_request_limit = $7,
is_active = $8, updated_at = NOW()
WHERE id = $1`,
*req.ID, req.TargetType, nullableString(req.TargetID), req.ModelName,
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive)
if err != nil {
response.InternalError(w, "更新配额失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"id": *req.ID})
return
}
var id string
err := h.pool.QueryRow(r.Context(), `
INSERT INTO model_quotas (target_type, target_id, model_name, daily_token_limit, monthly_token_limit, daily_request_limit, is_active)
VALUES ($1, NULLIF($2,'')::uuid, $3, $4, $5, $6, $7)
RETURNING id::text`,
req.TargetType, nullableString(req.TargetID), req.ModelName,
req.DailyTokenLimit, req.MonthlyTokenLimit, req.DailyRequestLimit, req.IsActive,
).Scan(&id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
response.InternalError(w, "创建配额失败")
return
}
response.InternalError(w, "创建配额失败")
return
}
response.JSON(w, http.StatusCreated, map[string]string{"id": id})
}
func (h *PlatformHandler) DeleteQuota(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(), `DELETE FROM model_quotas WHERE id = $1`, id)
if err != nil {
response.InternalError(w, "删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
func nullableString(s *string) string {
if s == nil {
return ""
}
return *s
}
+311
View File
@@ -0,0 +1,311 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"path/filepath"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
)
type PPTHandler struct {
pool *pgxpool.Pool
rdb *redis.Client
workerURL string
}
func NewPPTHandler(pool *pgxpool.Pool, rdb *redis.Client, workerURL string) *PPTHandler {
return &PPTHandler{pool: pool, rdb: rdb, workerURL: workerURL}
}
// ==================== 请求/响应结构 ====================
type createPPTRequest struct {
Title string `json:"title"`
SourceType string `json:"source_type"` // text / url
SourceContent string `json:"source_content"` // 文本内容或 URL
Config map[string]any `json:"config"`
}
type pptTaskResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Progress int `json:"progress"`
StatusMessage *string `json:"status_message,omitempty"`
ErrorMessage *string `json:"error_message,omitempty"`
OutputFile *string `json:"output_file,omitempty"`
PageCount *int `json:"page_count,omitempty"`
Title string `json:"title"`
CreatedAt string `json:"created_at"`
}
// ==================== 接口实现 ====================
// CreateTask 创建 PPT 生成任务(文本/URL 输入)
func (h *PPTHandler) CreateTask(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var req createPPTRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Title == "" {
response.BadRequest(w, "标题不能为空")
return
}
if req.SourceType == "" {
req.SourceType = "text"
}
if req.SourceContent == "" {
response.BadRequest(w, "请提供源内容")
return
}
taskID := uuid.New().String()
configJSON, _ := json.Marshal(req.Config)
// 写入数据库
_, err := h.pool.Exec(r.Context(),
`INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
VALUES ($1, $2, $3, $4, $5, $6)`,
taskID, userID, req.Title, req.SourceType, req.SourceContent, configJSON,
)
if err != nil {
response.InternalError(w, "创建任务失败: "+err.Error())
return
}
// 推送到 Redis 队列
taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID})
h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg)
response.JSON(w, http.StatusCreated, map[string]string{
"task_id": taskID,
"status": "pending",
})
}
// CreateTaskWithFile 创建带文件上传的 PPT 生成任务
func (h *PPTHandler) CreateTaskWithFile(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
if err := r.ParseMultipartForm(50 << 20); err != nil { // 50MB 限制
response.BadRequest(w, "文件过大或请求格式错误")
return
}
title := r.FormValue("title")
if title == "" {
response.BadRequest(w, "标题不能为空")
return
}
configStr := r.FormValue("config")
var taskConfig map[string]any
if configStr != "" {
json.Unmarshal([]byte(configStr), &taskConfig)
}
if taskConfig == nil {
taskConfig = map[string]any{}
}
file, header, err := r.FormFile("file")
if err != nil {
response.BadRequest(w, "请上传文件")
return
}
defer file.Close()
// 将文件转发到 PPT Worker
taskID := uuid.New().String()
configJSON, _ := json.Marshal(taskConfig)
// 转发文件到 Worker 服务
err = h.forwardFileToWorker(r.Context(), taskID, userID.String(), title, string(configJSON), file, header)
if err != nil {
response.InternalError(w, "提交任务失败: "+err.Error())
return
}
response.JSON(w, http.StatusCreated, map[string]string{
"task_id": taskID,
"status": "pending",
})
}
// GetTaskStatus 查询任务状态
func (h *PPTHandler) GetTaskStatus(w http.ResponseWriter, r *http.Request) {
taskID := chi.URLParam(r, "taskId")
userID := middleware.GetUserID(r.Context())
// 先从 Redis 快速查询
key := "ppt:status:" + taskID
cached, err := h.rdb.HGetAll(r.Context(), key).Result()
if err == nil && len(cached) > 0 {
progress := 0
fmt.Sscanf(cached["progress"], "%d", &progress)
msg := cached["message"]
response.JSON(w, http.StatusOK, pptTaskResponse{
TaskID: taskID,
Status: cached["status"],
Progress: progress,
StatusMessage: &msg,
})
return
}
// 回退到数据库
var task pptTaskResponse
var statusMsg, errMsg, outputFile *string
var pageCount *int
var createdAt time.Time
err = h.pool.QueryRow(r.Context(),
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
FROM ppt_tasks WHERE id = $1 AND user_id = $2`, taskID, userID,
).Scan(&task.TaskID, &task.Status, &task.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &task.Title, &createdAt)
if err != nil {
response.NotFound(w, "任务不存在")
return
}
task.StatusMessage = statusMsg
task.ErrorMessage = errMsg
task.OutputFile = outputFile
task.PageCount = pageCount
task.CreatedAt = createdAt.Format(time.RFC3339)
response.JSON(w, http.StatusOK, task)
}
// ListTasks 列出用户的 PPT 任务
func (h *PPTHandler) ListTasks(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
rows, err := h.pool.Query(r.Context(),
`SELECT id, status, progress, status_message, error_message, output_file, page_count, title, created_at
FROM ppt_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT 50`, userID,
)
if err != nil {
response.InternalError(w, "查询失败")
return
}
defer rows.Close()
var tasks []pptTaskResponse
for rows.Next() {
var t pptTaskResponse
var statusMsg, errMsg, outputFile *string
var pageCount *int
var createdAt time.Time
if err := rows.Scan(&t.TaskID, &t.Status, &t.Progress, &statusMsg, &errMsg, &outputFile, &pageCount, &t.Title, &createdAt); err != nil {
continue
}
t.StatusMessage = statusMsg
t.ErrorMessage = errMsg
t.OutputFile = outputFile
t.PageCount = pageCount
t.CreatedAt = createdAt.Format(time.RFC3339)
tasks = append(tasks, t)
}
if tasks == nil {
tasks = []pptTaskResponse{}
}
response.JSON(w, http.StatusOK, tasks)
}
// DownloadTask 下载生成的 PPTX 文件
func (h *PPTHandler) DownloadTask(w http.ResponseWriter, r *http.Request) {
taskID := chi.URLParam(r, "taskId")
userID := middleware.GetUserID(r.Context())
var status, title string
var outputFile *string
err := h.pool.QueryRow(r.Context(),
`SELECT status, title, output_file FROM ppt_tasks WHERE id = $1 AND user_id = $2`,
taskID, userID,
).Scan(&status, &title, &outputFile)
if err != nil {
response.NotFound(w, "任务不存在")
return
}
if status != "completed" {
response.BadRequest(w, "任务未完成")
return
}
// 代理下载请求到 Worker 服务
workerResp, err := http.Get(h.workerURL + "/api/tasks/" + taskID + "/download")
if err != nil {
response.InternalError(w, "下载服务不可用")
return
}
defer workerResp.Body.Close()
if workerResp.StatusCode != http.StatusOK {
response.InternalError(w, "文件下载失败")
return
}
filename := title + ".pptx"
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.presentationml.presentation")
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename))
io.Copy(w, workerResp.Body)
}
// ==================== 内部方法 ====================
func (h *PPTHandler) forwardFileToWorker(ctx context.Context, taskID, userID, title, configJSON string, file multipart.File, header *multipart.FileHeader) error {
// 构造 multipart 请求转发到 Worker
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
writer.WriteField("task_id", taskID)
writer.WriteField("user_id", userID)
writer.WriteField("title", title)
writer.WriteField("config_json", configJSON)
part, err := writer.CreateFormFile("file", filepath.Base(header.Filename))
if err != nil {
return err
}
if _, err := io.Copy(part, file); err != nil {
return err
}
writer.Close()
req, err := http.NewRequestWithContext(ctx, "POST", h.workerURL+"/api/tasks/upload", &buf)
if err != nil {
return err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("worker 返回错误 %d: %s", resp.StatusCode, string(body))
}
return nil
}
+350
View File
@@ -0,0 +1,350 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type StoreHandler struct {
pool *pgxpool.Pool
}
func NewStoreHandler(pool *pgxpool.Pool) *StoreHandler {
return &StoreHandler{pool: pool}
}
func (h *StoreHandler) ListCategories(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("org_id")
query := `SELECT c.id, c.name, c.slug, c.icon, c.description, c.sort_order,
COALESCE((SELECT COUNT(*) FROM applications a WHERE a.category_id = c.id AND a.status = 'approved'), 0) AS app_count
FROM categories c WHERE c.status = 'active'`
var args []any
if orgID != "" {
query += ` AND (c.org_id = $1 OR c.org_id IS NULL)`
args = append(args, orgID)
}
query += ` ORDER BY c.sort_order ASC`
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询分类失败")
return
}
defer rows.Close()
var cats []map[string]any
for rows.Next() {
var id, name, slug string
var icon, desc *string
var sortOrder int
var appCount int
if err := rows.Scan(&id, &name, &slug, &icon, &desc, &sortOrder, &appCount); err != nil {
continue
}
cats = append(cats, map[string]any{
"id": id, "name": name, "slug": slug,
"icon": icon, "description": desc, "sort_order": sortOrder,
"app_count": appCount,
})
}
response.JSON(w, http.StatusOK, cats)
}
func (h *StoreHandler) ListApps(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
page, _ := strconv.Atoi(q.Get("page"))
if page < 1 {
page = 1
}
pageSize, _ := strconv.Atoi(q.Get("page_size"))
if pageSize < 1 || pageSize > 50 {
pageSize = 20
}
offset := (page - 1) * pageSize
search := q.Get("q")
category := q.Get("category")
sort := q.Get("sort")
orgFilter := q.Get("org_id")
if sort == "" {
sort = "popular"
}
query := `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
u.name as creator_name,
a.usage_count, a.favorite_count, a.avg_rating, a.rating_count,
a.dify_app_type, a.welcome_message, a.published_at::text
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.status = 'approved' AND a.visibility = 'public'`
args := []any{}
argIdx := 1
// 按机构过滤:显示指定机构的应用 + 无机构归属的全局应用
if orgFilter != "" {
query += ` AND (a.org_id = $` + strconv.Itoa(argIdx) + ` OR a.org_id IS NULL)`
args = append(args, orgFilter)
argIdx++
}
if search != "" {
query += ` AND to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
@@ plainto_tsquery('simple', $` + strconv.Itoa(argIdx) + `)`
args = append(args, search)
argIdx++
}
if category != "" {
query += ` AND c.slug = $` + strconv.Itoa(argIdx)
args = append(args, category)
argIdx++
}
switch sort {
case "rating":
query += ` ORDER BY a.avg_rating DESC, a.usage_count DESC`
case "latest":
query += ` ORDER BY a.published_at DESC NULLS LAST`
default:
query += ` ORDER BY a.usage_count DESC`
}
query += ` LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1)
args = append(args, pageSize, offset)
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询应用失败")
return
}
defer rows.Close()
var apps []map[string]any
for rows.Next() {
var (
id, name, slug string
desc, iconURL *string
catName, catSlug *string
creatorName *string
usageCount int64
favCount, ratingCt int
avgRating float32
difyType, welcome *string
publishedAt *string
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL,
&catName, &catSlug, &creatorName,
&usageCount, &favCount, &avgRating, &ratingCt,
&difyType, &welcome, &publishedAt); err != nil {
continue
}
apps = append(apps, map[string]any{
"id": id, "name": name, "slug": slug,
"description": desc, "icon_url": iconURL,
"category_name": catName, "category_slug": catSlug,
"creator_name": creatorName,
"usage_count": usageCount, "favorite_count": favCount,
"avg_rating": avgRating, "rating_count": ratingCt,
"dify_app_type": difyType, "welcome_message": welcome,
"published_at": publishedAt,
})
}
if apps == nil {
apps = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{
"items": apps,
"page": page,
"page_size": pageSize,
})
}
func (h *StoreHandler) GetApp(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug")
var (
id, name, appSlug string
desc, longDesc *string
iconURL *string
catName, catSlug *string
creatorName *string
usageCount int64
favCount, ratingCt int
avgRating float32
difyType, welcome *string
suggestedPrompts *string
appConfig *string
publishedAt *string
version string
)
err := h.pool.QueryRow(r.Context(), `
SELECT a.id, a.name, a.slug, a.description, a.long_description, a.icon_url,
c.name, c.slug, u.name,
a.usage_count, a.favorite_count, a.avg_rating, a.rating_count,
a.dify_app_type, a.welcome_message, a.suggested_prompts::text,
a.app_config::text,
a.published_at::text, a.version
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.slug = $1 AND a.status = 'approved'`, slug,
).Scan(&id, &name, &appSlug, &desc, &longDesc, &iconURL,
&catName, &catSlug, &creatorName,
&usageCount, &favCount, &avgRating, &ratingCt,
&difyType, &welcome, &suggestedPrompts,
&appConfig,
&publishedAt, &version)
if err != nil {
response.NotFound(w, "应用不存在")
return
}
// Check if user favorited this app
isFavorited := false
userID := middleware.GetUserID(r.Context())
if userID.String() != "00000000-0000-0000-0000-000000000000" {
_ = h.pool.QueryRow(r.Context(),
`SELECT EXISTS(SELECT 1 FROM app_favorites WHERE user_id = $1 AND app_id = $2)`,
userID, id).Scan(&isFavorited)
}
// Parse app_config as JSON if present
var configData any
if appConfig != nil {
_ = json.Unmarshal([]byte(*appConfig), &configData)
}
response.JSON(w, http.StatusOK, map[string]any{
"id": id, "name": name, "slug": appSlug,
"description": desc, "long_description": longDesc, "icon_url": iconURL,
"category_name": catName, "category_slug": catSlug,
"creator_name": creatorName,
"usage_count": usageCount, "favorite_count": favCount,
"avg_rating": avgRating, "rating_count": ratingCt,
"dify_app_type": difyType, "welcome_message": welcome,
"suggested_prompts": suggestedPrompts,
"app_config": configData,
"published_at": publishedAt, "version": version,
"is_favorited": isFavorited,
})
}
func (h *StoreHandler) Featured(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("org_id")
query := `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.is_featured = true AND a.status = 'approved' AND a.visibility = 'public'`
var args []any
if orgID != "" {
query += ` AND (a.org_id = $1 OR a.org_id IS NULL)`
args = append(args, orgID)
}
query += ` ORDER BY a.usage_count DESC LIMIT 4`
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询精选应用失败")
return
}
defer rows.Close()
apps := scanAppList(rows)
response.JSON(w, http.StatusOK, apps)
}
func (h *StoreHandler) Rankings(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("org_id")
query := `
SELECT a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.status = 'approved' AND a.visibility = 'public'`
var args []any
if orgID != "" {
query += ` AND (a.org_id = $1 OR a.org_id IS NULL)`
args = append(args, orgID)
}
query += ` ORDER BY a.usage_count DESC LIMIT 50`
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "查询排行榜失败")
return
}
defer rows.Close()
apps := scanAppList(rows)
response.JSON(w, http.StatusOK, apps)
}
func (h *StoreHandler) Recent(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
rows, err := h.pool.Query(r.Context(), `
SELECT DISTINCT ON (a.id) a.id, a.name, a.slug, a.description, a.icon_url,
c.name as category_name, c.slug as category_slug,
a.usage_count, a.avg_rating, a.rating_count, a.dify_app_type
FROM app_usage_logs l
JOIN applications a ON l.app_id = a.id
LEFT JOIN categories c ON a.category_id = c.id
WHERE l.user_id = $1
ORDER BY a.id, l.created_at DESC
LIMIT 10`, userID)
if err != nil {
response.InternalError(w, "查询最近使用失败")
return
}
defer rows.Close()
apps := scanAppList(rows)
response.JSON(w, http.StatusOK, apps)
}
func scanAppList(rows interface {
Next() bool
Scan(dest ...any) error
}) []map[string]any {
var apps []map[string]any
for rows.Next() {
var (
id, name, slug string
desc, iconURL *string
catName, catSlug *string
usageCount int64
avgRating float32
ratingCt int
difyType *string
)
if err := rows.Scan(&id, &name, &slug, &desc, &iconURL,
&catName, &catSlug, &usageCount, &avgRating, &ratingCt, &difyType); err != nil {
continue
}
apps = append(apps, map[string]any{
"id": id, "name": name, "slug": slug,
"description": desc, "icon_url": iconURL,
"category_name": catName, "category_slug": catSlug,
"usage_count": usageCount, "avg_rating": avgRating, "rating_count": ratingCt,
"dify_app_type": difyType,
})
}
if apps == nil {
apps = []map[string]any{}
}
return apps
}
+65
View File
@@ -0,0 +1,65 @@
package middleware
import (
"context"
"encoding/json"
"net"
"net/http"
"github.com/jackc/pgx/v5/pgxpool"
)
// extractIP returns a clean IP without port. Falls back to "" so the INET
// column can take NULL via $5 when the value is empty.
func extractIP(r *http.Request) any {
addr := r.RemoteAddr
if host, _, err := net.SplitHostPort(addr); err == nil {
addr = host
}
if addr == "" {
return nil
}
return addr
}
// AuditLog records API access to the audit_logs table for important operations.
func AuditLog(pool *pgxpool.Pool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
// Only audit write operations
if r.Method == "GET" || r.Method == "OPTIONS" {
return
}
userID := GetUserID(r.Context())
if userID.String() == "00000000-0000-0000-0000-000000000000" {
return
}
details, _ := json.Marshal(map[string]string{
"method": r.Method,
"path": r.URL.Path,
})
ip := extractIP(r)
ua := r.UserAgent()
method := r.Method
path := r.URL.Path
go func() {
_, _ = pool.Exec(context.Background(),
`INSERT INTO audit_logs (user_id, action, resource_type, resource_id, details, ip_address, user_agent)
VALUES ($1, $2, $3, NULL, $4, $5, $6)`,
userID,
method+"."+path,
"api",
details,
ip,
ua,
)
}()
})
}
}
+76
View File
@@ -0,0 +1,76 @@
package middleware
import (
"context"
"net/http"
"strings"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/auth"
"github.com/google/uuid"
)
type contextKey string
const (
UserIDKey contextKey = "user_id"
EmailKey contextKey = "email"
RoleKey contextKey = "role"
)
func GetUserID(ctx context.Context) uuid.UUID {
v, _ := ctx.Value(UserIDKey).(uuid.UUID)
return v
}
func GetRole(ctx context.Context) string {
v, _ := ctx.Value(RoleKey).(string)
return v
}
func GetEmail(ctx context.Context) string {
v, _ := ctx.Value(EmailKey).(string)
return v
}
// Auth creates a middleware that validates JWT and injects user info into context.
func Auth(jwtMgr *auth.JWTManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenStr := extractToken(r)
if tokenStr == "" {
response.Unauthorized(w, "未登录")
return
}
claims, err := jwtMgr.ValidateToken(tokenStr)
if err != nil {
response.Error(w, http.StatusUnauthorized, 40102, "Token 已过期或无效")
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, UserIDKey, claims.UserID)
ctx = context.WithValue(ctx, EmailKey, claims.Email)
ctx = context.WithValue(ctx, RoleKey, claims.Role)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func extractToken(r *http.Request) string {
// Try Authorization header first
bearer := r.Header.Get("Authorization")
if strings.HasPrefix(bearer, "Bearer ") {
return strings.TrimPrefix(bearer, "Bearer ")
}
// Then try cookie
cookie, err := r.Cookie("access_token")
if err == nil {
return cookie.Value
}
return ""
}
+39
View File
@@ -0,0 +1,39 @@
package middleware
import (
"context"
"fmt"
"net/http"
"time"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/redis/go-redis/v9"
)
// RateLimit creates a per-user rate limiter using Redis sliding window.
func RateLimit(rdb *redis.Client, maxRequests int, window time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := GetUserID(r.Context())
key := fmt.Sprintf("rl:%s:%s", userID.String(), r.URL.Path)
ctx := context.Background()
count, err := rdb.Incr(ctx, key).Result()
if err != nil {
next.ServeHTTP(w, r)
return
}
if count == 1 {
rdb.Expire(ctx, key, window)
}
if count > int64(maxRequests) {
response.TooManyRequests(w, "请求过于频繁,请稍后再试")
return
}
next.ServeHTTP(w, r)
})
}
}
+41
View File
@@ -0,0 +1,41 @@
package middleware
import (
"net/http"
"github.com/enterprise-ai-platform/server/internal/response"
)
var roleLevel = map[string]int{
"user": 0,
"creator": 1,
"admin": 2,
"super_admin": 3,
}
// RequireRole returns middleware that checks if user has the minimum required role.
func RequireRole(minRole string) func(http.Handler) http.Handler {
minLevel := roleLevel[minRole]
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
role := GetRole(r.Context())
if roleLevel[role] < minLevel {
response.Forbidden(w, "权限不足")
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireSuperAdmin restricts access to platform-level (super_admin) operations only.
// Unlike RequireRole("admin")super admin 不受机构(org_id)限制,可执行跨机构操作。
func RequireSuperAdmin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if GetRole(r.Context()) != "super_admin" {
response.Forbidden(w, "仅平台管理员可访问")
return
}
next.ServeHTTP(w, r)
})
}
+56
View File
@@ -0,0 +1,56 @@
package response
import (
"encoding/json"
"net/http"
)
type APIResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}
func JSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(APIResponse{
Code: 0,
Message: "success",
Data: data,
})
}
func Error(w http.ResponseWriter, status int, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(APIResponse{
Code: code,
Message: message,
Data: nil,
})
}
func BadRequest(w http.ResponseWriter, message string) {
Error(w, http.StatusBadRequest, 40001, message)
}
func Unauthorized(w http.ResponseWriter, message string) {
Error(w, http.StatusUnauthorized, 40101, message)
}
func Forbidden(w http.ResponseWriter, message string) {
Error(w, http.StatusForbidden, 40301, message)
}
func NotFound(w http.ResponseWriter, message string) {
Error(w, http.StatusNotFound, 40401, message)
}
func InternalError(w http.ResponseWriter, message string) {
Error(w, http.StatusInternalServerError, 50001, message)
}
func TooManyRequests(w http.ResponseWriter, message string) {
Error(w, http.StatusTooManyRequests, 42901, message)
}