139 lines
3.3 KiB
Go
139 lines
3.3 KiB
Go
// Package embedding 提供文本向量化服务,支持 OpenAI 兼容的 Embedding API
|
||
package embedding
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// Config embedding 服务配置
|
||
type Config struct {
|
||
APIKey string // API 密钥
|
||
BaseURL string // API 基础 URL(OpenAI 兼容格式)
|
||
Model string // 模型名称
|
||
Dimensions int // 向量维度
|
||
}
|
||
|
||
// Client embedding 客户端
|
||
type Client struct {
|
||
cfg Config
|
||
httpClient *http.Client
|
||
}
|
||
|
||
// NewClient 创建 embedding 客户端
|
||
func NewClient(cfg Config) *Client {
|
||
if cfg.BaseURL == "" {
|
||
cfg.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
}
|
||
if cfg.Model == "" {
|
||
cfg.Model = "text-embedding-v3"
|
||
}
|
||
if cfg.Dimensions == 0 {
|
||
cfg.Dimensions = 1024
|
||
}
|
||
return &Client{
|
||
cfg: cfg,
|
||
httpClient: &http.Client{
|
||
Timeout: 30 * time.Second,
|
||
},
|
||
}
|
||
}
|
||
|
||
// embeddingRequest OpenAI 兼容的 embedding 请求
|
||
type embeddingRequest struct {
|
||
Input interface{} `json:"input"`
|
||
Model string `json:"model"`
|
||
Dimensions int `json:"dimensions,omitempty"`
|
||
}
|
||
|
||
// embeddingResponse OpenAI 兼容的 embedding 响应
|
||
type embeddingResponse struct {
|
||
Data []struct {
|
||
Embedding []float32 `json:"embedding"`
|
||
Index int `json:"index"`
|
||
} `json:"data"`
|
||
Usage struct {
|
||
TotalTokens int `json:"total_tokens"`
|
||
} `json:"usage"`
|
||
}
|
||
|
||
// GetEmbedding 获取单条文本的向量嵌入
|
||
func (c *Client) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||
if c.cfg.APIKey == "" {
|
||
return nil, fmt.Errorf("embedding API key not configured")
|
||
}
|
||
|
||
text = strings.TrimSpace(text)
|
||
if text == "" {
|
||
return nil, fmt.Errorf("empty text")
|
||
}
|
||
// 截断过长文本
|
||
if len([]rune(text)) > 8000 {
|
||
text = string([]rune(text)[:8000])
|
||
}
|
||
|
||
req := embeddingRequest{
|
||
Input: text,
|
||
Model: c.cfg.Model,
|
||
Dimensions: c.cfg.Dimensions,
|
||
}
|
||
|
||
body, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/embeddings"
|
||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := c.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("embedding request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
errBody, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("embedding error (status %d): %s", resp.StatusCode, string(errBody))
|
||
}
|
||
|
||
var result embeddingResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, err
|
||
}
|
||
if len(result.Data) == 0 {
|
||
return nil, fmt.Errorf("no embedding data returned")
|
||
}
|
||
return result.Data[0].Embedding, nil
|
||
}
|
||
|
||
// GetEmbeddingBatch 批量获取文本向量嵌入
|
||
func (c *Client) GetEmbeddingBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
||
results := make([][]float32, len(texts))
|
||
for i, text := range texts {
|
||
emb, err := c.GetEmbedding(ctx, text)
|
||
if err != nil {
|
||
results[i] = nil
|
||
continue
|
||
}
|
||
results[i] = emb
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// IsConfigured 检查 embedding 服务是否已配置
|
||
func (c *Client) IsConfigured() bool {
|
||
return c.cfg.APIKey != ""
|
||
}
|