package handler import ( "context" "encoding/json" "fmt" "io" "net/http" "time" "unicode/utf8" "github.com/enterprise-ai-platform/server/pkg/chunker" "github.com/enterprise-ai-platform/server/pkg/embedding" "github.com/rs/zerolog/log" mw "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) func internalErr(w http.ResponseWriter, err error) { response.InternalError(w, err.Error()) } type KnowledgeHandler struct { pool *pgxpool.Pool embedder *embedding.Client } func NewKnowledgeHandler(pool *pgxpool.Pool, embedder *embedding.Client) *KnowledgeHandler { return &KnowledgeHandler{pool: pool, embedder: embedder} } type knowledgeBaseRow struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description"` Visibility string `json:"visibility"` DocCount int `json:"document_count"` TotalChars int64 `json:"total_chars"` Status string `json:"status"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } func (h *KnowledgeHandler) ListKnowledgeBases(w http.ResponseWriter, r *http.Request) { userID := mw.GetUserID(r.Context()) userRole := mw.GetRole(r.Context()) var query string var args []any if userRole == "super_admin" { // 超级管理员查看全部知识库 query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at FROM knowledge_bases ORDER BY updated_at DESC` args = []any{} } else { // 优先使用前端传入的 org_id 参数(切换机构后),否则从用户表获取 orgFilter := r.URL.Query().Get("org_id") if orgFilter == "" { var userOrg *string _ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrg) if userOrg != nil { orgFilter = *userOrg } } query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at FROM knowledge_bases WHERE (owner_id = $1` args = []any{userID} if orgFilter != "" { query += ` OR org_id = $2` args = append(args, orgFilter) } query += `) ORDER BY updated_at DESC` } rows, err := h.pool.Query(r.Context(), query, args...) if err != nil { internalErr(w, err) return } defer rows.Close() var items []knowledgeBaseRow for rows.Next() { var kb knowledgeBaseRow if err := rows.Scan(&kb.ID, &kb.Name, &kb.Description, &kb.Visibility, &kb.DocCount, &kb.TotalChars, &kb.Status, &kb.CreatedAt, &kb.UpdatedAt); err != nil { internalErr(w, err) return } items = append(items, kb) } if items == nil { items = []knowledgeBaseRow{} } response.JSON(w, http.StatusOK, items) } func (h *KnowledgeHandler) CreateKnowledgeBase(w http.ResponseWriter, r *http.Request) { userID := mw.GetUserID(r.Context()) var body struct { Name string `json:"name"` Description string `json:"description"` Visibility string `json:"visibility"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { response.BadRequest(w, "无效的请求体") return } if body.Name == "" { response.BadRequest(w, "名称不能为空") return } if body.Visibility == "" { body.Visibility = "private" } // 获取用户所属机构 var userOrgID *string _ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrgID) id := uuid.New() _, err := h.pool.Exec(r.Context(), `INSERT INTO knowledge_bases (id, name, description, owner_id, visibility, org_id) VALUES ($1, $2, $3, $4, $5, $6)`, id, body.Name, body.Description, userID, body.Visibility, userOrgID) if err != nil { internalErr(w, err) return } response.JSON(w, http.StatusCreated, map[string]any{ "id": id.String(), "name": body.Name, "description": body.Description, "visibility": body.Visibility, }) } func (h *KnowledgeHandler) UpdateKnowledgeBase(w http.ResponseWriter, r *http.Request) { id, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { response.BadRequest(w, "无效的ID") return } userID := mw.GetUserID(r.Context()) var body struct { Name string `json:"name"` Description string `json:"description"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { response.BadRequest(w, "无效的请求体") return } tag, err := h.pool.Exec(r.Context(), `UPDATE knowledge_bases SET name = COALESCE(NULLIF($1,''), name), description = $2, updated_at = NOW() WHERE id = $3 AND owner_id = $4`, body.Name, body.Description, id, userID) if err != nil { internalErr(w, err) return } if tag.RowsAffected() == 0 { response.NotFound(w, "知识库不存在") return } response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"}) } func (h *KnowledgeHandler) DeleteKnowledgeBase(w http.ResponseWriter, r *http.Request) { id, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { response.BadRequest(w, "无效的ID") return } userID := mw.GetUserID(r.Context()) tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_bases WHERE id = $1 AND owner_id = $2`, id, userID) if err != nil { internalErr(w, err) return } if tag.RowsAffected() == 0 { response.NotFound(w, "知识库不存在") return } response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"}) } type documentRow struct { ID string `json:"id"` Name string `json:"filename"` FileType string `json:"file_type"` FileSize int64 `json:"file_size"` IndexingStatus string `json:"status"` CreatedAt time.Time `json:"created_at"` } func (h *KnowledgeHandler) ListDocuments(w http.ResponseWriter, r *http.Request) { kbID, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { response.BadRequest(w, "无效的ID") return } rows, err := h.pool.Query(r.Context(), `SELECT id, name, COALESCE(file_type,''), file_size, indexing_status, created_at FROM knowledge_documents WHERE kb_id = $1 ORDER BY created_at DESC`, kbID) if err != nil { internalErr(w, err) return } defer rows.Close() var docs []documentRow for rows.Next() { var d documentRow if err := rows.Scan(&d.ID, &d.Name, &d.FileType, &d.FileSize, &d.IndexingStatus, &d.CreatedAt); err != nil { internalErr(w, err) return } docs = append(docs, d) } if docs == nil { docs = []documentRow{} } response.JSON(w, http.StatusOK, docs) } func (h *KnowledgeHandler) UploadDocument(w http.ResponseWriter, r *http.Request) { kbID, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { response.BadRequest(w, "无效的ID") return } if err := r.ParseMultipartForm(32 << 20); err != nil { response.BadRequest(w, "文件过大或格式错误") return } file, header, err := r.FormFile("file") if err != nil { response.BadRequest(w, "请上传文件") return } defer file.Close() var exists bool err = h.pool.QueryRow(r.Context(), `SELECT EXISTS(SELECT 1 FROM knowledge_bases WHERE id = $1)`, kbID).Scan(&exists) if err != nil || !exists { response.NotFound(w, "知识库不存在") return } fileType := "" ext := "" if dot := len(header.Filename) - 1; dot > 0 { for i := dot; i >= 0; i-- { if header.Filename[i] == '.' { ext = header.Filename[i+1:] break } } } switch ext { case "pdf": fileType = "pdf" case "docx": fileType = "docx" case "txt": fileType = "txt" case "md": fileType = "md" case "csv": fileType = "csv" case "xlsx": fileType = "xlsx" default: fileType = "txt" } // 读取文件内容(文本文件) var content string if fileType == "txt" || fileType == "md" || fileType == "csv" { data, err := io.ReadAll(file) if err == nil { content = string(data) } } userID := mw.GetUserID(r.Context()) docID := uuid.New() // 计算分片数 chunkCount := 0 if content != "" { chunks := chunker.ChunkText(content, chunker.DefaultOptions()) chunkCount = len(chunks) } _, err = h.pool.Exec(r.Context(), `INSERT INTO knowledge_documents (id, kb_id, name, file_type, file_size, uploader_id, indexing_status, content, char_count, chunk_count) VALUES ($1, $2, $3, $4, $5, $6, 'processing', $7, $8, $9)`, docID, kbID, header.Filename, fileType, header.Size, userID, content, utf8.RuneCountInString(content), chunkCount) if err != nil { internalErr(w, err) return } _, _ = h.pool.Exec(r.Context(), `UPDATE knowledge_bases SET doc_count = doc_count + 1, updated_at = NOW() WHERE id = $1`, kbID) // 异步执行分片和向量化 go h.chunkAndEmbed(context.Background(), kbID, docID, content) response.JSON(w, http.StatusCreated, map[string]any{ "id": docID.String(), "filename": header.Filename, "size": header.Size, "chunks": chunkCount, "status": "processing", }) } // chunkAndEmbed 对文档内容执行分片和向量化(异步) func (h *KnowledgeHandler) chunkAndEmbed(ctx context.Context, kbID, docID uuid.UUID, content string) { if content == "" { h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID) return } chunks := chunker.ChunkText(content, chunker.DefaultOptions()) if len(chunks) == 0 { h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID) return } embeddingAvailable := h.embedder != nil && h.embedder.IsConfigured() successCount := 0 for i, chunk := range chunks { chunkID := uuid.New() charCount := utf8.RuneCountInString(chunk) if embeddingAvailable { emb, err := h.embedder.GetEmbedding(ctx, chunk) if err != nil { log.Warn().Err(err).Str("doc_id", docID.String()).Int("chunk", i).Msg("embedding failed") // 无 embedding 也插入 chunk h.pool.Exec(ctx, `INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count) VALUES ($1, $2, $3, $4, $5, $6)`, chunkID, kbID, docID, i, chunk, charCount) } else { // 将 []float32 转为 pgvector 格式字符串 vecStr := float32SliceToVectorStr(emb) h.pool.Exec(ctx, `INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count, embedding) VALUES ($1, $2, $3, $4, $5, $6, $7::vector)`, chunkID, kbID, docID, i, chunk, charCount, vecStr) successCount++ } } else { // 没有 embedding 服务,只存储文本分片 h.pool.Exec(ctx, `INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count) VALUES ($1, $2, $3, $4, $5, $6)`, chunkID, kbID, docID, i, chunk, charCount) } } // 更新文档状态 h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed', chunk_count = $2 WHERE id = $1`, docID, len(chunks)) log.Info(). Str("doc_id", docID.String()). Int("chunks", len(chunks)). Int("embedded", successCount). Bool("embedding_available", embeddingAvailable). Msg("document chunked and embedded") } // float32SliceToVectorStr 将 float32 切片转为 pgvector 格式字符串 "[0.1,0.2,...]" func float32SliceToVectorStr(v []float32) string { s := "[" for i, f := range v { if i > 0 { s += "," } s += fmt.Sprintf("%g", f) } s += "]" return s } // ReindexAll 对所有未分片的文档执行分片和向量化(管理端点) func (h *KnowledgeHandler) ReindexAll(w http.ResponseWriter, r *http.Request) { rows, err := h.pool.Query(r.Context(), `SELECT kd.id, kd.kb_id, kd.content FROM knowledge_documents kd WHERE kd.content IS NOT NULL AND kd.content != '' AND kd.chunk_count = 0`) if err != nil { internalErr(w, err) return } defer rows.Close() type docInfo struct { docID uuid.UUID kbID uuid.UUID content string } var docs []docInfo for rows.Next() { var d docInfo if err := rows.Scan(&d.docID, &d.kbID, &d.content); err != nil { continue } docs = append(docs, d) } for _, d := range docs { h.chunkAndEmbed(r.Context(), d.kbID, d.docID, d.content) } response.JSON(w, http.StatusOK, map[string]any{ "message": "重新索引完成", "documents": len(docs), }) } // ReembedChunks 为已有分片但缺失embedding的chunks补充向量化(管理端点) func (h *KnowledgeHandler) ReembedChunks(w http.ResponseWriter, r *http.Request) { if h.embedder == nil || !h.embedder.IsConfigured() { response.JSON(w, http.StatusOK, map[string]any{ "message": "embedding服务未配置", "updated": 0, }) return } rows, err := h.pool.Query(r.Context(), `SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL AND content IS NOT NULL AND content != '' LIMIT 200`) if err != nil { internalErr(w, err) return } defer rows.Close() type chunkInfo struct { id uuid.UUID content string } var chunks []chunkInfo for rows.Next() { var c chunkInfo if err := rows.Scan(&c.id, &c.content); err != nil { continue } chunks = append(chunks, c) } updated := 0 for _, c := range chunks { emb, err := h.embedder.GetEmbedding(r.Context(), c.content) if err != nil { log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("re-embed failed") continue } vecStr := float32SliceToVectorStr(emb) _, err = h.pool.Exec(r.Context(), `UPDATE knowledge_chunks SET embedding = $1::vector WHERE id = $2`, vecStr, c.id) if err != nil { log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("update embedding failed") continue } updated++ } response.JSON(w, http.StatusOK, map[string]any{ "message": "向量补充完成", "total": len(chunks), "updated": updated, }) } func (h *KnowledgeHandler) DeleteDocument(w http.ResponseWriter, r *http.Request) { kbID, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { response.BadRequest(w, "无效的知识库ID") return } docID, err := uuid.Parse(chi.URLParam(r, "docId")) if err != nil { response.BadRequest(w, "无效的文档ID") return } tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_documents WHERE id = $1 AND kb_id = $2`, docID, kbID) if err != nil { internalErr(w, err) return } if tag.RowsAffected() == 0 { response.NotFound(w, "文档不存在") return } _, _ = h.pool.Exec(r.Context(), `UPDATE knowledge_bases SET doc_count = GREATEST(doc_count - 1, 0), updated_at = NOW() WHERE id = $1`, kbID) response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"}) }