Files
GovAI/server/internal/handler/knowledge.go
T
2026-06-15 23:48:37 +08:00

523 lines
14 KiB
Go

package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"unicode/utf8"
"github.com/enterprise-ai-platform/server/pkg/chunker"
"github.com/enterprise-ai-platform/server/pkg/embedding"
"github.com/rs/zerolog/log"
mw "github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
func internalErr(w http.ResponseWriter, err error) {
response.InternalError(w, err.Error())
}
type KnowledgeHandler struct {
pool *pgxpool.Pool
embedder *embedding.Client
}
func NewKnowledgeHandler(pool *pgxpool.Pool, embedder *embedding.Client) *KnowledgeHandler {
return &KnowledgeHandler{pool: pool, embedder: embedder}
}
type knowledgeBaseRow struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Visibility string `json:"visibility"`
DocCount int `json:"document_count"`
TotalChars int64 `json:"total_chars"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (h *KnowledgeHandler) ListKnowledgeBases(w http.ResponseWriter, r *http.Request) {
userID := mw.GetUserID(r.Context())
userRole := mw.GetRole(r.Context())
var query string
var args []any
if userRole == "super_admin" {
// 超级管理员查看全部知识库
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
FROM knowledge_bases ORDER BY updated_at DESC`
args = []any{}
} else {
// 优先使用前端传入的 org_id 参数(切换机构后),否则从用户表获取
orgFilter := r.URL.Query().Get("org_id")
if orgFilter == "" {
var userOrg *string
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrg)
if userOrg != nil {
orgFilter = *userOrg
}
}
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
FROM knowledge_bases WHERE (owner_id = $1`
args = []any{userID}
if orgFilter != "" {
query += ` OR org_id = $2`
args = append(args, orgFilter)
}
query += `) ORDER BY updated_at DESC`
}
rows, err := h.pool.Query(r.Context(), query, args...)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
var items []knowledgeBaseRow
for rows.Next() {
var kb knowledgeBaseRow
if err := rows.Scan(&kb.ID, &kb.Name, &kb.Description, &kb.Visibility, &kb.DocCount, &kb.TotalChars, &kb.Status, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
internalErr(w, err)
return
}
items = append(items, kb)
}
if items == nil {
items = []knowledgeBaseRow{}
}
response.JSON(w, http.StatusOK, items)
}
func (h *KnowledgeHandler) CreateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
userID := mw.GetUserID(r.Context())
var body struct {
Name string `json:"name"`
Description string `json:"description"`
Visibility string `json:"visibility"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
response.BadRequest(w, "无效的请求体")
return
}
if body.Name == "" {
response.BadRequest(w, "名称不能为空")
return
}
if body.Visibility == "" {
body.Visibility = "private"
}
// 获取用户所属机构
var userOrgID *string
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrgID)
id := uuid.New()
_, err := h.pool.Exec(r.Context(),
`INSERT INTO knowledge_bases (id, name, description, owner_id, visibility, org_id)
VALUES ($1, $2, $3, $4, $5, $6)`,
id, body.Name, body.Description, userID, body.Visibility, userOrgID)
if err != nil {
internalErr(w, err)
return
}
response.JSON(w, http.StatusCreated, map[string]any{
"id": id.String(),
"name": body.Name,
"description": body.Description,
"visibility": body.Visibility,
})
}
func (h *KnowledgeHandler) UpdateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
id, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
userID := mw.GetUserID(r.Context())
var body struct {
Name string `json:"name"`
Description string `json:"description"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
response.BadRequest(w, "无效的请求体")
return
}
tag, err := h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET name = COALESCE(NULLIF($1,''), name),
description = $2, updated_at = NOW()
WHERE id = $3 AND owner_id = $4`,
body.Name, body.Description, id, userID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "知识库不存在")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
}
func (h *KnowledgeHandler) DeleteKnowledgeBase(w http.ResponseWriter, r *http.Request) {
id, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
userID := mw.GetUserID(r.Context())
tag, err := h.pool.Exec(r.Context(),
`DELETE FROM knowledge_bases WHERE id = $1 AND owner_id = $2`, id, userID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "知识库不存在")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
type documentRow struct {
ID string `json:"id"`
Name string `json:"filename"`
FileType string `json:"file_type"`
FileSize int64 `json:"file_size"`
IndexingStatus string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
func (h *KnowledgeHandler) ListDocuments(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
rows, err := h.pool.Query(r.Context(),
`SELECT id, name, COALESCE(file_type,''), file_size, indexing_status, created_at
FROM knowledge_documents WHERE kb_id = $1 ORDER BY created_at DESC`, kbID)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
var docs []documentRow
for rows.Next() {
var d documentRow
if err := rows.Scan(&d.ID, &d.Name, &d.FileType, &d.FileSize, &d.IndexingStatus, &d.CreatedAt); err != nil {
internalErr(w, err)
return
}
docs = append(docs, d)
}
if docs == nil {
docs = []documentRow{}
}
response.JSON(w, http.StatusOK, docs)
}
func (h *KnowledgeHandler) UploadDocument(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的ID")
return
}
if err := r.ParseMultipartForm(32 << 20); err != nil {
response.BadRequest(w, "文件过大或格式错误")
return
}
file, header, err := r.FormFile("file")
if err != nil {
response.BadRequest(w, "请上传文件")
return
}
defer file.Close()
var exists bool
err = h.pool.QueryRow(r.Context(),
`SELECT EXISTS(SELECT 1 FROM knowledge_bases WHERE id = $1)`, kbID).Scan(&exists)
if err != nil || !exists {
response.NotFound(w, "知识库不存在")
return
}
fileType := ""
ext := ""
if dot := len(header.Filename) - 1; dot > 0 {
for i := dot; i >= 0; i-- {
if header.Filename[i] == '.' {
ext = header.Filename[i+1:]
break
}
}
}
switch ext {
case "pdf":
fileType = "pdf"
case "docx":
fileType = "docx"
case "txt":
fileType = "txt"
case "md":
fileType = "md"
case "csv":
fileType = "csv"
case "xlsx":
fileType = "xlsx"
default:
fileType = "txt"
}
// 读取文件内容(文本文件)
var content string
if fileType == "txt" || fileType == "md" || fileType == "csv" {
data, err := io.ReadAll(file)
if err == nil {
content = string(data)
}
}
userID := mw.GetUserID(r.Context())
docID := uuid.New()
// 计算分片数
chunkCount := 0
if content != "" {
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
chunkCount = len(chunks)
}
_, err = h.pool.Exec(r.Context(),
`INSERT INTO knowledge_documents (id, kb_id, name, file_type, file_size, uploader_id, indexing_status, content, char_count, chunk_count)
VALUES ($1, $2, $3, $4, $5, $6, 'processing', $7, $8, $9)`,
docID, kbID, header.Filename, fileType, header.Size, userID, content, utf8.RuneCountInString(content), chunkCount)
if err != nil {
internalErr(w, err)
return
}
_, _ = h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET doc_count = doc_count + 1, updated_at = NOW() WHERE id = $1`, kbID)
// 异步执行分片和向量化
go h.chunkAndEmbed(context.Background(), kbID, docID, content)
response.JSON(w, http.StatusCreated, map[string]any{
"id": docID.String(),
"filename": header.Filename,
"size": header.Size,
"chunks": chunkCount,
"status": "processing",
})
}
// chunkAndEmbed 对文档内容执行分片和向量化(异步)
func (h *KnowledgeHandler) chunkAndEmbed(ctx context.Context, kbID, docID uuid.UUID, content string) {
if content == "" {
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
return
}
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
if len(chunks) == 0 {
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
return
}
embeddingAvailable := h.embedder != nil && h.embedder.IsConfigured()
successCount := 0
for i, chunk := range chunks {
chunkID := uuid.New()
charCount := utf8.RuneCountInString(chunk)
if embeddingAvailable {
emb, err := h.embedder.GetEmbedding(ctx, chunk)
if err != nil {
log.Warn().Err(err).Str("doc_id", docID.String()).Int("chunk", i).Msg("embedding failed")
// 无 embedding 也插入 chunk
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
VALUES ($1, $2, $3, $4, $5, $6)`,
chunkID, kbID, docID, i, chunk, charCount)
} else {
// 将 []float32 转为 pgvector 格式字符串
vecStr := float32SliceToVectorStr(emb)
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)`,
chunkID, kbID, docID, i, chunk, charCount, vecStr)
successCount++
}
} else {
// 没有 embedding 服务,只存储文本分片
h.pool.Exec(ctx,
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
VALUES ($1, $2, $3, $4, $5, $6)`,
chunkID, kbID, docID, i, chunk, charCount)
}
}
// 更新文档状态
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed', chunk_count = $2 WHERE id = $1`, docID, len(chunks))
log.Info().
Str("doc_id", docID.String()).
Int("chunks", len(chunks)).
Int("embedded", successCount).
Bool("embedding_available", embeddingAvailable).
Msg("document chunked and embedded")
}
// float32SliceToVectorStr 将 float32 切片转为 pgvector 格式字符串 "[0.1,0.2,...]"
func float32SliceToVectorStr(v []float32) string {
s := "["
for i, f := range v {
if i > 0 {
s += ","
}
s += fmt.Sprintf("%g", f)
}
s += "]"
return s
}
// ReindexAll 对所有未分片的文档执行分片和向量化(管理端点)
func (h *KnowledgeHandler) ReindexAll(w http.ResponseWriter, r *http.Request) {
rows, err := h.pool.Query(r.Context(),
`SELECT kd.id, kd.kb_id, kd.content FROM knowledge_documents kd
WHERE kd.content IS NOT NULL AND kd.content != '' AND kd.chunk_count = 0`)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
type docInfo struct {
docID uuid.UUID
kbID uuid.UUID
content string
}
var docs []docInfo
for rows.Next() {
var d docInfo
if err := rows.Scan(&d.docID, &d.kbID, &d.content); err != nil {
continue
}
docs = append(docs, d)
}
for _, d := range docs {
h.chunkAndEmbed(r.Context(), d.kbID, d.docID, d.content)
}
response.JSON(w, http.StatusOK, map[string]any{
"message": "重新索引完成",
"documents": len(docs),
})
}
// ReembedChunks 为已有分片但缺失embedding的chunks补充向量化(管理端点)
func (h *KnowledgeHandler) ReembedChunks(w http.ResponseWriter, r *http.Request) {
if h.embedder == nil || !h.embedder.IsConfigured() {
response.JSON(w, http.StatusOK, map[string]any{
"message": "embedding服务未配置",
"updated": 0,
})
return
}
rows, err := h.pool.Query(r.Context(),
`SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL AND content IS NOT NULL AND content != '' LIMIT 200`)
if err != nil {
internalErr(w, err)
return
}
defer rows.Close()
type chunkInfo struct {
id uuid.UUID
content string
}
var chunks []chunkInfo
for rows.Next() {
var c chunkInfo
if err := rows.Scan(&c.id, &c.content); err != nil {
continue
}
chunks = append(chunks, c)
}
updated := 0
for _, c := range chunks {
emb, err := h.embedder.GetEmbedding(r.Context(), c.content)
if err != nil {
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("re-embed failed")
continue
}
vecStr := float32SliceToVectorStr(emb)
_, err = h.pool.Exec(r.Context(),
`UPDATE knowledge_chunks SET embedding = $1::vector WHERE id = $2`, vecStr, c.id)
if err != nil {
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("update embedding failed")
continue
}
updated++
}
response.JSON(w, http.StatusOK, map[string]any{
"message": "向量补充完成",
"total": len(chunks),
"updated": updated,
})
}
func (h *KnowledgeHandler) DeleteDocument(w http.ResponseWriter, r *http.Request) {
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil {
response.BadRequest(w, "无效的知识库ID")
return
}
docID, err := uuid.Parse(chi.URLParam(r, "docId"))
if err != nil {
response.BadRequest(w, "无效的文档ID")
return
}
tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_documents WHERE id = $1 AND kb_id = $2`, docID, kbID)
if err != nil {
internalErr(w, err)
return
}
if tag.RowsAffected() == 0 {
response.NotFound(w, "文档不存在")
return
}
_, _ = h.pool.Exec(r.Context(),
`UPDATE knowledge_bases SET doc_count = GREATEST(doc_count - 1, 0), updated_at = NOW() WHERE id = $1`, kbID)
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}