Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
// Package embedding 提供文本向量化服务,支持 OpenAI 兼容的 Embedding API
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config embedding 服务配置
|
||||
type Config struct {
|
||||
APIKey string // API 密钥
|
||||
BaseURL string // API 基础 URL(OpenAI 兼容格式)
|
||||
Model string // 模型名称
|
||||
Dimensions int // 向量维度
|
||||
}
|
||||
|
||||
// Client embedding 客户端
|
||||
type Client struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建 embedding 客户端
|
||||
func NewClient(cfg Config) *Client {
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = "text-embedding-v3"
|
||||
}
|
||||
if cfg.Dimensions == 0 {
|
||||
cfg.Dimensions = 1024
|
||||
}
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// embeddingRequest OpenAI 兼容的 embedding 请求
|
||||
type embeddingRequest struct {
|
||||
Input interface{} `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// embeddingResponse OpenAI 兼容的 embedding 响应
|
||||
type embeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// GetEmbedding 获取单条文本的向量嵌入
|
||||
func (c *Client) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
if c.cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("embedding API key not configured")
|
||||
}
|
||||
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("empty text")
|
||||
}
|
||||
// 截断过长文本
|
||||
if len([]rune(text)) > 8000 {
|
||||
text = string([]rune(text)[:8000])
|
||||
}
|
||||
|
||||
req := embeddingRequest{
|
||||
Input: text,
|
||||
Model: c.cfg.Model,
|
||||
Dimensions: c.cfg.Dimensions,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/embeddings"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("embedding error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var result embeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embedding data returned")
|
||||
}
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// GetEmbeddingBatch 批量获取文本向量嵌入
|
||||
func (c *Client) GetEmbeddingBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
results := make([][]float32, len(texts))
|
||||
for i, text := range texts {
|
||||
emb, err := c.GetEmbedding(ctx, text)
|
||||
if err != nil {
|
||||
results[i] = nil
|
||||
continue
|
||||
}
|
||||
results[i] = emb
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// IsConfigured 检查 embedding 服务是否已配置
|
||||
func (c *Client) IsConfigured() bool {
|
||||
return c.cfg.APIKey != ""
|
||||
}
|
||||
Reference in New Issue
Block a user