77 lines
1.7 KiB
Go
77 lines
1.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/enterprise-ai-platform/server/internal/response"
|
|
"github.com/enterprise-ai-platform/server/pkg/auth"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
UserIDKey contextKey = "user_id"
|
|
EmailKey contextKey = "email"
|
|
RoleKey contextKey = "role"
|
|
)
|
|
|
|
func GetUserID(ctx context.Context) uuid.UUID {
|
|
v, _ := ctx.Value(UserIDKey).(uuid.UUID)
|
|
return v
|
|
}
|
|
|
|
func GetRole(ctx context.Context) string {
|
|
v, _ := ctx.Value(RoleKey).(string)
|
|
return v
|
|
}
|
|
|
|
func GetEmail(ctx context.Context) string {
|
|
v, _ := ctx.Value(EmailKey).(string)
|
|
return v
|
|
}
|
|
|
|
// Auth creates a middleware that validates JWT and injects user info into context.
|
|
func Auth(jwtMgr *auth.JWTManager) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenStr := extractToken(r)
|
|
if tokenStr == "" {
|
|
response.Unauthorized(w, "未登录")
|
|
return
|
|
}
|
|
|
|
claims, err := jwtMgr.ValidateToken(tokenStr)
|
|
if err != nil {
|
|
response.Error(w, http.StatusUnauthorized, 40102, "Token 已过期或无效")
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, UserIDKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, EmailKey, claims.Email)
|
|
ctx = context.WithValue(ctx, RoleKey, claims.Role)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
func extractToken(r *http.Request) string {
|
|
// Try Authorization header first
|
|
bearer := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(bearer, "Bearer ") {
|
|
return strings.TrimPrefix(bearer, "Bearer ")
|
|
}
|
|
|
|
// Then try cookie
|
|
cookie, err := r.Cookie("access_token")
|
|
if err == nil {
|
|
return cookie.Value
|
|
}
|
|
|
|
return ""
|
|
}
|