523 lines
14 KiB
Go
523 lines
14 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/enterprise-ai-platform/server/pkg/chunker"
|
|
"github.com/enterprise-ai-platform/server/pkg/embedding"
|
|
"github.com/rs/zerolog/log"
|
|
|
|
mw "github.com/enterprise-ai-platform/server/internal/middleware"
|
|
"github.com/enterprise-ai-platform/server/internal/response"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
func internalErr(w http.ResponseWriter, err error) {
|
|
response.InternalError(w, err.Error())
|
|
}
|
|
|
|
type KnowledgeHandler struct {
|
|
pool *pgxpool.Pool
|
|
embedder *embedding.Client
|
|
}
|
|
|
|
func NewKnowledgeHandler(pool *pgxpool.Pool, embedder *embedding.Client) *KnowledgeHandler {
|
|
return &KnowledgeHandler{pool: pool, embedder: embedder}
|
|
}
|
|
|
|
type knowledgeBaseRow struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Visibility string `json:"visibility"`
|
|
DocCount int `json:"document_count"`
|
|
TotalChars int64 `json:"total_chars"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
func (h *KnowledgeHandler) ListKnowledgeBases(w http.ResponseWriter, r *http.Request) {
|
|
userID := mw.GetUserID(r.Context())
|
|
userRole := mw.GetRole(r.Context())
|
|
|
|
var query string
|
|
var args []any
|
|
|
|
if userRole == "super_admin" {
|
|
// 超级管理员查看全部知识库
|
|
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
|
|
FROM knowledge_bases ORDER BY updated_at DESC`
|
|
args = []any{}
|
|
} else {
|
|
// 优先使用前端传入的 org_id 参数(切换机构后),否则从用户表获取
|
|
orgFilter := r.URL.Query().Get("org_id")
|
|
if orgFilter == "" {
|
|
var userOrg *string
|
|
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrg)
|
|
if userOrg != nil {
|
|
orgFilter = *userOrg
|
|
}
|
|
}
|
|
|
|
query = `SELECT id, name, COALESCE(description,''), visibility, doc_count, total_chars, status, created_at, updated_at
|
|
FROM knowledge_bases WHERE (owner_id = $1`
|
|
args = []any{userID}
|
|
if orgFilter != "" {
|
|
query += ` OR org_id = $2`
|
|
args = append(args, orgFilter)
|
|
}
|
|
query += `) ORDER BY updated_at DESC`
|
|
}
|
|
rows, err := h.pool.Query(r.Context(), query, args...)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []knowledgeBaseRow
|
|
for rows.Next() {
|
|
var kb knowledgeBaseRow
|
|
if err := rows.Scan(&kb.ID, &kb.Name, &kb.Description, &kb.Visibility, &kb.DocCount, &kb.TotalChars, &kb.Status, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
items = append(items, kb)
|
|
}
|
|
if items == nil {
|
|
items = []knowledgeBaseRow{}
|
|
}
|
|
response.JSON(w, http.StatusOK, items)
|
|
}
|
|
|
|
func (h *KnowledgeHandler) CreateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
|
userID := mw.GetUserID(r.Context())
|
|
|
|
var body struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Visibility string `json:"visibility"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
response.BadRequest(w, "无效的请求体")
|
|
return
|
|
}
|
|
if body.Name == "" {
|
|
response.BadRequest(w, "名称不能为空")
|
|
return
|
|
}
|
|
if body.Visibility == "" {
|
|
body.Visibility = "private"
|
|
}
|
|
|
|
// 获取用户所属机构
|
|
var userOrgID *string
|
|
_ = h.pool.QueryRow(r.Context(), `SELECT org_id::text FROM users WHERE id = $1`, userID).Scan(&userOrgID)
|
|
|
|
id := uuid.New()
|
|
_, err := h.pool.Exec(r.Context(),
|
|
`INSERT INTO knowledge_bases (id, name, description, owner_id, visibility, org_id)
|
|
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
id, body.Name, body.Description, userID, body.Visibility, userOrgID)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
|
|
response.JSON(w, http.StatusCreated, map[string]any{
|
|
"id": id.String(),
|
|
"name": body.Name,
|
|
"description": body.Description,
|
|
"visibility": body.Visibility,
|
|
})
|
|
}
|
|
|
|
func (h *KnowledgeHandler) UpdateKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
|
id, err := uuid.Parse(chi.URLParam(r, "id"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的ID")
|
|
return
|
|
}
|
|
userID := mw.GetUserID(r.Context())
|
|
|
|
var body struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
response.BadRequest(w, "无效的请求体")
|
|
return
|
|
}
|
|
|
|
tag, err := h.pool.Exec(r.Context(),
|
|
`UPDATE knowledge_bases SET name = COALESCE(NULLIF($1,''), name),
|
|
description = $2, updated_at = NOW()
|
|
WHERE id = $3 AND owner_id = $4`,
|
|
body.Name, body.Description, id, userID)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
response.NotFound(w, "知识库不存在")
|
|
return
|
|
}
|
|
response.JSON(w, http.StatusOK, map[string]string{"message": "已更新"})
|
|
}
|
|
|
|
func (h *KnowledgeHandler) DeleteKnowledgeBase(w http.ResponseWriter, r *http.Request) {
|
|
id, err := uuid.Parse(chi.URLParam(r, "id"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的ID")
|
|
return
|
|
}
|
|
userID := mw.GetUserID(r.Context())
|
|
|
|
tag, err := h.pool.Exec(r.Context(),
|
|
`DELETE FROM knowledge_bases WHERE id = $1 AND owner_id = $2`, id, userID)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
response.NotFound(w, "知识库不存在")
|
|
return
|
|
}
|
|
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
|
}
|
|
|
|
type documentRow struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"filename"`
|
|
FileType string `json:"file_type"`
|
|
FileSize int64 `json:"file_size"`
|
|
IndexingStatus string `json:"status"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
func (h *KnowledgeHandler) ListDocuments(w http.ResponseWriter, r *http.Request) {
|
|
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的ID")
|
|
return
|
|
}
|
|
|
|
rows, err := h.pool.Query(r.Context(),
|
|
`SELECT id, name, COALESCE(file_type,''), file_size, indexing_status, created_at
|
|
FROM knowledge_documents WHERE kb_id = $1 ORDER BY created_at DESC`, kbID)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var docs []documentRow
|
|
for rows.Next() {
|
|
var d documentRow
|
|
if err := rows.Scan(&d.ID, &d.Name, &d.FileType, &d.FileSize, &d.IndexingStatus, &d.CreatedAt); err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
docs = append(docs, d)
|
|
}
|
|
if docs == nil {
|
|
docs = []documentRow{}
|
|
}
|
|
response.JSON(w, http.StatusOK, docs)
|
|
}
|
|
|
|
func (h *KnowledgeHandler) UploadDocument(w http.ResponseWriter, r *http.Request) {
|
|
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的ID")
|
|
return
|
|
}
|
|
|
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
|
response.BadRequest(w, "文件过大或格式错误")
|
|
return
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
response.BadRequest(w, "请上传文件")
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
var exists bool
|
|
err = h.pool.QueryRow(r.Context(),
|
|
`SELECT EXISTS(SELECT 1 FROM knowledge_bases WHERE id = $1)`, kbID).Scan(&exists)
|
|
if err != nil || !exists {
|
|
response.NotFound(w, "知识库不存在")
|
|
return
|
|
}
|
|
|
|
fileType := ""
|
|
ext := ""
|
|
if dot := len(header.Filename) - 1; dot > 0 {
|
|
for i := dot; i >= 0; i-- {
|
|
if header.Filename[i] == '.' {
|
|
ext = header.Filename[i+1:]
|
|
break
|
|
}
|
|
}
|
|
}
|
|
switch ext {
|
|
case "pdf":
|
|
fileType = "pdf"
|
|
case "docx":
|
|
fileType = "docx"
|
|
case "txt":
|
|
fileType = "txt"
|
|
case "md":
|
|
fileType = "md"
|
|
case "csv":
|
|
fileType = "csv"
|
|
case "xlsx":
|
|
fileType = "xlsx"
|
|
default:
|
|
fileType = "txt"
|
|
}
|
|
|
|
// 读取文件内容(文本文件)
|
|
var content string
|
|
if fileType == "txt" || fileType == "md" || fileType == "csv" {
|
|
data, err := io.ReadAll(file)
|
|
if err == nil {
|
|
content = string(data)
|
|
}
|
|
}
|
|
|
|
userID := mw.GetUserID(r.Context())
|
|
docID := uuid.New()
|
|
|
|
// 计算分片数
|
|
chunkCount := 0
|
|
if content != "" {
|
|
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
|
|
chunkCount = len(chunks)
|
|
}
|
|
|
|
_, err = h.pool.Exec(r.Context(),
|
|
`INSERT INTO knowledge_documents (id, kb_id, name, file_type, file_size, uploader_id, indexing_status, content, char_count, chunk_count)
|
|
VALUES ($1, $2, $3, $4, $5, $6, 'processing', $7, $8, $9)`,
|
|
docID, kbID, header.Filename, fileType, header.Size, userID, content, utf8.RuneCountInString(content), chunkCount)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
|
|
_, _ = h.pool.Exec(r.Context(),
|
|
`UPDATE knowledge_bases SET doc_count = doc_count + 1, updated_at = NOW() WHERE id = $1`, kbID)
|
|
|
|
// 异步执行分片和向量化
|
|
go h.chunkAndEmbed(context.Background(), kbID, docID, content)
|
|
|
|
response.JSON(w, http.StatusCreated, map[string]any{
|
|
"id": docID.String(),
|
|
"filename": header.Filename,
|
|
"size": header.Size,
|
|
"chunks": chunkCount,
|
|
"status": "processing",
|
|
})
|
|
}
|
|
|
|
// chunkAndEmbed 对文档内容执行分片和向量化(异步)
|
|
func (h *KnowledgeHandler) chunkAndEmbed(ctx context.Context, kbID, docID uuid.UUID, content string) {
|
|
if content == "" {
|
|
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
|
|
return
|
|
}
|
|
|
|
chunks := chunker.ChunkText(content, chunker.DefaultOptions())
|
|
if len(chunks) == 0 {
|
|
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed' WHERE id = $1`, docID)
|
|
return
|
|
}
|
|
|
|
embeddingAvailable := h.embedder != nil && h.embedder.IsConfigured()
|
|
successCount := 0
|
|
|
|
for i, chunk := range chunks {
|
|
chunkID := uuid.New()
|
|
charCount := utf8.RuneCountInString(chunk)
|
|
|
|
if embeddingAvailable {
|
|
emb, err := h.embedder.GetEmbedding(ctx, chunk)
|
|
if err != nil {
|
|
log.Warn().Err(err).Str("doc_id", docID.String()).Int("chunk", i).Msg("embedding failed")
|
|
// 无 embedding 也插入 chunk
|
|
h.pool.Exec(ctx,
|
|
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
|
|
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
chunkID, kbID, docID, i, chunk, charCount)
|
|
} else {
|
|
// 将 []float32 转为 pgvector 格式字符串
|
|
vecStr := float32SliceToVectorStr(emb)
|
|
h.pool.Exec(ctx,
|
|
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count, embedding)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7::vector)`,
|
|
chunkID, kbID, docID, i, chunk, charCount, vecStr)
|
|
successCount++
|
|
}
|
|
} else {
|
|
// 没有 embedding 服务,只存储文本分片
|
|
h.pool.Exec(ctx,
|
|
`INSERT INTO knowledge_chunks (id, kb_id, doc_id, chunk_index, content, char_count)
|
|
VALUES ($1, $2, $3, $4, $5, $6)`,
|
|
chunkID, kbID, docID, i, chunk, charCount)
|
|
}
|
|
}
|
|
|
|
// 更新文档状态
|
|
h.pool.Exec(ctx, `UPDATE knowledge_documents SET indexing_status = 'completed', chunk_count = $2 WHERE id = $1`, docID, len(chunks))
|
|
|
|
log.Info().
|
|
Str("doc_id", docID.String()).
|
|
Int("chunks", len(chunks)).
|
|
Int("embedded", successCount).
|
|
Bool("embedding_available", embeddingAvailable).
|
|
Msg("document chunked and embedded")
|
|
}
|
|
|
|
// float32SliceToVectorStr 将 float32 切片转为 pgvector 格式字符串 "[0.1,0.2,...]"
|
|
func float32SliceToVectorStr(v []float32) string {
|
|
s := "["
|
|
for i, f := range v {
|
|
if i > 0 {
|
|
s += ","
|
|
}
|
|
s += fmt.Sprintf("%g", f)
|
|
}
|
|
s += "]"
|
|
return s
|
|
}
|
|
|
|
// ReindexAll 对所有未分片的文档执行分片和向量化(管理端点)
|
|
func (h *KnowledgeHandler) ReindexAll(w http.ResponseWriter, r *http.Request) {
|
|
rows, err := h.pool.Query(r.Context(),
|
|
`SELECT kd.id, kd.kb_id, kd.content FROM knowledge_documents kd
|
|
WHERE kd.content IS NOT NULL AND kd.content != '' AND kd.chunk_count = 0`)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
type docInfo struct {
|
|
docID uuid.UUID
|
|
kbID uuid.UUID
|
|
content string
|
|
}
|
|
var docs []docInfo
|
|
for rows.Next() {
|
|
var d docInfo
|
|
if err := rows.Scan(&d.docID, &d.kbID, &d.content); err != nil {
|
|
continue
|
|
}
|
|
docs = append(docs, d)
|
|
}
|
|
|
|
for _, d := range docs {
|
|
h.chunkAndEmbed(r.Context(), d.kbID, d.docID, d.content)
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, map[string]any{
|
|
"message": "重新索引完成",
|
|
"documents": len(docs),
|
|
})
|
|
}
|
|
|
|
// ReembedChunks 为已有分片但缺失embedding的chunks补充向量化(管理端点)
|
|
func (h *KnowledgeHandler) ReembedChunks(w http.ResponseWriter, r *http.Request) {
|
|
if h.embedder == nil || !h.embedder.IsConfigured() {
|
|
response.JSON(w, http.StatusOK, map[string]any{
|
|
"message": "embedding服务未配置",
|
|
"updated": 0,
|
|
})
|
|
return
|
|
}
|
|
|
|
rows, err := h.pool.Query(r.Context(),
|
|
`SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL AND content IS NOT NULL AND content != '' LIMIT 200`)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
type chunkInfo struct {
|
|
id uuid.UUID
|
|
content string
|
|
}
|
|
var chunks []chunkInfo
|
|
for rows.Next() {
|
|
var c chunkInfo
|
|
if err := rows.Scan(&c.id, &c.content); err != nil {
|
|
continue
|
|
}
|
|
chunks = append(chunks, c)
|
|
}
|
|
|
|
updated := 0
|
|
for _, c := range chunks {
|
|
emb, err := h.embedder.GetEmbedding(r.Context(), c.content)
|
|
if err != nil {
|
|
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("re-embed failed")
|
|
continue
|
|
}
|
|
vecStr := float32SliceToVectorStr(emb)
|
|
_, err = h.pool.Exec(r.Context(),
|
|
`UPDATE knowledge_chunks SET embedding = $1::vector WHERE id = $2`, vecStr, c.id)
|
|
if err != nil {
|
|
log.Warn().Err(err).Str("chunk_id", c.id.String()).Msg("update embedding failed")
|
|
continue
|
|
}
|
|
updated++
|
|
}
|
|
|
|
response.JSON(w, http.StatusOK, map[string]any{
|
|
"message": "向量补充完成",
|
|
"total": len(chunks),
|
|
"updated": updated,
|
|
})
|
|
}
|
|
|
|
func (h *KnowledgeHandler) DeleteDocument(w http.ResponseWriter, r *http.Request) {
|
|
kbID, err := uuid.Parse(chi.URLParam(r, "id"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的知识库ID")
|
|
return
|
|
}
|
|
docID, err := uuid.Parse(chi.URLParam(r, "docId"))
|
|
if err != nil {
|
|
response.BadRequest(w, "无效的文档ID")
|
|
return
|
|
}
|
|
|
|
tag, err := h.pool.Exec(r.Context(), `DELETE FROM knowledge_documents WHERE id = $1 AND kb_id = $2`, docID, kbID)
|
|
if err != nil {
|
|
internalErr(w, err)
|
|
return
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
response.NotFound(w, "文档不存在")
|
|
return
|
|
}
|
|
|
|
_, _ = h.pool.Exec(r.Context(),
|
|
`UPDATE knowledge_bases SET doc_count = GREATEST(doc_count - 1, 0), updated_at = NOW() WHERE id = $1`, kbID)
|
|
|
|
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
|
|
}
|