Files
GovAI/server/internal/handler/auth.go
T
2026-06-15 23:48:37 +08:00

419 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/auth"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
type AuthHandler struct {
pool *pgxpool.Pool
jwtMgr *auth.JWTManager
}
func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler {
return &AuthHandler{pool: pool, jwtMgr: jwtMgr}
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
OrgID string `json:"org_id"`
}
type orgInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
ShortName string `json:"short_name"`
}
type userResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL *string `json:"avatar_url"`
Role string `json:"role"`
EmployeeID *string `json:"employee_id"`
OrgID *string `json:"org_id"`
Org *orgInfo `json:"org,omitempty"`
}
type registerRequest struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req registerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Name == "" || req.Email == "" || req.Password == "" {
response.BadRequest(w, "姓名、邮箱和密码不能为空")
return
}
if len(req.Password) < 6 {
response.BadRequest(w, "密码长度不能少于6位")
return
}
var exists bool
h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists)
if exists {
response.Error(w, http.StatusConflict, 40901, "该邮箱已注册")
return
}
hash, err := auth.HashPassword(req.Password)
if err != nil {
response.InternalError(w, "密码加密失败")
return
}
id := uuid.New()
_, err = h.pool.Exec(r.Context(),
`INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`,
id, req.Name, req.Email, hash)
if err != nil {
response.InternalError(w, "注册失败")
return
}
tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user")
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
response.JSON(w, http.StatusCreated, map[string]any{
"user": userResponse{
ID: id.String(), Name: req.Name, Email: req.Email, Role: "user",
},
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Email == "" || req.Password == "" {
response.BadRequest(w, "邮箱和密码不能为空")
return
}
var (
id string
name string
email string
passwordHash *string
avatarURL *string
role string
employeeID *string
status string
orgID *string
)
err := h.pool.QueryRow(r.Context(),
`SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text
FROM users WHERE email = $1`, req.Email,
).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID)
if err != nil {
response.Unauthorized(w, "邮箱或密码错误")
return
}
if status != "active" {
response.Error(w, http.StatusForbidden, 40302, "账号已被禁用")
return
}
if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) {
response.Unauthorized(w, "邮箱或密码错误")
return
}
// 平台管理员不绑定机构,可登录任意机构入口
// 普通用户/机构管理员必须属于所选机构
if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID {
response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构")
return
}
uid, _ := uuid.Parse(id)
tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role)
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
// Update login info
go func() {
_, _ = h.pool.Exec(context.Background(),
`UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id)
}()
// Set cookies
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: tokenPair.RefreshToken,
Path: "/api/v1/auth/refresh",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(7 * 24 * time.Hour / time.Second),
})
usr := userResponse{
ID: id, Name: name, Email: email,
AvatarURL: avatarURL, Role: role, EmployeeID: employeeID,
OrgID: orgID,
}
if orgID != nil {
usr.Org = h.loadOrgInfo(r.Context(), *orgID)
}
response.JSON(w, http.StatusOK, map[string]any{
"user": usr,
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: "",
Path: "/",
HttpOnly: true,
MaxAge: -1,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/api/v1/auth/refresh",
HttpOnly: true,
MaxAge: -1,
})
response.JSON(w, http.StatusOK, nil)
}
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var u userResponse
err := h.pool.QueryRow(r.Context(),
`SELECT id, name, email, avatar_url, role, employee_id, org_id::text
FROM users WHERE id = $1`, userID,
).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID)
if err != nil {
response.NotFound(w, "用户不存在")
return
}
if u.OrgID != nil {
u.Org = h.loadOrgInfo(r.Context(), *u.OrgID)
}
response.JSON(w, http.StatusOK, u)
}
// ListOrganizations 返回所有可用机构列表
func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(),
`SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'')
FROM organizations WHERE is_active = true ORDER BY sort_order ASC`)
if err != nil {
response.InternalError(w, "查询机构列表失败")
return
}
defer rows.Close()
var orgs []map[string]any
for rows.Next() {
var id, name, slug, shortName, desc, logo string
if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil {
continue
}
orgs = append(orgs, map[string]any{
"id": id, "name": name, "slug": slug,
"short_name": shortName, "description": desc, "logo_url": logo,
})
}
if orgs == nil {
orgs = []map[string]any{}
}
response.JSON(w, http.StatusOK, orgs)
}
// SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份
func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) {
currentRole := middleware.GetRole(r.Context())
if currentRole != "super_admin" && currentRole != "admin" {
response.Forbidden(w, "仅管理员可以切换机构")
return
}
var req struct {
OrgID string `json:"org_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" {
response.BadRequest(w, "请选择机构")
return
}
// 校验机构是否存在
var exists bool
h.pool.QueryRow(r.Context(),
`SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists)
if !exists {
response.NotFound(w, "机构不存在")
return
}
// 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator
var targetID uuid.UUID
var targetEmail, targetRole, targetName string
err := h.pool.QueryRow(r.Context(),
`SELECT id, email, role, name FROM users
WHERE org_id = $1 AND status = 'active'
ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END
LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName)
if err != nil {
response.InternalError(w, "该机构暂无可用用户")
return
}
// 为目标用户生成新的JWT token
tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole)
if err != nil {
response.InternalError(w, "生成令牌失败")
return
}
org := h.loadOrgInfo(r.Context(), req.OrgID)
response.JSON(w, http.StatusOK, map[string]any{
"message": "已切换",
"org": org,
"token": tokens.AccessToken,
"user": map[string]any{
"id": targetID,
"name": targetName,
"email": targetEmail,
"role": targetRole,
"org_id": req.OrgID,
},
})
}
func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo {
var o orgInfo
err := h.pool.QueryRow(ctx,
`SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID,
).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName)
if err != nil {
return nil
}
return &o
}
func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
userID := middleware.GetUserID(r.Context())
var req struct {
Name string `json:"name"`
AvatarURL string `json:"avatar_url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
_, err := h.pool.Exec(r.Context(),
`UPDATE users SET name = COALESCE(NULLIF($2,''), name),
avatar_url = COALESCE(NULLIF($3,''), avatar_url),
updated_at = NOW()
WHERE id = $1`, userID, req.Name, req.AvatarURL)
if err != nil {
response.InternalError(w, "更新个人信息失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("refresh_token")
if err != nil {
response.Unauthorized(w, "Refresh Token 不存在")
return
}
claims, err := h.jwtMgr.ValidateToken(cookie.Value)
if err != nil {
response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期")
return
}
// Get current role from DB
var role string
err = h.pool.QueryRow(r.Context(),
`SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID,
).Scan(&role)
if err != nil {
response.Unauthorized(w, "用户不存在或已被禁用")
return
}
tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role)
if err != nil {
response.InternalError(w, "生成Token失败")
return
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: tokenPair.AccessToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(24 * time.Hour / time.Second),
})
response.JSON(w, http.StatusOK, map[string]any{
"access_token": tokenPair.AccessToken,
"expires_at": tokenPair.ExpiresAt,
})
}