Initial commit: GovAI 政务AI平台

This commit is contained in:
freedakgmail
2026-06-15 23:48:37 +08:00
commit 0f490f72a9
245 changed files with 51669 additions and 0 deletions
+97
View File
@@ -0,0 +1,97 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt int64 `json:"expires_at"`
}
type Claims struct {
UserID uuid.UUID `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
type JWTManager struct {
secret []byte
accessExpiry time.Duration
refreshExpiry time.Duration
}
func NewJWTManager(secret string, accessExpiry, refreshExpiry time.Duration) *JWTManager {
return &JWTManager{
secret: []byte(secret),
accessExpiry: accessExpiry,
refreshExpiry: refreshExpiry,
}
}
func (m *JWTManager) GenerateTokenPair(userID uuid.UUID, email, role string) (*TokenPair, error) {
now := time.Now()
accessExp := now.Add(m.accessExpiry)
accessClaims := &Claims{
UserID: userID,
Email: email,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(accessExp),
IssuedAt: jwt.NewNumericDate(now),
Subject: userID.String(),
},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessStr, err := accessToken.SignedString(m.secret)
if err != nil {
return nil, fmt.Errorf("sign access token: %w", err)
}
refreshClaims := &Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(m.refreshExpiry)),
IssuedAt: jwt.NewNumericDate(now),
Subject: userID.String(),
},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshStr, err := refreshToken.SignedString(m.secret)
if err != nil {
return nil, fmt.Errorf("sign refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessStr,
RefreshToken: refreshStr,
ExpiresAt: accessExp.Unix(),
}, nil
}
func (m *JWTManager) ValidateToken(tokenStr string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.secret, nil
})
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}
+13
View File
@@ -0,0 +1,13 @@
package auth
import "golang.org/x/crypto/bcrypt"
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
func CheckPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
+21
View File
@@ -0,0 +1,21 @@
package auth
import "context"
// SSOProvider defines the interface for SSO authentication providers.
// Implementations will be added for LDAP, OAuth2, and SAML.
type SSOProvider interface {
// Authenticate validates credentials and returns user information.
Authenticate(ctx context.Context, credentials map[string]string) (*SSOUser, error)
// Name returns the provider name (e.g., "ldap", "oauth2", "saml").
Name() string
}
type SSOUser struct {
ExternalID string
Name string
Email string
Phone string
Department string
AvatarURL string
}
+144
View File
@@ -0,0 +1,144 @@
// Package chunker 提供文本智能分片服务,将长文本切分为适合向量化的片段
package chunker
import (
"strings"
"unicode/utf8"
)
// Options 分片配置
type Options struct {
ChunkSize int // 每块最大字符数,默认 500
Overlap int // 块间重叠字符数,默认 50
Separators []string // 分隔符列表(按优先级)
}
// DefaultOptions 默认分片配置
func DefaultOptions() Options {
return Options{
ChunkSize: 500,
Overlap: 50,
Separators: []string{
"\n\n", "\n", "。", ".", "", "!", "", "?", "", ";", " ",
},
}
}
// ChunkText 智能分片:优先按段落/句子边界切分,避免在句子中间断开
func ChunkText(text string, opts Options) []string {
text = strings.TrimSpace(text)
if text == "" {
return nil
}
if opts.ChunkSize <= 0 {
opts.ChunkSize = 500
}
if opts.Overlap < 0 {
opts.Overlap = 0
}
if opts.Separators == nil {
opts.Separators = DefaultOptions().Separators
}
runeLen := utf8.RuneCountInString(text)
if runeLen <= opts.ChunkSize {
return []string{text}
}
// 递归分片
chunks := recursiveSplit(text, opts.Separators, opts.ChunkSize)
// 添加重叠
if opts.Overlap > 0 && len(chunks) > 1 {
chunks = addOverlap(chunks, opts.Overlap)
}
// 过滤空片段
var result []string
for _, c := range chunks {
c = strings.TrimSpace(c)
if c != "" {
result = append(result, c)
}
}
return result
}
// recursiveSplit 递归分片
func recursiveSplit(text string, separators []string, chunkSize int) []string {
if utf8.RuneCountInString(text) <= chunkSize {
return []string{text}
}
for _, sep := range separators {
parts := strings.Split(text, sep)
if len(parts) <= 1 {
continue
}
var result []string
current := ""
for _, part := range parts {
candidate := current
if candidate != "" {
candidate += sep
}
candidate += part
if utf8.RuneCountInString(candidate) <= chunkSize {
current = candidate
} else {
if current != "" {
result = append(result, current)
}
if utf8.RuneCountInString(part) > chunkSize {
// 继续用更细的分隔符拆分
nextSeps := separators[1:]
if len(nextSeps) == 0 {
nextSeps = nil
}
result = append(result, recursiveSplit(part, nextSeps, chunkSize)...)
current = ""
} else {
current = part
}
}
}
if current != "" {
result = append(result, current)
}
return result
}
// 无分隔符可用,按字符硬切
runes := []rune(text)
var result []string
for i := 0; i < len(runes); i += chunkSize {
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
result = append(result, string(runes[i:end]))
}
return result
}
// addOverlap 给分片添加重叠区域
func addOverlap(chunks []string, overlap int) []string {
if len(chunks) <= 1 || overlap <= 0 {
return chunks
}
result := []string{chunks[0]}
for i := 1; i < len(chunks); i++ {
prevRunes := []rune(chunks[i-1])
overlapStart := len(prevRunes) - overlap
if overlapStart < 0 {
overlapStart = 0
}
prevTail := string(prevRunes[overlapStart:])
result = append(result, prevTail+chunks[i])
}
return result
}
+29
View File
@@ -0,0 +1,29 @@
package db
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
func NewPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("parse database url: %w", err)
}
config.MaxConns = 20
config.MinConns = 5
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("create pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("ping database: %w", err)
}
return pool, nil
}
+122
View File
@@ -0,0 +1,122 @@
-- name: CreateApplication :one
INSERT INTO applications (
name, slug, description, long_description, icon_url,
category_id, creator_id, dept_id,
dify_app_id, dify_app_type, dify_api_key,
app_config, welcome_message, suggested_prompts,
max_tokens, temperature, status, visibility, is_template
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
) RETURNING *;
-- name: GetApplicationByID :one
SELECT * FROM applications WHERE id = $1;
-- name: GetApplicationBySlug :one
SELECT * FROM applications WHERE slug = $1;
-- name: UpdateApplication :one
UPDATE applications
SET name = COALESCE(sqlc.narg('name'), name),
description = COALESCE(sqlc.narg('description'), description),
long_description = COALESCE(sqlc.narg('long_description'), long_description),
icon_url = COALESCE(sqlc.narg('icon_url'), icon_url),
category_id = COALESCE(sqlc.narg('category_id'), category_id),
app_config = COALESCE(sqlc.narg('app_config'), app_config),
welcome_message = COALESCE(sqlc.narg('welcome_message'), welcome_message),
suggested_prompts = COALESCE(sqlc.narg('suggested_prompts'), suggested_prompts),
max_tokens = COALESCE(sqlc.narg('max_tokens'), max_tokens),
temperature = COALESCE(sqlc.narg('temperature'), temperature),
visibility = COALESCE(sqlc.narg('visibility'), visibility)
WHERE id = $1
RETURNING *;
-- name: UpdateApplicationStatus :exec
UPDATE applications SET status = $2 WHERE id = $1;
-- name: DeleteApplication :exec
DELETE FROM applications WHERE id = $1 AND status = 'draft';
-- name: ListStoreApps :many
SELECT a.*, c.name as category_name, c.slug as category_slug, u.name as creator_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.status = 'approved'
AND a.visibility = 'public'
AND (sqlc.narg('category_slug')::VARCHAR IS NULL OR c.slug = sqlc.narg('category_slug'))
AND (sqlc.narg('search')::VARCHAR IS NULL
OR to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
@@ plainto_tsquery('simple', sqlc.narg('search')))
ORDER BY
CASE WHEN sqlc.narg('sort')::VARCHAR = 'popular' THEN a.usage_count END DESC,
CASE WHEN sqlc.narg('sort')::VARCHAR = 'rating' THEN a.avg_rating END DESC,
CASE WHEN sqlc.narg('sort')::VARCHAR IS NULL OR sqlc.narg('sort') = 'latest' THEN EXTRACT(EPOCH FROM a.published_at) END DESC
LIMIT $1 OFFSET $2;
-- name: CountStoreApps :one
SELECT COUNT(*) FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.status = 'approved'
AND a.visibility = 'public'
AND (sqlc.narg('category_slug')::VARCHAR IS NULL OR c.slug = sqlc.narg('category_slug'))
AND (sqlc.narg('search')::VARCHAR IS NULL
OR to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
@@ plainto_tsquery('simple', sqlc.narg('search')));
-- name: ListFeaturedApps :many
SELECT a.*, c.name as category_name, u.name as creator_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.is_featured = true AND a.status = 'approved' AND a.visibility = 'public'
ORDER BY a.usage_count DESC
LIMIT $1;
-- name: ListTopApps :many
SELECT a.*, c.name as category_name, u.name as creator_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE a.status = 'approved' AND a.visibility = 'public'
ORDER BY a.usage_count DESC
LIMIT $1;
-- name: ListCreatorApps :many
SELECT a.*, c.name as category_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.creator_id = $1
ORDER BY a.updated_at DESC
LIMIT $2 OFFSET $3;
-- name: ListTemplates :many
SELECT a.*, c.name as category_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
WHERE a.is_template = true AND a.status = 'approved'
ORDER BY a.usage_count DESC;
-- name: IncrementUsageCount :exec
UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1;
-- name: UpdateFavoriteCount :exec
UPDATE applications SET favorite_count = favorite_count + $2 WHERE id = $1;
-- name: UpdateAppRating :exec
UPDATE applications
SET avg_rating = $2, rating_count = $3
WHERE id = $1;
-- name: ListAllApps :many
SELECT a.*, c.name as category_name, u.name as creator_name
FROM applications a
LEFT JOIN categories c ON a.category_id = c.id
LEFT JOIN users u ON a.creator_id = u.id
WHERE (sqlc.narg('status')::VARCHAR IS NULL OR a.status = sqlc.narg('status'))
ORDER BY a.created_at DESC
LIMIT $1 OFFSET $2;
-- name: CountAllApps :one
SELECT COUNT(*) FROM applications
WHERE (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'));
+23
View File
@@ -0,0 +1,23 @@
-- name: CreateAuditLog :exec
INSERT INTO audit_logs (user_id, action, resource_type, resource_id, details, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5, $6, $7);
-- name: ListAuditLogs :many
SELECT al.*, u.name as user_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
WHERE (sqlc.narg('user_id')::UUID IS NULL OR al.user_id = sqlc.narg('user_id'))
AND (sqlc.narg('action')::VARCHAR IS NULL OR al.action = sqlc.narg('action'))
AND (sqlc.narg('resource_type')::VARCHAR IS NULL OR al.resource_type = sqlc.narg('resource_type'))
AND (sqlc.narg('start_time')::TIMESTAMPTZ IS NULL OR al.created_at >= sqlc.narg('start_time'))
AND (sqlc.narg('end_time')::TIMESTAMPTZ IS NULL OR al.created_at <= sqlc.narg('end_time'))
ORDER BY al.created_at DESC
LIMIT $1 OFFSET $2;
-- name: CountAuditLogs :one
SELECT COUNT(*) FROM audit_logs
WHERE (sqlc.narg('user_id')::UUID IS NULL OR user_id = sqlc.narg('user_id'))
AND (sqlc.narg('action')::VARCHAR IS NULL OR action = sqlc.narg('action'))
AND (sqlc.narg('resource_type')::VARCHAR IS NULL OR resource_type = sqlc.narg('resource_type'))
AND (sqlc.narg('start_time')::TIMESTAMPTZ IS NULL OR created_at >= sqlc.narg('start_time'))
AND (sqlc.narg('end_time')::TIMESTAMPTZ IS NULL OR created_at <= sqlc.narg('end_time'));
+10
View File
@@ -0,0 +1,10 @@
-- name: ListCategories :many
SELECT * FROM categories
WHERE status = 'active'
ORDER BY sort_order ASC;
-- name: GetCategoryByID :one
SELECT * FROM categories WHERE id = $1;
-- name: GetCategoryBySlug :one
SELECT * FROM categories WHERE slug = $1;
@@ -0,0 +1,37 @@
-- name: AddFavorite :exec
INSERT INTO app_favorites (user_id, app_id) VALUES ($1, $2)
ON CONFLICT DO NOTHING;
-- name: RemoveFavorite :exec
DELETE FROM app_favorites WHERE user_id = $1 AND app_id = $2;
-- name: IsFavorited :one
SELECT EXISTS(SELECT 1 FROM app_favorites WHERE user_id = $1 AND app_id = $2);
-- name: ListUserFavorites :many
SELECT a.*, c.name as category_name
FROM app_favorites f
JOIN applications a ON f.app_id = a.id
LEFT JOIN categories c ON a.category_id = c.id
WHERE f.user_id = $1
ORDER BY f.created_at DESC
LIMIT $2 OFFSET $3;
-- name: UpsertRating :one
INSERT INTO app_ratings (app_id, user_id, score, comment)
VALUES ($1, $2, $3, $4)
ON CONFLICT (app_id, user_id)
DO UPDATE SET score = EXCLUDED.score, comment = EXCLUDED.comment
RETURNING *;
-- name: GetAppAvgRating :one
SELECT COALESCE(AVG(score)::REAL, 0) as avg_rating, COUNT(*) as rating_count
FROM app_ratings WHERE app_id = $1;
-- name: ListAppRatings :many
SELECT r.*, u.name as user_name, u.avatar_url as user_avatar
FROM app_ratings r
JOIN users u ON r.user_id = u.id
WHERE r.app_id = $1
ORDER BY r.created_at DESC
LIMIT $2 OFFSET $3;
+40
View File
@@ -0,0 +1,40 @@
-- name: CreateReview :one
INSERT INTO app_reviews (app_id, version, submitter_id, submit_comment)
VALUES ($1, $2, $3, $4)
RETURNING *;
-- name: GetReviewByID :one
SELECT * FROM app_reviews WHERE id = $1;
-- name: ListPendingReviews :many
SELECT r.*, a.name as app_name, a.description as app_description,
a.icon_url as app_icon, u.name as submitter_name
FROM app_reviews r
JOIN applications a ON r.app_id = a.id
JOIN users u ON r.submitter_id = u.id
WHERE r.status = 'pending'
ORDER BY r.submitted_at ASC
LIMIT $1 OFFSET $2;
-- name: CountPendingReviews :one
SELECT COUNT(*) FROM app_reviews WHERE status = 'pending';
-- name: ApproveReview :exec
UPDATE app_reviews
SET status = 'approved', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
WHERE id = $1;
-- name: RejectReview :exec
UPDATE app_reviews
SET status = 'rejected', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
WHERE id = $1;
-- name: WithdrawReview :exec
UPDATE app_reviews SET status = 'withdrawn' WHERE id = $1 AND status = 'pending';
-- name: ListAppReviews :many
SELECT r.*, u.name as reviewer_name
FROM app_reviews r
LEFT JOIN users u ON r.reviewer_id = u.id
WHERE r.app_id = $1
ORDER BY r.created_at DESC;
+34
View File
@@ -0,0 +1,34 @@
-- name: CreateUsageLog :one
INSERT INTO app_usage_logs (
app_id, user_id, dept_id, conversation_id, message_count,
prompt_tokens, completion_tokens, total_tokens, model_name,
estimated_cost, duration_ms, is_successful, error_message, client_type
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING *;
-- name: GetRecentUsedApps :many
SELECT DISTINCT ON (a.id) a.*, c.name as category_name, l.created_at as last_used_at
FROM app_usage_logs l
JOIN applications a ON l.app_id = a.id
LEFT JOIN categories c ON a.category_id = c.id
WHERE l.user_id = $1
ORDER BY a.id, l.created_at DESC
LIMIT $2;
-- name: GetUserStats :one
SELECT
COUNT(*) as total_conversations,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(estimated_cost), 0) as total_cost
FROM app_usage_logs
WHERE user_id = $1
AND created_at >= $2;
-- name: GetOverviewStats :one
SELECT
(SELECT COUNT(*) FROM users WHERE status = 'active') as total_users,
(SELECT COUNT(*) FROM applications WHERE status = 'approved') as total_apps,
(SELECT COUNT(DISTINCT user_id) FROM app_usage_logs WHERE created_at >= $1) as active_users,
(SELECT COUNT(*) FROM app_usage_logs WHERE created_at >= $1) as total_conversations,
(SELECT COALESCE(SUM(total_tokens), 0) FROM app_usage_logs WHERE created_at >= $2) as monthly_tokens,
(SELECT COALESCE(SUM(estimated_cost), 0) FROM app_usage_logs WHERE created_at >= $2) as monthly_cost;
+52
View File
@@ -0,0 +1,52 @@
-- name: GetUserByID :one
SELECT * FROM users WHERE id = $1;
-- name: GetUserByEmail :one
SELECT * FROM users WHERE email = $1;
-- name: GetUserByEmployeeID :one
SELECT * FROM users WHERE employee_id = $1;
-- name: CreateUser :one
INSERT INTO users (name, email, password_hash, phone, avatar_url, role, status, sso_provider, sso_external_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING *;
-- name: UpdateUserProfile :one
UPDATE users
SET name = COALESCE(sqlc.narg('name'), name),
phone = COALESCE(sqlc.narg('phone'), phone),
avatar_url = COALESCE(sqlc.narg('avatar_url'), avatar_url)
WHERE id = $1
RETURNING *;
-- name: UpdateUserRole :exec
UPDATE users SET role = $2 WHERE id = $1;
-- name: UpdateUserStatus :exec
UPDATE users SET status = $2 WHERE id = $1;
-- name: UpdateUserLogin :exec
UPDATE users
SET last_login_at = NOW(), login_count = login_count + 1
WHERE id = $1;
-- name: ListUsers :many
SELECT * FROM users
WHERE (sqlc.narg('role')::VARCHAR IS NULL OR role = sqlc.narg('role'))
AND (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'))
AND (sqlc.narg('search')::VARCHAR IS NULL
OR name ILIKE '%' || sqlc.narg('search') || '%'
OR email ILIKE '%' || sqlc.narg('search') || '%'
OR employee_id ILIKE '%' || sqlc.narg('search') || '%')
ORDER BY created_at DESC
LIMIT $1 OFFSET $2;
-- name: CountUsers :one
SELECT COUNT(*) FROM users
WHERE (sqlc.narg('role')::VARCHAR IS NULL OR role = sqlc.narg('role'))
AND (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'))
AND (sqlc.narg('search')::VARCHAR IS NULL
OR name ILIKE '%' || sqlc.narg('search') || '%'
OR email ILIKE '%' || sqlc.narg('search') || '%'
OR employee_id ILIKE '%' || sqlc.narg('search') || '%');
+74
View File
@@ -0,0 +1,74 @@
package dify
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type Client struct {
baseURL string
httpClient *http.Client
}
func NewClient(baseURL string) *Client {
return &Client{
baseURL: baseURL,
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
}
}
func (c *Client) do(ctx context.Context, method, path, apiKey string, body any) (*http.Response, error) {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute request: %w", err)
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
var apiErr APIError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return nil, fmt.Errorf("dify API error (status %d)", resp.StatusCode)
}
apiErr.Status = resp.StatusCode
return nil, &apiErr
}
return resp, nil
}
func (c *Client) doJSON(ctx context.Context, method, path, apiKey string, body any, result any) error {
resp, err := c.do(ctx, method, path, apiKey, body)
if err != nil {
return err
}
defer resp.Body.Close()
if result != nil {
return json.NewDecoder(resp.Body).Decode(result)
}
return nil
}
+102
View File
@@ -0,0 +1,102 @@
package dify
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
)
// ChatStream sends a chat message and returns a reader for SSE events.
// Caller is responsible for closing the returned io.ReadCloser.
func (c *Client) ChatStream(ctx context.Context, apiKey string, req *ChatRequest) (io.ReadCloser, error) {
req.ResponseMode = "streaming"
resp, err := c.do(ctx, "POST", "/chat-messages", apiKey, req)
if err != nil {
return nil, err
}
return resp.Body, nil
}
// ChatBlocking sends a chat message and waits for the complete response.
func (c *Client) ChatBlocking(ctx context.Context, apiKey string, req *ChatRequest) (*ChatStreamEvent, error) {
req.ResponseMode = "blocking"
var result ChatStreamEvent
if err := c.doJSON(ctx, "POST", "/chat-messages", apiKey, req, &result); err != nil {
return nil, err
}
return &result, nil
}
// ParseSSEStream parses a Dify SSE stream and calls handler for each event.
func ParseSSEStream(reader io.Reader, handler func(event ChatStreamEvent) error) error {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var event ChatStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
if err := handler(event); err != nil {
return err
}
}
return scanner.Err()
}
// ListConversations returns the user's conversation list for an app.
func (c *Client) ListConversations(ctx context.Context, apiKey, user string, limit int, firstID string) (*ConversationListResponse, error) {
path := fmt.Sprintf("/conversations?user=%s&limit=%d", user, limit)
if firstID != "" {
path += "&first_id=" + firstID
}
var result ConversationListResponse
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
return nil, err
}
return &result, nil
}
// ListMessages returns messages in a conversation.
func (c *Client) ListMessages(ctx context.Context, apiKey, user, conversationID string, limit int, firstID string) (*MessageListResponse, error) {
path := fmt.Sprintf("/messages?user=%s&conversation_id=%s&limit=%d", user, conversationID, limit)
if firstID != "" {
path += "&first_id=" + firstID
}
var result MessageListResponse
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
return nil, err
}
return &result, nil
}
// DeleteConversation deletes a conversation.
func (c *Client) DeleteConversation(ctx context.Context, apiKey, user, conversationID string) error {
body := map[string]string{"user": user}
return c.doJSON(ctx, "DELETE", "/conversations/"+conversationID, apiKey, body, nil)
}
// SubmitFeedback submits feedback for a message.
func (c *Client) SubmitFeedback(ctx context.Context, apiKey, messageID string, req *FeedbackRequest) error {
return c.doJSON(ctx, "POST", "/messages/"+messageID+"/feedbacks", apiKey, req, nil)
}
+96
View File
@@ -0,0 +1,96 @@
package dify
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
)
// CreateDataset creates a new knowledge base (dataset) in Dify.
func (c *Client) CreateDataset(ctx context.Context, apiKey string, req *DatasetCreateRequest) (*Dataset, error) {
var result Dataset
if err := c.doJSON(ctx, "POST", "/datasets", apiKey, req, &result); err != nil {
return nil, err
}
return &result, nil
}
// DeleteDataset deletes a knowledge base.
func (c *Client) DeleteDataset(ctx context.Context, apiKey, datasetID string) error {
return c.doJSON(ctx, "DELETE", "/datasets/"+datasetID, apiKey, nil, nil)
}
// UploadDocument uploads a file to a dataset for indexing.
func (c *Client) UploadDocument(ctx context.Context, apiKey, datasetID string, filename string, fileReader io.Reader) (*DocumentIndexingStatus, error) {
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("create form file: %w", err)
}
if _, err := io.Copy(part, fileReader); err != nil {
return nil, fmt.Errorf("copy file: %w", err)
}
// indexing mode
if err := writer.WriteField("indexing_technique", "high_quality"); err != nil {
return nil, err
}
if err := writer.WriteField("process_rule", `{"mode": "automatic"}`); err != nil {
return nil, err
}
writer.Close()
path := fmt.Sprintf("/datasets/%s/document/create_by_file", datasetID)
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, &buf)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("upload failed (status %d): %s", resp.StatusCode, string(body))
}
// Dify returns document info in response
var result struct {
Document DocumentIndexingStatus `json:"document"`
}
if err := decodeJSON(resp.Body, &result); err != nil {
return nil, err
}
return &result.Document, nil
}
// DeleteDocument deletes a document from a dataset.
func (c *Client) DeleteDocument(ctx context.Context, apiKey, datasetID, documentID string) error {
path := fmt.Sprintf("/datasets/%s/documents/%s", datasetID, documentID)
return c.doJSON(ctx, "DELETE", path, apiKey, nil, nil)
}
// GetDocumentIndexingStatus checks the indexing status of a document.
func (c *Client) GetDocumentIndexingStatus(ctx context.Context, apiKey, datasetID, batch string) (*DocumentIndexingStatus, error) {
path := fmt.Sprintf("/datasets/%s/documents/%s/indexing-status", datasetID, batch)
var result DocumentIndexingStatus
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
return nil, err
}
return &result, nil
}
func decodeJSON(r io.Reader, v any) error {
return json.NewDecoder(r).Decode(v)
}
+115
View File
@@ -0,0 +1,115 @@
package dify
import "time"
// --- Request types ---
type ChatRequest struct {
Query string `json:"query"`
Inputs map[string]any `json:"inputs,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
User string `json:"user"`
ResponseMode string `json:"response_mode"`
}
type CompletionRequest struct {
Inputs map[string]any `json:"inputs"`
User string `json:"user"`
ResponseMode string `json:"response_mode"`
}
type FeedbackRequest struct {
Rating string `json:"rating"` // "like", "dislike", null
User string `json:"user"`
}
// --- Response types ---
type ChatStreamEvent struct {
Event string `json:"event"`
TaskID string `json:"task_id"`
MessageID string `json:"message_id"`
ConversationID string `json:"conversation_id"`
Answer string `json:"answer"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt int64 `json:"created_at"`
}
type MessageEndMetadata struct {
Usage TokenUsage `json:"usage"`
}
type TokenUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Conversation struct {
ID string `json:"id"`
Name string `json:"name"`
Inputs any `json:"inputs"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ConversationListResponse struct {
Data []Conversation `json:"data"`
HasMore bool `json:"has_more"`
Limit int `json:"limit"`
}
type Message struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
Query string `json:"query"`
Answer string `json:"answer"`
Feedback *Feedback `json:"feedback"`
CreatedAt int64 `json:"created_at"`
Inputs map[string]any `json:"inputs"`
}
type Feedback struct {
Rating string `json:"rating"`
}
type MessageListResponse struct {
Data []Message `json:"data"`
HasMore bool `json:"has_more"`
Limit int `json:"limit"`
}
// --- Dataset/Knowledge types ---
type DatasetCreateRequest struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
}
type Dataset struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
DocCount int `json:"document_count"`
CreatedAt time.Time `json:"created_at"`
}
type DocumentIndexingStatus struct {
ID string `json:"id"`
IndexingStatus string `json:"indexing_status"`
ProcessingStart string `json:"processing_started_at"`
CompletedAt string `json:"completed_at"`
}
// --- Error ---
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
Status int `json:"status"`
}
func (e *APIError) Error() string {
return e.Message
}
+138
View File
@@ -0,0 +1,138 @@
// Package embedding 提供文本向量化服务,支持 OpenAI 兼容的 Embedding API
package embedding
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// Config embedding 服务配置
type Config struct {
APIKey string // API 密钥
BaseURL string // API 基础 URLOpenAI 兼容格式)
Model string // 模型名称
Dimensions int // 向量维度
}
// Client embedding 客户端
type Client struct {
cfg Config
httpClient *http.Client
}
// NewClient 创建 embedding 客户端
func NewClient(cfg Config) *Client {
if cfg.BaseURL == "" {
cfg.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
}
if cfg.Model == "" {
cfg.Model = "text-embedding-v3"
}
if cfg.Dimensions == 0 {
cfg.Dimensions = 1024
}
return &Client{
cfg: cfg,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// embeddingRequest OpenAI 兼容的 embedding 请求
type embeddingRequest struct {
Input interface{} `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
}
// embeddingResponse OpenAI 兼容的 embedding 响应
type embeddingResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// GetEmbedding 获取单条文本的向量嵌入
func (c *Client) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
if c.cfg.APIKey == "" {
return nil, fmt.Errorf("embedding API key not configured")
}
text = strings.TrimSpace(text)
if text == "" {
return nil, fmt.Errorf("empty text")
}
// 截断过长文本
if len([]rune(text)) > 8000 {
text = string([]rune(text)[:8000])
}
req := embeddingRequest{
Input: text,
Model: c.cfg.Model,
Dimensions: c.cfg.Dimensions,
}
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("embedding request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("embedding error (status %d): %s", resp.StatusCode, string(errBody))
}
var result embeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no embedding data returned")
}
return result.Data[0].Embedding, nil
}
// GetEmbeddingBatch 批量获取文本向量嵌入
func (c *Client) GetEmbeddingBatch(ctx context.Context, texts []string) ([][]float32, error) {
results := make([][]float32, len(texts))
for i, text := range texts {
emb, err := c.GetEmbedding(ctx, text)
if err != nil {
results[i] = nil
continue
}
results[i] = emb
}
return results, nil
}
// IsConfigured 检查 embedding 服务是否已配置
func (c *Client) IsConfigured() bool {
return c.cfg.APIKey != ""
}
+196
View File
@@ -0,0 +1,196 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type AnthropicProvider struct {
apiKey string
baseURL string
model string
httpClient *http.Client
}
func NewAnthropicProvider(apiKey, baseURL, defaultModel string) *AnthropicProvider {
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
if defaultModel == "" {
defaultModel = "claude-sonnet-4-20250514"
}
return &AnthropicProvider{
apiKey: apiKey,
baseURL: baseURL,
model: defaultModel,
httpClient: &http.Client{},
}
}
func (p *AnthropicProvider) Name() string { return "anthropic" }
type anthropicRequest struct {
Model string `json:"model"`
Messages []anthropicMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
}
type anthropicMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type anthropicResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
func (p *AnthropicProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
model := req.Model
if model == "" {
model = p.model
}
system, msgs := extractSystemMessage(req.Messages)
maxTokens := req.MaxTokens
if maxTokens == 0 {
maxTokens = 4096
}
aReq := anthropicRequest{
Model: model,
Messages: toAnthropicMessages(msgs),
System: system,
MaxTokens: maxTokens,
Temperature: req.Temperature,
Stream: false,
}
body, err := json.Marshal(aReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("x-api-key", p.apiKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("anthropic request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
}
var aResp anthropicResponse
if err := json.NewDecoder(resp.Body).Decode(&aResp); err != nil {
return nil, err
}
content := ""
for _, c := range aResp.Content {
if c.Type == "text" {
content += c.Text
}
}
return &ChatResponse{
Content: content,
Model: aResp.Model,
PromptTokens: aResp.Usage.InputTokens,
CompletionTokens: aResp.Usage.OutputTokens,
TotalTokens: aResp.Usage.InputTokens + aResp.Usage.OutputTokens,
}, nil
}
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
model := req.Model
if model == "" {
model = p.model
}
system, msgs := extractSystemMessage(req.Messages)
maxTokens := req.MaxTokens
if maxTokens == 0 {
maxTokens = 4096
}
aReq := anthropicRequest{
Model: model,
Messages: toAnthropicMessages(msgs),
System: system,
MaxTokens: maxTokens,
Temperature: req.Temperature,
Stream: true,
}
body, err := json.Marshal(aReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("x-api-key", p.apiKey)
httpReq.Header.Set("anthropic-version", "2023-06-01")
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("anthropic stream request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
}
return resp.Body, nil
}
func extractSystemMessage(msgs []Message) (string, []Message) {
system := ""
var filtered []Message
for _, m := range msgs {
if m.Role == RoleSystem {
system = m.Content
} else {
filtered = append(filtered, m)
}
}
return system, filtered
}
func toAnthropicMessages(msgs []Message) []anthropicMessage {
out := make([]anthropicMessage, len(msgs))
for i, m := range msgs {
out[i] = anthropicMessage{Role: string(m.Role), Content: m.Content}
}
return out
}
+197
View File
@@ -0,0 +1,197 @@
package llm
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
)
// Manager manages multiple LLM providers and routes requests.
type Manager struct {
providers map[string]Provider
fallback string
}
func NewManager() *Manager {
return &Manager{
providers: make(map[string]Provider),
}
}
func (m *Manager) Register(name string, provider Provider) {
m.providers[name] = provider
if m.fallback == "" {
m.fallback = name
}
}
func (m *Manager) SetFallback(name string) {
m.fallback = name
}
func (m *Manager) GetProvider(name string) (Provider, error) {
if p, ok := m.providers[name]; ok {
return p, nil
}
if p, ok := m.providers[m.fallback]; ok {
return p, nil
}
return nil, fmt.Errorf("no provider found: %s", name)
}
// Chat performs a blocking chat completion using the specified provider.
func (m *Manager) Chat(ctx context.Context, providerName string, req *ChatRequest) (*ChatResponse, error) {
provider, err := m.GetProvider(providerName)
if err != nil {
return nil, err
}
return provider.ChatCompletion(ctx, req)
}
// ChatStream performs a streaming chat and returns the raw SSE body.
func (m *Manager) ChatStream(ctx context.Context, providerName string, req *ChatRequest) (io.ReadCloser, error) {
provider, err := m.GetProvider(providerName)
if err != nil {
return nil, err
}
req.Stream = true
return provider.ChatStream(ctx, req)
}
// StreamEvent represents a normalized SSE event for the frontend.
type StreamEvent struct {
Event string `json:"event"`
Answer string `json:"answer,omitempty"`
MessageID string `json:"message_id,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
Model string `json:"model"`
}
// TransformOpenAIStream reads an OpenAI SSE stream and writes normalized events to the writer.
func TransformOpenAIStream(reader io.Reader, write func(event StreamEvent)) error {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
var totalContent string
var model string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
write(StreamEvent{
Event: "message_end",
Usage: &Usage{
TotalTokens: estimateTokens(totalContent),
Model: model,
},
})
break
}
var chunk struct {
ID string `json:"id"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
model = chunk.Model
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
content := chunk.Choices[0].Delta.Content
totalContent += content
write(StreamEvent{
Event: "message",
Answer: content,
MessageID: chunk.ID,
})
}
}
return scanner.Err()
}
// TransformAnthropicStream reads an Anthropic SSE stream and writes normalized events.
func TransformAnthropicStream(reader io.Reader, write func(event StreamEvent)) error {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
var model string
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
if strings.HasPrefix(line, "event: ") {
continue
}
continue
}
data := strings.TrimPrefix(line, "data: ")
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
eventType, _ := event["type"].(string)
switch eventType {
case "message_start":
if msg, ok := event["message"].(map[string]any); ok {
if m, ok := msg["model"].(string); ok {
model = m
}
}
case "content_block_delta":
if delta, ok := event["delta"].(map[string]any); ok {
if text, ok := delta["text"].(string); ok {
write(StreamEvent{
Event: "message",
Answer: text,
})
}
}
case "message_delta":
if usage, ok := event["usage"].(map[string]any); ok {
outputTokens := int(getFloat(usage, "output_tokens"))
write(StreamEvent{
Event: "message_end",
Usage: &Usage{
CompletionTokens: outputTokens,
TotalTokens: outputTokens,
Model: model,
},
})
}
}
}
return scanner.Err()
}
func getFloat(m map[string]any, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0
}
func estimateTokens(text string) int {
return len(text) / 4
}
+167
View File
@@ -0,0 +1,167 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type OpenAIProvider struct {
apiKey string
baseURL string
model string
httpClient *http.Client
}
func NewOpenAIProvider(apiKey, baseURL, defaultModel string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
if defaultModel == "" {
defaultModel = "gpt-4o-mini"
}
return &OpenAIProvider{
apiKey: apiKey,
baseURL: baseURL,
model: defaultModel,
httpClient: &http.Client{},
}
}
func (p *OpenAIProvider) Name() string { return "openai" }
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
model := req.Model
if model == "" {
model = p.model
}
oaiReq := openAIRequest{
Model: model,
Messages: toOpenAIMessages(req.Messages),
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Stream: false,
}
body, err := json.Marshal(oaiReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("openai request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
}
var oaiResp openAIResponse
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
return nil, err
}
content := ""
if len(oaiResp.Choices) > 0 {
content = oaiResp.Choices[0].Message.Content
}
return &ChatResponse{
Content: content,
Model: oaiResp.Model,
PromptTokens: oaiResp.Usage.PromptTokens,
CompletionTokens: oaiResp.Usage.CompletionTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
}, nil
}
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
model := req.Model
if model == "" {
model = p.model
}
oaiReq := openAIRequest{
Model: model,
Messages: toOpenAIMessages(req.Messages),
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Stream: true,
}
body, err := json.Marshal(oaiReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("openai stream request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
}
return resp.Body, nil
}
func toOpenAIMessages(msgs []Message) []openAIMessage {
out := make([]openAIMessage, len(msgs))
for i, m := range msgs {
out[i] = openAIMessage{Role: string(m.Role), Content: m.Content}
}
return out
}
+55
View File
@@ -0,0 +1,55 @@
package llm
import (
"context"
"io"
)
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
type Message struct {
Role Role `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
type ChatResponse struct {
Content string `json:"content"`
Model string `json:"model"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
ConversationID string `json:"conversation_id,omitempty"`
}
type StreamDelta struct {
Content string `json:"content"`
FinishReason string `json:"finish_reason,omitempty"`
}
// Provider abstracts an LLM API (OpenAI, Anthropic, etc.)
type Provider interface {
ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error)
Name() string
}
type ProviderConfig struct {
Type string `json:"type"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url,omitempty"`
Model string `json:"model,omitempty"`
}