97 lines
2.2 KiB
Go
97 lines
2.2 KiB
Go
// 批量为 knowledge_chunks 生成 embedding 向量
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/enterprise-ai-platform/server/internal/config"
|
|
"github.com/enterprise-ai-platform/server/pkg/embedding"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
func main() {
|
|
cfg := config.Load()
|
|
|
|
ctx := context.Background()
|
|
pool, err := pgxpool.New(ctx, cfg.Database.URL)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "数据库连接失败: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer pool.Close()
|
|
|
|
client := embedding.NewClient(embedding.Config{
|
|
APIKey: cfg.Embedding.APIKey,
|
|
BaseURL: cfg.Embedding.BaseURL,
|
|
Model: cfg.Embedding.Model,
|
|
Dimensions: cfg.Embedding.Dimensions,
|
|
})
|
|
|
|
if !client.IsConfigured() {
|
|
fmt.Fprintln(os.Stderr, "EMBEDDING_API_KEY 未配置")
|
|
os.Exit(1)
|
|
}
|
|
|
|
// 查询所有没有 embedding 的 chunks
|
|
rows, err := pool.Query(ctx,
|
|
`SELECT id, content FROM knowledge_chunks WHERE embedding IS NULL ORDER BY created_at`)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "查询失败: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer rows.Close()
|
|
|
|
type chunk struct {
|
|
id string
|
|
content string
|
|
}
|
|
var chunks []chunk
|
|
for rows.Next() {
|
|
var c chunk
|
|
if err := rows.Scan(&c.id, &c.content); err != nil {
|
|
continue
|
|
}
|
|
chunks = append(chunks, c)
|
|
}
|
|
|
|
fmt.Printf("共 %d 个 chunks 需要生成 embedding\n", len(chunks))
|
|
|
|
success := 0
|
|
for i, c := range chunks {
|
|
emb, err := client.GetEmbedding(ctx, c.content)
|
|
if err != nil {
|
|
fmt.Printf("[%d/%d] ❌ %s: %v\n", i+1, len(chunks), c.id[:8], err)
|
|
time.Sleep(500 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
// 转为 pgvector 格式
|
|
vecStr := "["
|
|
for j, f := range emb {
|
|
if j > 0 {
|
|
vecStr += ","
|
|
}
|
|
vecStr += fmt.Sprintf("%g", f)
|
|
}
|
|
vecStr += "]"
|
|
|
|
_, err = pool.Exec(ctx,
|
|
`UPDATE knowledge_chunks SET embedding = $2::vector WHERE id = $1`,
|
|
c.id, vecStr)
|
|
if err != nil {
|
|
fmt.Printf("[%d/%d] ❌ 写入失败 %s: %v\n", i+1, len(chunks), c.id[:8], err)
|
|
} else {
|
|
success++
|
|
fmt.Printf("[%d/%d] ✅ %s (dim=%d)\n", i+1, len(chunks), c.id[:8], len(emb))
|
|
}
|
|
|
|
// 避免 API 限流
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
fmt.Printf("\n完成!成功: %d/%d\n", success, len(chunks))
|
|
}
|