419 lines
11 KiB
Go
419 lines
11 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||
"github.com/enterprise-ai-platform/server/internal/response"
|
||
"github.com/enterprise-ai-platform/server/pkg/auth"
|
||
"github.com/google/uuid"
|
||
"github.com/jackc/pgx/v5/pgxpool"
|
||
)
|
||
|
||
type AuthHandler struct {
|
||
pool *pgxpool.Pool
|
||
jwtMgr *auth.JWTManager
|
||
}
|
||
|
||
func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler {
|
||
return &AuthHandler{pool: pool, jwtMgr: jwtMgr}
|
||
}
|
||
|
||
type loginRequest struct {
|
||
Email string `json:"email"`
|
||
Password string `json:"password"`
|
||
OrgID string `json:"org_id"`
|
||
}
|
||
|
||
type orgInfo struct {
|
||
ID string `json:"id"`
|
||
Name string `json:"name"`
|
||
Slug string `json:"slug"`
|
||
ShortName string `json:"short_name"`
|
||
}
|
||
|
||
type userResponse struct {
|
||
ID string `json:"id"`
|
||
Name string `json:"name"`
|
||
Email string `json:"email"`
|
||
AvatarURL *string `json:"avatar_url"`
|
||
Role string `json:"role"`
|
||
EmployeeID *string `json:"employee_id"`
|
||
OrgID *string `json:"org_id"`
|
||
Org *orgInfo `json:"org,omitempty"`
|
||
}
|
||
|
||
type registerRequest struct {
|
||
Name string `json:"name"`
|
||
Email string `json:"email"`
|
||
Password string `json:"password"`
|
||
}
|
||
|
||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||
var req registerRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
response.BadRequest(w, "无效的请求格式")
|
||
return
|
||
}
|
||
if req.Name == "" || req.Email == "" || req.Password == "" {
|
||
response.BadRequest(w, "姓名、邮箱和密码不能为空")
|
||
return
|
||
}
|
||
if len(req.Password) < 6 {
|
||
response.BadRequest(w, "密码长度不能少于6位")
|
||
return
|
||
}
|
||
|
||
var exists bool
|
||
h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists)
|
||
if exists {
|
||
response.Error(w, http.StatusConflict, 40901, "该邮箱已注册")
|
||
return
|
||
}
|
||
|
||
hash, err := auth.HashPassword(req.Password)
|
||
if err != nil {
|
||
response.InternalError(w, "密码加密失败")
|
||
return
|
||
}
|
||
|
||
id := uuid.New()
|
||
_, err = h.pool.Exec(r.Context(),
|
||
`INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`,
|
||
id, req.Name, req.Email, hash)
|
||
if err != nil {
|
||
response.InternalError(w, "注册失败")
|
||
return
|
||
}
|
||
|
||
tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user")
|
||
if err != nil {
|
||
response.InternalError(w, "生成Token失败")
|
||
return
|
||
}
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "access_token",
|
||
Value: tokenPair.AccessToken,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
SameSite: http.SameSiteLaxMode,
|
||
MaxAge: int(24 * time.Hour / time.Second),
|
||
})
|
||
|
||
response.JSON(w, http.StatusCreated, map[string]any{
|
||
"user": userResponse{
|
||
ID: id.String(), Name: req.Name, Email: req.Email, Role: "user",
|
||
},
|
||
"access_token": tokenPair.AccessToken,
|
||
"expires_at": tokenPair.ExpiresAt,
|
||
})
|
||
}
|
||
|
||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||
var req loginRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
response.BadRequest(w, "无效的请求格式")
|
||
return
|
||
}
|
||
|
||
if req.Email == "" || req.Password == "" {
|
||
response.BadRequest(w, "邮箱和密码不能为空")
|
||
return
|
||
}
|
||
|
||
var (
|
||
id string
|
||
name string
|
||
email string
|
||
passwordHash *string
|
||
avatarURL *string
|
||
role string
|
||
employeeID *string
|
||
status string
|
||
orgID *string
|
||
)
|
||
|
||
err := h.pool.QueryRow(r.Context(),
|
||
`SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text
|
||
FROM users WHERE email = $1`, req.Email,
|
||
).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID)
|
||
|
||
if err != nil {
|
||
response.Unauthorized(w, "邮箱或密码错误")
|
||
return
|
||
}
|
||
|
||
if status != "active" {
|
||
response.Error(w, http.StatusForbidden, 40302, "账号已被禁用")
|
||
return
|
||
}
|
||
|
||
if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) {
|
||
response.Unauthorized(w, "邮箱或密码错误")
|
||
return
|
||
}
|
||
|
||
// 平台管理员不绑定机构,可登录任意机构入口
|
||
// 普通用户/机构管理员必须属于所选机构
|
||
if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID {
|
||
response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构")
|
||
return
|
||
}
|
||
|
||
uid, _ := uuid.Parse(id)
|
||
tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role)
|
||
if err != nil {
|
||
response.InternalError(w, "生成Token失败")
|
||
return
|
||
}
|
||
|
||
// Update login info
|
||
go func() {
|
||
_, _ = h.pool.Exec(context.Background(),
|
||
`UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id)
|
||
}()
|
||
|
||
// Set cookies
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "access_token",
|
||
Value: tokenPair.AccessToken,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
SameSite: http.SameSiteLaxMode,
|
||
MaxAge: int(24 * time.Hour / time.Second),
|
||
})
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "refresh_token",
|
||
Value: tokenPair.RefreshToken,
|
||
Path: "/api/v1/auth/refresh",
|
||
HttpOnly: true,
|
||
SameSite: http.SameSiteLaxMode,
|
||
MaxAge: int(7 * 24 * time.Hour / time.Second),
|
||
})
|
||
|
||
usr := userResponse{
|
||
ID: id, Name: name, Email: email,
|
||
AvatarURL: avatarURL, Role: role, EmployeeID: employeeID,
|
||
OrgID: orgID,
|
||
}
|
||
if orgID != nil {
|
||
usr.Org = h.loadOrgInfo(r.Context(), *orgID)
|
||
}
|
||
|
||
response.JSON(w, http.StatusOK, map[string]any{
|
||
"user": usr,
|
||
"access_token": tokenPair.AccessToken,
|
||
"expires_at": tokenPair.ExpiresAt,
|
||
})
|
||
}
|
||
|
||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "access_token",
|
||
Value: "",
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
MaxAge: -1,
|
||
})
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "refresh_token",
|
||
Value: "",
|
||
Path: "/api/v1/auth/refresh",
|
||
HttpOnly: true,
|
||
MaxAge: -1,
|
||
})
|
||
response.JSON(w, http.StatusOK, nil)
|
||
}
|
||
|
||
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
||
userID := middleware.GetUserID(r.Context())
|
||
|
||
var u userResponse
|
||
err := h.pool.QueryRow(r.Context(),
|
||
`SELECT id, name, email, avatar_url, role, employee_id, org_id::text
|
||
FROM users WHERE id = $1`, userID,
|
||
).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID)
|
||
|
||
if err != nil {
|
||
response.NotFound(w, "用户不存在")
|
||
return
|
||
}
|
||
|
||
if u.OrgID != nil {
|
||
u.Org = h.loadOrgInfo(r.Context(), *u.OrgID)
|
||
}
|
||
|
||
response.JSON(w, http.StatusOK, u)
|
||
}
|
||
|
||
// ListOrganizations 返回所有可用机构列表
|
||
func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) {
|
||
rows, err := h.pool.Query(r.Context(),
|
||
`SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'')
|
||
FROM organizations WHERE is_active = true ORDER BY sort_order ASC`)
|
||
if err != nil {
|
||
response.InternalError(w, "查询机构列表失败")
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
|
||
var orgs []map[string]any
|
||
for rows.Next() {
|
||
var id, name, slug, shortName, desc, logo string
|
||
if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil {
|
||
continue
|
||
}
|
||
orgs = append(orgs, map[string]any{
|
||
"id": id, "name": name, "slug": slug,
|
||
"short_name": shortName, "description": desc, "logo_url": logo,
|
||
})
|
||
}
|
||
if orgs == nil {
|
||
orgs = []map[string]any{}
|
||
}
|
||
response.JSON(w, http.StatusOK, orgs)
|
||
}
|
||
|
||
// SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份
|
||
func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||
currentRole := middleware.GetRole(r.Context())
|
||
if currentRole != "super_admin" && currentRole != "admin" {
|
||
response.Forbidden(w, "仅管理员可以切换机构")
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
OrgID string `json:"org_id"`
|
||
}
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" {
|
||
response.BadRequest(w, "请选择机构")
|
||
return
|
||
}
|
||
|
||
// 校验机构是否存在
|
||
var exists bool
|
||
h.pool.QueryRow(r.Context(),
|
||
`SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists)
|
||
if !exists {
|
||
response.NotFound(w, "机构不存在")
|
||
return
|
||
}
|
||
|
||
// 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator)
|
||
var targetID uuid.UUID
|
||
var targetEmail, targetRole, targetName string
|
||
err := h.pool.QueryRow(r.Context(),
|
||
`SELECT id, email, role, name FROM users
|
||
WHERE org_id = $1 AND status = 'active'
|
||
ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END
|
||
LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName)
|
||
if err != nil {
|
||
response.InternalError(w, "该机构暂无可用用户")
|
||
return
|
||
}
|
||
|
||
// 为目标用户生成新的JWT token
|
||
tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole)
|
||
if err != nil {
|
||
response.InternalError(w, "生成令牌失败")
|
||
return
|
||
}
|
||
|
||
org := h.loadOrgInfo(r.Context(), req.OrgID)
|
||
response.JSON(w, http.StatusOK, map[string]any{
|
||
"message": "已切换",
|
||
"org": org,
|
||
"token": tokens.AccessToken,
|
||
"user": map[string]any{
|
||
"id": targetID,
|
||
"name": targetName,
|
||
"email": targetEmail,
|
||
"role": targetRole,
|
||
"org_id": req.OrgID,
|
||
},
|
||
})
|
||
}
|
||
|
||
func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo {
|
||
var o orgInfo
|
||
err := h.pool.QueryRow(ctx,
|
||
`SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID,
|
||
).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
return &o
|
||
}
|
||
|
||
func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
|
||
userID := middleware.GetUserID(r.Context())
|
||
|
||
var req struct {
|
||
Name string `json:"name"`
|
||
AvatarURL string `json:"avatar_url"`
|
||
}
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
response.BadRequest(w, "无效的请求格式")
|
||
return
|
||
}
|
||
|
||
_, err := h.pool.Exec(r.Context(),
|
||
`UPDATE users SET name = COALESCE(NULLIF($2,''), name),
|
||
avatar_url = COALESCE(NULLIF($3,''), avatar_url),
|
||
updated_at = NOW()
|
||
WHERE id = $1`, userID, req.Name, req.AvatarURL)
|
||
if err != nil {
|
||
response.InternalError(w, "更新个人信息失败")
|
||
return
|
||
}
|
||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||
}
|
||
|
||
func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||
cookie, err := r.Cookie("refresh_token")
|
||
if err != nil {
|
||
response.Unauthorized(w, "Refresh Token 不存在")
|
||
return
|
||
}
|
||
|
||
claims, err := h.jwtMgr.ValidateToken(cookie.Value)
|
||
if err != nil {
|
||
response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期")
|
||
return
|
||
}
|
||
|
||
// Get current role from DB
|
||
var role string
|
||
err = h.pool.QueryRow(r.Context(),
|
||
`SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID,
|
||
).Scan(&role)
|
||
if err != nil {
|
||
response.Unauthorized(w, "用户不存在或已被禁用")
|
||
return
|
||
}
|
||
|
||
tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role)
|
||
if err != nil {
|
||
response.InternalError(w, "生成Token失败")
|
||
return
|
||
}
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "access_token",
|
||
Value: tokenPair.AccessToken,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
SameSite: http.SameSiteLaxMode,
|
||
MaxAge: int(24 * time.Hour / time.Second),
|
||
})
|
||
|
||
response.JSON(w, http.StatusOK, map[string]any{
|
||
"access_token": tokenPair.AccessToken,
|
||
"expires_at": tokenPair.ExpiresAt,
|
||
})
|
||
}
|