Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,418 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/auth"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
jwtMgr *auth.JWTManager
|
||||
}
|
||||
|
||||
func NewAuthHandler(pool *pgxpool.Pool, jwtMgr *auth.JWTManager) *AuthHandler {
|
||||
return &AuthHandler{pool: pool, jwtMgr: jwtMgr}
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
|
||||
type orgInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
ShortName string `json:"short_name"`
|
||||
}
|
||||
|
||||
type userResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Role string `json:"role"`
|
||||
EmployeeID *string `json:"employee_id"`
|
||||
OrgID *string `json:"org_id"`
|
||||
Org *orgInfo `json:"org,omitempty"`
|
||||
}
|
||||
|
||||
type registerRequest struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
if req.Name == "" || req.Email == "" || req.Password == "" {
|
||||
response.BadRequest(w, "姓名、邮箱和密码不能为空")
|
||||
return
|
||||
}
|
||||
if len(req.Password) < 6 {
|
||||
response.BadRequest(w, "密码长度不能少于6位")
|
||||
return
|
||||
}
|
||||
|
||||
var exists bool
|
||||
h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, req.Email).Scan(&exists)
|
||||
if exists {
|
||||
response.Error(w, http.StatusConflict, 40901, "该邮箱已注册")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
response.InternalError(w, "密码加密失败")
|
||||
return
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
_, err = h.pool.Exec(r.Context(),
|
||||
`INSERT INTO users (id, name, email, password_hash, role, status) VALUES ($1, $2, $3, $4, 'user', 'active')`,
|
||||
id, req.Name, req.Email, hash)
|
||||
if err != nil {
|
||||
response.InternalError(w, "注册失败")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(id, req.Email, "user")
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
response.JSON(w, http.StatusCreated, map[string]any{
|
||||
"user": userResponse{
|
||||
ID: id.String(), Name: req.Name, Email: req.Email, Role: "user",
|
||||
},
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.Password == "" {
|
||||
response.BadRequest(w, "邮箱和密码不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
id string
|
||||
name string
|
||||
email string
|
||||
passwordHash *string
|
||||
avatarURL *string
|
||||
role string
|
||||
employeeID *string
|
||||
status string
|
||||
orgID *string
|
||||
)
|
||||
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, name, email, password_hash, avatar_url, role, employee_id, status, org_id::text
|
||||
FROM users WHERE email = $1`, req.Email,
|
||||
).Scan(&id, &name, &email, &passwordHash, &avatarURL, &role, &employeeID, &status, &orgID)
|
||||
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "邮箱或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
if status != "active" {
|
||||
response.Error(w, http.StatusForbidden, 40302, "账号已被禁用")
|
||||
return
|
||||
}
|
||||
|
||||
if passwordHash == nil || !auth.CheckPassword(req.Password, *passwordHash) {
|
||||
response.Unauthorized(w, "邮箱或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 平台管理员不绑定机构,可登录任意机构入口
|
||||
// 普通用户/机构管理员必须属于所选机构
|
||||
if role != "super_admin" && req.OrgID != "" && orgID != nil && *orgID != req.OrgID {
|
||||
response.Unauthorized(w, "该账号不属于所选机构,请选择正确的机构")
|
||||
return
|
||||
}
|
||||
|
||||
uid, _ := uuid.Parse(id)
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(uid, email, role)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
// Update login info
|
||||
go func() {
|
||||
_, _ = h.pool.Exec(context.Background(),
|
||||
`UPDATE users SET last_login_at = NOW(), login_count = login_count + 1 WHERE id = $1`, id)
|
||||
}()
|
||||
|
||||
// Set cookies
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: tokenPair.RefreshToken,
|
||||
Path: "/api/v1/auth/refresh",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(7 * 24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
usr := userResponse{
|
||||
ID: id, Name: name, Email: email,
|
||||
AvatarURL: avatarURL, Role: role, EmployeeID: employeeID,
|
||||
OrgID: orgID,
|
||||
}
|
||||
if orgID != nil {
|
||||
usr.Org = h.loadOrgInfo(r.Context(), *orgID)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"user": usr,
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
MaxAge: -1,
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/api/v1/auth/refresh",
|
||||
HttpOnly: true,
|
||||
MaxAge: -1,
|
||||
})
|
||||
response.JSON(w, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var u userResponse
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, name, email, avatar_url, role, employee_id, org_id::text
|
||||
FROM users WHERE id = $1`, userID,
|
||||
).Scan(&u.ID, &u.Name, &u.Email, &u.AvatarURL, &u.Role, &u.EmployeeID, &u.OrgID)
|
||||
|
||||
if err != nil {
|
||||
response.NotFound(w, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if u.OrgID != nil {
|
||||
u.Org = h.loadOrgInfo(r.Context(), *u.OrgID)
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, u)
|
||||
}
|
||||
|
||||
// ListOrganizations 返回所有可用机构列表
|
||||
func (h *AuthHandler) ListOrganizations(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.pool.Query(r.Context(),
|
||||
`SELECT id, name, slug, COALESCE(short_name,''), COALESCE(description,''), COALESCE(logo_url,'')
|
||||
FROM organizations WHERE is_active = true ORDER BY sort_order ASC`)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询机构列表失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var orgs []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, slug, shortName, desc, logo string
|
||||
if rows.Scan(&id, &name, &slug, &shortName, &desc, &logo) != nil {
|
||||
continue
|
||||
}
|
||||
orgs = append(orgs, map[string]any{
|
||||
"id": id, "name": name, "slug": slug,
|
||||
"short_name": shortName, "description": desc, "logo_url": logo,
|
||||
})
|
||||
}
|
||||
if orgs == nil {
|
||||
orgs = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, orgs)
|
||||
}
|
||||
|
||||
// SwitchOrg 切换机构:管理员及以上可用,自动切换为目标机构的管理员身份
|
||||
func (h *AuthHandler) SwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
currentRole := middleware.GetRole(r.Context())
|
||||
if currentRole != "super_admin" && currentRole != "admin" {
|
||||
response.Forbidden(w, "仅管理员可以切换机构")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
OrgID string `json:"org_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.OrgID == "" {
|
||||
response.BadRequest(w, "请选择机构")
|
||||
return
|
||||
}
|
||||
|
||||
// 校验机构是否存在
|
||||
var exists bool
|
||||
h.pool.QueryRow(r.Context(),
|
||||
`SELECT EXISTS(SELECT 1 FROM organizations WHERE id = $1 AND is_active = true)`, req.OrgID).Scan(&exists)
|
||||
if !exists {
|
||||
response.NotFound(w, "机构不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 查找目标机构的管理员用户(优先admin,其次super_admin,最后creator)
|
||||
var targetID uuid.UUID
|
||||
var targetEmail, targetRole, targetName string
|
||||
err := h.pool.QueryRow(r.Context(),
|
||||
`SELECT id, email, role, name FROM users
|
||||
WHERE org_id = $1 AND status = 'active'
|
||||
ORDER BY CASE role WHEN 'admin' THEN 1 WHEN 'super_admin' THEN 2 WHEN 'creator' THEN 3 ELSE 4 END
|
||||
LIMIT 1`, req.OrgID).Scan(&targetID, &targetEmail, &targetRole, &targetName)
|
||||
if err != nil {
|
||||
response.InternalError(w, "该机构暂无可用用户")
|
||||
return
|
||||
}
|
||||
|
||||
// 为目标用户生成新的JWT token
|
||||
tokens, err := h.jwtMgr.GenerateTokenPair(targetID, targetEmail, targetRole)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成令牌失败")
|
||||
return
|
||||
}
|
||||
|
||||
org := h.loadOrgInfo(r.Context(), req.OrgID)
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"message": "已切换",
|
||||
"org": org,
|
||||
"token": tokens.AccessToken,
|
||||
"user": map[string]any{
|
||||
"id": targetID,
|
||||
"name": targetName,
|
||||
"email": targetEmail,
|
||||
"role": targetRole,
|
||||
"org_id": req.OrgID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) loadOrgInfo(ctx context.Context, orgID string) *orgInfo {
|
||||
var o orgInfo
|
||||
err := h.pool.QueryRow(ctx,
|
||||
`SELECT id, name, slug, COALESCE(short_name,'') FROM organizations WHERE id = $1`, orgID,
|
||||
).Scan(&o.ID, &o.Name, &o.Slug, &o.ShortName)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &o
|
||||
}
|
||||
|
||||
func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) {
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效的请求格式")
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.pool.Exec(r.Context(),
|
||||
`UPDATE users SET name = COALESCE(NULLIF($2,''), name),
|
||||
avatar_url = COALESCE(NULLIF($3,''), avatar_url),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`, userID, req.Name, req.AvatarURL)
|
||||
if err != nil {
|
||||
response.InternalError(w, "更新个人信息失败")
|
||||
return
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "Refresh Token 不存在")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := h.jwtMgr.ValidateToken(cookie.Value)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusUnauthorized, 40102, "Refresh Token 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// Get current role from DB
|
||||
var role string
|
||||
err = h.pool.QueryRow(r.Context(),
|
||||
`SELECT role FROM users WHERE id = $1 AND status = 'active'`, claims.UserID,
|
||||
).Scan(&role)
|
||||
if err != nil {
|
||||
response.Unauthorized(w, "用户不存在或已被禁用")
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.jwtMgr.GenerateTokenPair(claims.UserID, claims.Email, role)
|
||||
if err != nil {
|
||||
response.InternalError(w, "生成Token失败")
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "access_token",
|
||||
Value: tokenPair.AccessToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(24 * time.Hour / time.Second),
|
||||
})
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"expires_at": tokenPair.ExpiresAt,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user