package handler import ( "context" "encoding/json" "net/http" "time" "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/enterprise-ai-platform/server/pkg/auth" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) type AuthHandler struct { pool *pgxpool.Pool jwtMgr *auth.JWTManager } func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler { return &AuthHandler{pool: pool, jwtMgr: jwtMgr} } type loginRequest struct { Email string `json:"email"` Password string `json:"password"` OrgID string `json:"org_id"` } type orgInfo struct { ID string `json:"id"` Name string `json:"name"` Slug string `json:"slug"` ShortName string `json:"short_name"` } type userResponse struct { ID string `json:"id"` Name string `json:"name"` Email string `json:"email"` AvatarURL *string `json:"avatar_url"` Role string `json:"role"` EmployeeID *string `json:"employee_id"` OrgID *string `json:"org_id"` Org *orgInfo `json:"org,omitempty"` } type registerRequest struct { Name string `json:"name"` Email string `json:"email"` Password string `json:"password"` } func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { var req registerRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.Name == "" || req.Email == "" || req.Password == "" { response.BadRequest(w, "姓名、邮箱和密码不能为空") return } if len(req.Password) < 6 { response.BadRequest(w, "密码长度不能少于6位") return } var exists bool h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists) if exists { response.Error(w, http.StatusConflict, 40901, "该邮箱已注册") return } hash, err := auth.HashPassword(req.Password) if err != nil { response.InternalError(w, "密码加密失败") return } id := uuid.New() _, err = h.pool.Exec(r.Context(), `INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`, id, req.Name, req.Email, hash) if err != nil { response.InternalError(w, "注册失败") return } tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user") if err != nil { response.InternalError(w, "生成Token失败") return } http.SetCookie(w, &http.Cookie{ Name: "access_token", Value: tokenPair.AccessToken, Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int(24 * time.Hour / time.Second), }) response.JSON(w, http.StatusCreated, map[string]any{ "user": userResponse{ ID: id.String(), Name: req.Name, Email: req.Email, Role: "user", }, "access_token": tokenPair.AccessToken, "expires_at": tokenPair.ExpiresAt, }) } func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { var req loginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } if req.Email == "" || req.Password == "" { response.BadRequest(w, "邮箱和密码不能为空") return } var ( id string name string email string passwordHash *string avatarURL *string role string employeeID *string status string orgID *string ) err := h.pool.QueryRow(r.Context(), `SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text FROM users WHERE email = $1`, req.Email, ).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID) if err != nil { response.Unauthorized(w, "邮箱或密码错误") return } if status != "active" { response.Error(w, http.StatusForbidden, 40302, "账号已被禁用") return } if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) { response.Unauthorized(w, "邮箱或密码错误") return } // 平台管理员不绑定机构,可登录任意机构入口 // 普通用户/机构管理员必须属于所选机构 if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID { response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构") return } uid, _ := uuid.Parse(id) tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role) if err != nil { response.InternalError(w, "生成Token失败") return } // Update login info go func() { _, _ = h.pool.Exec(context.Background(), `UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id) }() // Set cookies http.SetCookie(w, &http.Cookie{ Name: "access_token", Value: tokenPair.AccessToken, Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int(24 * time.Hour / time.Second), }) http.SetCookie(w, &http.Cookie{ Name: "refresh_token", Value: tokenPair.RefreshToken, Path: "/api/v1/auth/refresh", HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int(7 * 24 * time.Hour / time.Second), }) usr := userResponse{ ID: id, Name: name, Email: email, AvatarURL: avatarURL, Role: role, EmployeeID: employeeID, OrgID: orgID, } if orgID != nil { usr.Org = h.loadOrgInfo(r.Context(), *orgID) } response.JSON(w, http.StatusOK, map[string]any{ "user": usr, "access_token": tokenPair.AccessToken, "expires_at": tokenPair.ExpiresAt, }) } func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "access_token", Value: "", Path: "/", HttpOnly: true, MaxAge: -1, }) http.SetCookie(w, &http.Cookie{ Name: "refresh_token", Value: "", Path: "/api/v1/auth/refresh", HttpOnly: true, MaxAge: -1, }) response.JSON(w, http.StatusOK, nil) } func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { userID := middleware.GetUserID(r.Context()) var u userResponse err := h.pool.QueryRow(r.Context(), `SELECT id, name, email, avatar_url, role, employee_id, org_id::text FROM users WHERE id = $1`, userID, ).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID) if err != nil { response.NotFound(w, "用户不存在") return } if u.OrgID != nil { u.Org = h.loadOrgInfo(r.Context(), *u.OrgID) } response.JSON(w, http.StatusOK, u) } // ListOrganizations 返回所有可用机构列表 func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) { rows, err := h.pool.Query(r.Context(), `SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'') FROM organizations WHERE is_active = true ORDER BY sort_order ASC`) if err != nil { response.InternalError(w, "查询机构列表失败") return } defer rows.Close() var orgs []map[string]any for rows.Next() { var id, name, slug, shortName, desc, logo string if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil { continue } orgs = append(orgs, map[string]any{ "id": id, "name": name, "slug": slug, "short_name": shortName, "description": desc, "logo_url": logo, }) } if orgs == nil { orgs = []map[string]any{} } response.JSON(w, http.StatusOK, orgs) } // SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份 func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) { currentRole := middleware.GetRole(r.Context()) if currentRole != "super_admin" && currentRole != "admin" { response.Forbidden(w, "仅管理员可以切换机构") return } var req struct { OrgID string `json:"org_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" { response.BadRequest(w, "请选择机构") return } // 校验机构是否存在 var exists bool h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists) if !exists { response.NotFound(w, "机构不存在") return } // 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator) var targetID uuid.UUID var targetEmail, targetRole, targetName string err := h.pool.QueryRow(r.Context(), `SELECT id, email, role, name FROM users WHERE org_id = $1 AND status = 'active' ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName) if err != nil { response.InternalError(w, "该机构暂无可用用户") return } // 为目标用户生成新的JWT token tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole) if err != nil { response.InternalError(w, "生成令牌失败") return } org := h.loadOrgInfo(r.Context(), req.OrgID) response.JSON(w, http.StatusOK, map[string]any{ "message": "已切换", "org": org, "token": tokens.AccessToken, "user": map[string]any{ "id": targetID, "name": targetName, "email": targetEmail, "role": targetRole, "org_id": req.OrgID, }, }) } func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo { var o orgInfo err := h.pool.QueryRow(ctx, `SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID, ).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName) if err != nil { return nil } return &o } func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) { userID := middleware.GetUserID(r.Context()) var req struct { Name string `json:"name"` AvatarURL string `json:"avatar_url"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效的请求格式") return } _, err := h.pool.Exec(r.Context(), `UPDATE users SET name = COALESCE(NULLIF($2,''), name), avatar_url = COALESCE(NULLIF($3,''), avatar_url), updated_at = NOW() WHERE id = $1`, userID, req.Name, req.AvatarURL) if err != nil { response.InternalError(w, "更新个人信息失败") return } response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"}) } func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("refresh_token") if err != nil { response.Unauthorized(w, "Refresh Token 不存在") return } claims, err := h.jwtMgr.ValidateToken(cookie.Value) if err != nil { response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期") return } // Get current role from DB var role string err = h.pool.QueryRow(r.Context(), `SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID, ).Scan(&role) if err != nil { response.Unauthorized(w, "用户不存在或已被禁用") return } tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role) if err != nil { response.InternalError(w, "生成Token失败") return } http.SetCookie(w, &http.Cookie{ Name: "access_token", Value: tokenPair.AccessToken, Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int(24 * time.Hour / time.Second), }) response.JSON(w, http.StatusOK, map[string]any{ "access_token": tokenPair.AccessToken, "expires_at": tokenPair.ExpiresAt, }) }