Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// extractIP returns a clean IP without port. Falls back to "" so the INET
|
||||
// column can take NULL via $5 when the value is empty.
|
||||
func extractIP(r *http.Request) any {
|
||||
addr := r.RemoteAddr
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil {
|
||||
addr = host
|
||||
}
|
||||
if addr == "" {
|
||||
return nil
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// AuditLog records API access to the audit_logs table for important operations.
|
||||
func AuditLog(pool *pgxpool.Pool) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
// Only audit write operations
|
||||
if r.Method == "GET" || r.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
|
||||
userID := GetUserID(r.Context())
|
||||
if userID.String() == "00000000-0000-0000-0000-000000000000" {
|
||||
return
|
||||
}
|
||||
|
||||
details, _ := json.Marshal(map[string]string{
|
||||
"method": r.Method,
|
||||
"path": r.URL.Path,
|
||||
})
|
||||
|
||||
ip := extractIP(r)
|
||||
ua := r.UserAgent()
|
||||
method := r.Method
|
||||
path := r.URL.Path
|
||||
|
||||
go func() {
|
||||
_, _ = pool.Exec(context.Background(),
|
||||
`INSERT INTO audit_logs (user_id, action, resource_type, resource_id, details, ip_address, user_agent)
|
||||
VALUES ($1, $2, $3, NULL, $4, $5, $6)`,
|
||||
userID,
|
||||
method+"."+path,
|
||||
"api",
|
||||
details,
|
||||
ip,
|
||||
ua,
|
||||
)
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/auth"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
UserIDKey contextKey = "user_id"
|
||||
EmailKey contextKey = "email"
|
||||
RoleKey contextKey = "role"
|
||||
)
|
||||
|
||||
func GetUserID(ctx context.Context) uuid.UUID {
|
||||
v, _ := ctx.Value(UserIDKey).(uuid.UUID)
|
||||
return v
|
||||
}
|
||||
|
||||
func GetRole(ctx context.Context) string {
|
||||
v, _ := ctx.Value(RoleKey).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
func GetEmail(ctx context.Context) string {
|
||||
v, _ := ctx.Value(EmailKey).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// Auth creates a middleware that validates JWT and injects user info into context.
|
||||
func Auth(jwtMgr *auth.JWTManager) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenStr := extractToken(r)
|
||||
if tokenStr == "" {
|
||||
response.Unauthorized(w, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := jwtMgr.ValidateToken(tokenStr)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusUnauthorized, 40102, "Token 已过期或无效")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserIDKey, claims.UserID)
|
||||
ctx = context.WithValue(ctx, EmailKey, claims.Email)
|
||||
ctx = context.WithValue(ctx, RoleKey, claims.Role)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) string {
|
||||
// Try Authorization header first
|
||||
bearer := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(bearer, "Bearer ") {
|
||||
return strings.TrimPrefix(bearer, "Bearer ")
|
||||
}
|
||||
|
||||
// Then try cookie
|
||||
cookie, err := r.Cookie("access_token")
|
||||
if err == nil {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimit creates a per-user rate limiter using Redis sliding window.
|
||||
func RateLimit(rdb *redis.Client, maxRequests int, window time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userID := GetUserID(r.Context())
|
||||
key := fmt.Sprintf("rl:%s:%s", userID.String(), r.URL.Path)
|
||||
|
||||
ctx := context.Background()
|
||||
count, err := rdb.Incr(ctx, key).Result()
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
rdb.Expire(ctx, key, window)
|
||||
}
|
||||
|
||||
if count > int64(maxRequests) {
|
||||
response.TooManyRequests(w, "请求过于频繁,请稍后再试")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
)
|
||||
|
||||
var roleLevel = map[string]int{
|
||||
"user": 0,
|
||||
"creator": 1,
|
||||
"admin": 2,
|
||||
"super_admin": 3,
|
||||
}
|
||||
|
||||
// RequireRole returns middleware that checks if user has the minimum required role.
|
||||
func RequireRole(minRole string) func(http.Handler) http.Handler {
|
||||
minLevel := roleLevel[minRole]
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
role := GetRole(r.Context())
|
||||
if roleLevel[role] < minLevel {
|
||||
response.Forbidden(w, "权限不足")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSuperAdmin restricts access to platform-level (super_admin) operations only.
|
||||
// Unlike RequireRole("admin"),super admin 不受机构(org_id)限制,可执行跨机构操作。
|
||||
func RequireSuperAdmin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if GetRole(r.Context()) != "super_admin" {
|
||||
response.Forbidden(w, "仅平台管理员可访问")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user