Files
GovAI/server/pkg/embedding/embedding.go
T
2026-06-15 23:48:37 +08:00

139 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package embedding 提供文本向量化服务,支持 OpenAI 兼容的 Embedding API
package embedding
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// Config embedding 服务配置
type Config struct {
APIKey string // API 密钥
BaseURL string // API 基础 URLOpenAI 兼容格式)
Model string // 模型名称
Dimensions int // 向量维度
}
// Client embedding 客户端
type Client struct {
cfg Config
httpClient *http.Client
}
// NewClient 创建 embedding 客户端
func NewClient(cfg Config) *Client {
if cfg.BaseURL == "" {
cfg.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
if cfg.Model == "" {
cfg.Model = "text-embedding-v3"
}
if cfg.Dimensions == 0 {
cfg.Dimensions = 1024
}
return &Client{
cfg: cfg,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// embeddingRequest OpenAI 兼容的 embedding 请求
type embeddingRequest struct {
Input interface{} `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
}
// embeddingResponse OpenAI 兼容的 embedding 响应
type embeddingResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// GetEmbedding 获取单条文本的向量嵌入
func (c *Client) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
if c.cfg.APIKey == "" {
return nil, fmt.Errorf("embedding API key not configured")
}
text = strings.TrimSpace(text)
if text == "" {
return nil, fmt.Errorf("empty text")
}
// 截断过长文本
if len([]rune(text)) > 8000 {
text = string([]rune(text)[:8000])
}
req := embeddingRequest{
Input: text,
Model: c.cfg.Model,
Dimensions: c.cfg.Dimensions,
}
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("embedding request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("embedding error (status %d): %s", resp.StatusCode, string(errBody))
}
var result embeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no embedding data returned")
}
return result.Data[0].Embedding, nil
}
// GetEmbeddingBatch 批量获取文本向量嵌入
func (c *Client) GetEmbeddingBatch(ctx context.Context, texts []string) ([][]float32, error) {
results := make([][]float32, len(texts))
for i, text := range texts {
emb, err := c.GetEmbedding(ctx, text)
if err != nil {
results[i] = nil
continue
}
results[i] = emb
}
return results, nil
}
// IsConfigured 检查 embedding 服务是否已配置
func (c *Client) IsConfigured() bool {
return c.cfg.APIKey != ""
}