Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type JWTManager struct {
|
||||
secret []byte
|
||||
accessExpiry time.Duration
|
||||
refreshExpiry time.Duration
|
||||
}
|
||||
|
||||
func NewJWTManager(secret string, accessExpiry, refreshExpiry time.Duration) *JWTManager {
|
||||
return &JWTManager{
|
||||
secret: []byte(secret),
|
||||
accessExpiry: accessExpiry,
|
||||
refreshExpiry: refreshExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *JWTManager) GenerateTokenPair(userID uuid.UUID, email, role string) (*TokenPair, error) {
|
||||
now := time.Now()
|
||||
accessExp := now.Add(m.accessExpiry)
|
||||
|
||||
accessClaims := &Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(accessExp),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
Subject: userID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessStr, err := accessToken.SignedString(m.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign access token: %w", err)
|
||||
}
|
||||
|
||||
refreshClaims := &Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.refreshExpiry)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
Subject: userID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshStr, err := refreshToken.SignedString(m.secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessStr,
|
||||
RefreshToken: refreshStr,
|
||||
ExpiresAt: accessExp.Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *JWTManager) ValidateToken(tokenStr string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return m.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func CheckPassword(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
// SSOProvider defines the interface for SSO authentication providers.
|
||||
// Implementations will be added for LDAP, OAuth2, and SAML.
|
||||
type SSOProvider interface {
|
||||
// Authenticate validates credentials and returns user information.
|
||||
Authenticate(ctx context.Context, credentials map[string]string) (*SSOUser, error)
|
||||
// Name returns the provider name (e.g., "ldap", "oauth2", "saml").
|
||||
Name() string
|
||||
}
|
||||
|
||||
type SSOUser struct {
|
||||
ExternalID string
|
||||
Name string
|
||||
Email string
|
||||
Phone string
|
||||
Department string
|
||||
AvatarURL string
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
// Package chunker 提供文本智能分片服务,将长文本切分为适合向量化的片段
|
||||
package chunker
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Options 分片配置
|
||||
type Options struct {
|
||||
ChunkSize int // 每块最大字符数,默认 500
|
||||
Overlap int // 块间重叠字符数,默认 50
|
||||
Separators []string // 分隔符列表(按优先级)
|
||||
}
|
||||
|
||||
// DefaultOptions 默认分片配置
|
||||
func DefaultOptions() Options {
|
||||
return Options{
|
||||
ChunkSize: 500,
|
||||
Overlap: 50,
|
||||
Separators: []string{
|
||||
"\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";", " ",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkText 智能分片:优先按段落/句子边界切分,避免在句子中间断开
|
||||
func ChunkText(text string, opts Options) []string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.ChunkSize <= 0 {
|
||||
opts.ChunkSize = 500
|
||||
}
|
||||
if opts.Overlap < 0 {
|
||||
opts.Overlap = 0
|
||||
}
|
||||
if opts.Separators == nil {
|
||||
opts.Separators = DefaultOptions().Separators
|
||||
}
|
||||
|
||||
runeLen := utf8.RuneCountInString(text)
|
||||
if runeLen <= opts.ChunkSize {
|
||||
return []string{text}
|
||||
}
|
||||
|
||||
// 递归分片
|
||||
chunks := recursiveSplit(text, opts.Separators, opts.ChunkSize)
|
||||
|
||||
// 添加重叠
|
||||
if opts.Overlap > 0 && len(chunks) > 1 {
|
||||
chunks = addOverlap(chunks, opts.Overlap)
|
||||
}
|
||||
|
||||
// 过滤空片段
|
||||
var result []string
|
||||
for _, c := range chunks {
|
||||
c = strings.TrimSpace(c)
|
||||
if c != "" {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// recursiveSplit 递归分片
|
||||
func recursiveSplit(text string, separators []string, chunkSize int) []string {
|
||||
if utf8.RuneCountInString(text) <= chunkSize {
|
||||
return []string{text}
|
||||
}
|
||||
|
||||
for _, sep := range separators {
|
||||
parts := strings.Split(text, sep)
|
||||
if len(parts) <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
var result []string
|
||||
current := ""
|
||||
for _, part := range parts {
|
||||
candidate := current
|
||||
if candidate != "" {
|
||||
candidate += sep
|
||||
}
|
||||
candidate += part
|
||||
|
||||
if utf8.RuneCountInString(candidate) <= chunkSize {
|
||||
current = candidate
|
||||
} else {
|
||||
if current != "" {
|
||||
result = append(result, current)
|
||||
}
|
||||
if utf8.RuneCountInString(part) > chunkSize {
|
||||
// 继续用更细的分隔符拆分
|
||||
nextSeps := separators[1:]
|
||||
if len(nextSeps) == 0 {
|
||||
nextSeps = nil
|
||||
}
|
||||
result = append(result, recursiveSplit(part, nextSeps, chunkSize)...)
|
||||
current = ""
|
||||
} else {
|
||||
current = part
|
||||
}
|
||||
}
|
||||
}
|
||||
if current != "" {
|
||||
result = append(result, current)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// 无分隔符可用,按字符硬切
|
||||
runes := []rune(text)
|
||||
var result []string
|
||||
for i := 0; i < len(runes); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
result = append(result, string(runes[i:end]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// addOverlap 给分片添加重叠区域
|
||||
func addOverlap(chunks []string, overlap int) []string {
|
||||
if len(chunks) <= 1 || overlap <= 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
result := []string{chunks[0]}
|
||||
for i := 1; i < len(chunks); i++ {
|
||||
prevRunes := []rune(chunks[i-1])
|
||||
overlapStart := len(prevRunes) - overlap
|
||||
if overlapStart < 0 {
|
||||
overlapStart = 0
|
||||
}
|
||||
prevTail := string(prevRunes[overlapStart:])
|
||||
result = append(result, prevTail+chunks[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
func NewPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
|
||||
config, err := pgxpool.ParseConfig(databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse database url: %w", err)
|
||||
}
|
||||
|
||||
config.MaxConns = 20
|
||||
config.MinConns = 5
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
return nil, fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
-- name: CreateApplication :one
|
||||
INSERT INTO applications (
|
||||
name, slug, description, long_description, icon_url,
|
||||
category_id, creator_id, dept_id,
|
||||
dify_app_id, dify_app_type, dify_api_key,
|
||||
app_config, welcome_message, suggested_prompts,
|
||||
max_tokens, temperature, status, visibility, is_template
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19
|
||||
) RETURNING *;
|
||||
|
||||
-- name: GetApplicationByID :one
|
||||
SELECT * FROM applications WHERE id = $1;
|
||||
|
||||
-- name: GetApplicationBySlug :one
|
||||
SELECT * FROM applications WHERE slug = $1;
|
||||
|
||||
-- name: UpdateApplication :one
|
||||
UPDATE applications
|
||||
SET name = COALESCE(sqlc.narg('name'), name),
|
||||
description = COALESCE(sqlc.narg('description'), description),
|
||||
long_description = COALESCE(sqlc.narg('long_description'), long_description),
|
||||
icon_url = COALESCE(sqlc.narg('icon_url'), icon_url),
|
||||
category_id = COALESCE(sqlc.narg('category_id'), category_id),
|
||||
app_config = COALESCE(sqlc.narg('app_config'), app_config),
|
||||
welcome_message = COALESCE(sqlc.narg('welcome_message'), welcome_message),
|
||||
suggested_prompts = COALESCE(sqlc.narg('suggested_prompts'), suggested_prompts),
|
||||
max_tokens = COALESCE(sqlc.narg('max_tokens'), max_tokens),
|
||||
temperature = COALESCE(sqlc.narg('temperature'), temperature),
|
||||
visibility = COALESCE(sqlc.narg('visibility'), visibility)
|
||||
WHERE id = $1
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateApplicationStatus :exec
|
||||
UPDATE applications SET status = $2 WHERE id = $1;
|
||||
|
||||
-- name: DeleteApplication :exec
|
||||
DELETE FROM applications WHERE id = $1 AND status = 'draft';
|
||||
|
||||
-- name: ListStoreApps :many
|
||||
SELECT a.*, c.name as category_name, c.slug as category_slug, u.name as creator_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.status = 'approved'
|
||||
AND a.visibility = 'public'
|
||||
AND (sqlc.narg('category_slug')::VARCHAR IS NULL OR c.slug = sqlc.narg('category_slug'))
|
||||
AND (sqlc.narg('search')::VARCHAR IS NULL
|
||||
OR to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
|
||||
@@ plainto_tsquery('simple', sqlc.narg('search')))
|
||||
ORDER BY
|
||||
CASE WHEN sqlc.narg('sort')::VARCHAR = 'popular' THEN a.usage_count END DESC,
|
||||
CASE WHEN sqlc.narg('sort')::VARCHAR = 'rating' THEN a.avg_rating END DESC,
|
||||
CASE WHEN sqlc.narg('sort')::VARCHAR IS NULL OR sqlc.narg('sort') = 'latest' THEN EXTRACT(EPOCH FROM a.published_at) END DESC
|
||||
LIMIT $1 OFFSET $2;
|
||||
|
||||
-- name: CountStoreApps :one
|
||||
SELECT COUNT(*) FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.status = 'approved'
|
||||
AND a.visibility = 'public'
|
||||
AND (sqlc.narg('category_slug')::VARCHAR IS NULL OR c.slug = sqlc.narg('category_slug'))
|
||||
AND (sqlc.narg('search')::VARCHAR IS NULL
|
||||
OR to_tsvector('simple', a.name || ' ' || COALESCE(a.description, ''))
|
||||
@@ plainto_tsquery('simple', sqlc.narg('search')));
|
||||
|
||||
-- name: ListFeaturedApps :many
|
||||
SELECT a.*, c.name as category_name, u.name as creator_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.is_featured = true AND a.status = 'approved' AND a.visibility = 'public'
|
||||
ORDER BY a.usage_count DESC
|
||||
LIMIT $1;
|
||||
|
||||
-- name: ListTopApps :many
|
||||
SELECT a.*, c.name as category_name, u.name as creator_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE a.status = 'approved' AND a.visibility = 'public'
|
||||
ORDER BY a.usage_count DESC
|
||||
LIMIT $1;
|
||||
|
||||
-- name: ListCreatorApps :many
|
||||
SELECT a.*, c.name as category_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.creator_id = $1
|
||||
ORDER BY a.updated_at DESC
|
||||
LIMIT $2 OFFSET $3;
|
||||
|
||||
-- name: ListTemplates :many
|
||||
SELECT a.*, c.name as category_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE a.is_template = true AND a.status = 'approved'
|
||||
ORDER BY a.usage_count DESC;
|
||||
|
||||
-- name: IncrementUsageCount :exec
|
||||
UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1;
|
||||
|
||||
-- name: UpdateFavoriteCount :exec
|
||||
UPDATE applications SET favorite_count = favorite_count + $2 WHERE id = $1;
|
||||
|
||||
-- name: UpdateAppRating :exec
|
||||
UPDATE applications
|
||||
SET avg_rating = $2, rating_count = $3
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ListAllApps :many
|
||||
SELECT a.*, c.name as category_name, u.name as creator_name
|
||||
FROM applications a
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
LEFT JOIN users u ON a.creator_id = u.id
|
||||
WHERE (sqlc.narg('status')::VARCHAR IS NULL OR a.status = sqlc.narg('status'))
|
||||
ORDER BY a.created_at DESC
|
||||
LIMIT $1 OFFSET $2;
|
||||
|
||||
-- name: CountAllApps :one
|
||||
SELECT COUNT(*) FROM applications
|
||||
WHERE (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'));
|
||||
@@ -0,0 +1,23 @@
|
||||
-- name: CreateAuditLog :exec
|
||||
INSERT INTO audit_logs (user_id, action, resource_type, resource_id, details, ip_address, user_agent)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7);
|
||||
|
||||
-- name: ListAuditLogs :many
|
||||
SELECT al.*, u.name as user_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
WHERE (sqlc.narg('user_id')::UUID IS NULL OR al.user_id = sqlc.narg('user_id'))
|
||||
AND (sqlc.narg('action')::VARCHAR IS NULL OR al.action = sqlc.narg('action'))
|
||||
AND (sqlc.narg('resource_type')::VARCHAR IS NULL OR al.resource_type = sqlc.narg('resource_type'))
|
||||
AND (sqlc.narg('start_time')::TIMESTAMPTZ IS NULL OR al.created_at >= sqlc.narg('start_time'))
|
||||
AND (sqlc.narg('end_time')::TIMESTAMPTZ IS NULL OR al.created_at <= sqlc.narg('end_time'))
|
||||
ORDER BY al.created_at DESC
|
||||
LIMIT $1 OFFSET $2;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*) FROM audit_logs
|
||||
WHERE (sqlc.narg('user_id')::UUID IS NULL OR user_id = sqlc.narg('user_id'))
|
||||
AND (sqlc.narg('action')::VARCHAR IS NULL OR action = sqlc.narg('action'))
|
||||
AND (sqlc.narg('resource_type')::VARCHAR IS NULL OR resource_type = sqlc.narg('resource_type'))
|
||||
AND (sqlc.narg('start_time')::TIMESTAMPTZ IS NULL OR created_at >= sqlc.narg('start_time'))
|
||||
AND (sqlc.narg('end_time')::TIMESTAMPTZ IS NULL OR created_at <= sqlc.narg('end_time'));
|
||||
@@ -0,0 +1,10 @@
|
||||
-- name: ListCategories :many
|
||||
SELECT * FROM categories
|
||||
WHERE status = 'active'
|
||||
ORDER BY sort_order ASC;
|
||||
|
||||
-- name: GetCategoryByID :one
|
||||
SELECT * FROM categories WHERE id = $1;
|
||||
|
||||
-- name: GetCategoryBySlug :one
|
||||
SELECT * FROM categories WHERE slug = $1;
|
||||
@@ -0,0 +1,37 @@
|
||||
-- name: AddFavorite :exec
|
||||
INSERT INTO app_favorites (user_id, app_id) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING;
|
||||
|
||||
-- name: RemoveFavorite :exec
|
||||
DELETE FROM app_favorites WHERE user_id = $1 AND app_id = $2;
|
||||
|
||||
-- name: IsFavorited :one
|
||||
SELECT EXISTS(SELECT 1 FROM app_favorites WHERE user_id = $1 AND app_id = $2);
|
||||
|
||||
-- name: ListUserFavorites :many
|
||||
SELECT a.*, c.name as category_name
|
||||
FROM app_favorites f
|
||||
JOIN applications a ON f.app_id = a.id
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE f.user_id = $1
|
||||
ORDER BY f.created_at DESC
|
||||
LIMIT $2 OFFSET $3;
|
||||
|
||||
-- name: UpsertRating :one
|
||||
INSERT INTO app_ratings (app_id, user_id, score, comment)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (app_id, user_id)
|
||||
DO UPDATE SET score = EXCLUDED.score, comment = EXCLUDED.comment
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetAppAvgRating :one
|
||||
SELECT COALESCE(AVG(score)::REAL, 0) as avg_rating, COUNT(*) as rating_count
|
||||
FROM app_ratings WHERE app_id = $1;
|
||||
|
||||
-- name: ListAppRatings :many
|
||||
SELECT r.*, u.name as user_name, u.avatar_url as user_avatar
|
||||
FROM app_ratings r
|
||||
JOIN users u ON r.user_id = u.id
|
||||
WHERE r.app_id = $1
|
||||
ORDER BY r.created_at DESC
|
||||
LIMIT $2 OFFSET $3;
|
||||
@@ -0,0 +1,40 @@
|
||||
-- name: CreateReview :one
|
||||
INSERT INTO app_reviews (app_id, version, submitter_id, submit_comment)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetReviewByID :one
|
||||
SELECT * FROM app_reviews WHERE id = $1;
|
||||
|
||||
-- name: ListPendingReviews :many
|
||||
SELECT r.*, a.name as app_name, a.description as app_description,
|
||||
a.icon_url as app_icon, u.name as submitter_name
|
||||
FROM app_reviews r
|
||||
JOIN applications a ON r.app_id = a.id
|
||||
JOIN users u ON r.submitter_id = u.id
|
||||
WHERE r.status = 'pending'
|
||||
ORDER BY r.submitted_at ASC
|
||||
LIMIT $1 OFFSET $2;
|
||||
|
||||
-- name: CountPendingReviews :one
|
||||
SELECT COUNT(*) FROM app_reviews WHERE status = 'pending';
|
||||
|
||||
-- name: ApproveReview :exec
|
||||
UPDATE app_reviews
|
||||
SET status = 'approved', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: RejectReview :exec
|
||||
UPDATE app_reviews
|
||||
SET status = 'rejected', reviewer_id = $2, review_comment = $3, reviewed_at = NOW()
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: WithdrawReview :exec
|
||||
UPDATE app_reviews SET status = 'withdrawn' WHERE id = $1 AND status = 'pending';
|
||||
|
||||
-- name: ListAppReviews :many
|
||||
SELECT r.*, u.name as reviewer_name
|
||||
FROM app_reviews r
|
||||
LEFT JOIN users u ON r.reviewer_id = u.id
|
||||
WHERE r.app_id = $1
|
||||
ORDER BY r.created_at DESC;
|
||||
@@ -0,0 +1,34 @@
|
||||
-- name: CreateUsageLog :one
|
||||
INSERT INTO app_usage_logs (
|
||||
app_id, user_id, dept_id, conversation_id, message_count,
|
||||
prompt_tokens, completion_tokens, total_tokens, model_name,
|
||||
estimated_cost, duration_ms, is_successful, error_message, client_type
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetRecentUsedApps :many
|
||||
SELECT DISTINCT ON (a.id) a.*, c.name as category_name, l.created_at as last_used_at
|
||||
FROM app_usage_logs l
|
||||
JOIN applications a ON l.app_id = a.id
|
||||
LEFT JOIN categories c ON a.category_id = c.id
|
||||
WHERE l.user_id = $1
|
||||
ORDER BY a.id, l.created_at DESC
|
||||
LIMIT $2;
|
||||
|
||||
-- name: GetUserStats :one
|
||||
SELECT
|
||||
COUNT(*) as total_conversations,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(estimated_cost), 0) as total_cost
|
||||
FROM app_usage_logs
|
||||
WHERE user_id = $1
|
||||
AND created_at >= $2;
|
||||
|
||||
-- name: GetOverviewStats :one
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM users WHERE status = 'active') as total_users,
|
||||
(SELECT COUNT(*) FROM applications WHERE status = 'approved') as total_apps,
|
||||
(SELECT COUNT(DISTINCT user_id) FROM app_usage_logs WHERE created_at >= $1) as active_users,
|
||||
(SELECT COUNT(*) FROM app_usage_logs WHERE created_at >= $1) as total_conversations,
|
||||
(SELECT COALESCE(SUM(total_tokens), 0) FROM app_usage_logs WHERE created_at >= $2) as monthly_tokens,
|
||||
(SELECT COALESCE(SUM(estimated_cost), 0) FROM app_usage_logs WHERE created_at >= $2) as monthly_cost;
|
||||
@@ -0,0 +1,52 @@
|
||||
-- name: GetUserByID :one
|
||||
SELECT * FROM users WHERE id = $1;
|
||||
|
||||
-- name: GetUserByEmail :one
|
||||
SELECT * FROM users WHERE email = $1;
|
||||
|
||||
-- name: GetUserByEmployeeID :one
|
||||
SELECT * FROM users WHERE employee_id = $1;
|
||||
|
||||
-- name: CreateUser :one
|
||||
INSERT INTO users (name, email, password_hash, phone, avatar_url, role, status, sso_provider, sso_external_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserProfile :one
|
||||
UPDATE users
|
||||
SET name = COALESCE(sqlc.narg('name'), name),
|
||||
phone = COALESCE(sqlc.narg('phone'), phone),
|
||||
avatar_url = COALESCE(sqlc.narg('avatar_url'), avatar_url)
|
||||
WHERE id = $1
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserRole :exec
|
||||
UPDATE users SET role = $2 WHERE id = $1;
|
||||
|
||||
-- name: UpdateUserStatus :exec
|
||||
UPDATE users SET status = $2 WHERE id = $1;
|
||||
|
||||
-- name: UpdateUserLogin :exec
|
||||
UPDATE users
|
||||
SET last_login_at = NOW(), login_count = login_count + 1
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ListUsers :many
|
||||
SELECT * FROM users
|
||||
WHERE (sqlc.narg('role')::VARCHAR IS NULL OR role = sqlc.narg('role'))
|
||||
AND (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'))
|
||||
AND (sqlc.narg('search')::VARCHAR IS NULL
|
||||
OR name ILIKE '%' || sqlc.narg('search') || '%'
|
||||
OR email ILIKE '%' || sqlc.narg('search') || '%'
|
||||
OR employee_id ILIKE '%' || sqlc.narg('search') || '%')
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2;
|
||||
|
||||
-- name: CountUsers :one
|
||||
SELECT COUNT(*) FROM users
|
||||
WHERE (sqlc.narg('role')::VARCHAR IS NULL OR role = sqlc.narg('role'))
|
||||
AND (sqlc.narg('status')::VARCHAR IS NULL OR status = sqlc.narg('status'))
|
||||
AND (sqlc.narg('search')::VARCHAR IS NULL
|
||||
OR name ILIKE '%' || sqlc.narg('search') || '%'
|
||||
OR email ILIKE '%' || sqlc.narg('search') || '%'
|
||||
OR employee_id ILIKE '%' || sqlc.narg('search') || '%');
|
||||
@@ -0,0 +1,74 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) do(ctx context.Context, method, path, apiKey string, body any) (*http.Response, error) {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
var apiErr APIError
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
|
||||
return nil, fmt.Errorf("dify API error (status %d)", resp.StatusCode)
|
||||
}
|
||||
apiErr.Status = resp.StatusCode
|
||||
return nil, &apiErr
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, method, path, apiKey string, body any, result any) error {
|
||||
resp, err := c.do(ctx, method, path, apiKey, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if result != nil {
|
||||
return json.NewDecoder(resp.Body).Decode(result)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatStream sends a chat message and returns a reader for SSE events.
|
||||
// Caller is responsible for closing the returned io.ReadCloser.
|
||||
func (c *Client) ChatStream(ctx context.Context, apiKey string, req *ChatRequest) (io.ReadCloser, error) {
|
||||
req.ResponseMode = "streaming"
|
||||
|
||||
resp, err := c.do(ctx, "POST", "/chat-messages", apiKey, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// ChatBlocking sends a chat message and waits for the complete response.
|
||||
func (c *Client) ChatBlocking(ctx context.Context, apiKey string, req *ChatRequest) (*ChatStreamEvent, error) {
|
||||
req.ResponseMode = "blocking"
|
||||
|
||||
var result ChatStreamEvent
|
||||
if err := c.doJSON(ctx, "POST", "/chat-messages", apiKey, req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ParseSSEStream parses a Dify SSE stream and calls handler for each event.
|
||||
func ParseSSEStream(reader io.Reader, handler func(event ChatStreamEvent) error) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var event ChatStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := handler(event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// ListConversations returns the user's conversation list for an app.
|
||||
func (c *Client) ListConversations(ctx context.Context, apiKey, user string, limit int, firstID string) (*ConversationListResponse, error) {
|
||||
path := fmt.Sprintf("/conversations?user=%s&limit=%d", user, limit)
|
||||
if firstID != "" {
|
||||
path += "&first_id=" + firstID
|
||||
}
|
||||
|
||||
var result ConversationListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ListMessages returns messages in a conversation.
|
||||
func (c *Client) ListMessages(ctx context.Context, apiKey, user, conversationID string, limit int, firstID string) (*MessageListResponse, error) {
|
||||
path := fmt.Sprintf("/messages?user=%s&conversation_id=%s&limit=%d", user, conversationID, limit)
|
||||
if firstID != "" {
|
||||
path += "&first_id=" + firstID
|
||||
}
|
||||
|
||||
var result MessageListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// DeleteConversation deletes a conversation.
|
||||
func (c *Client) DeleteConversation(ctx context.Context, apiKey, user, conversationID string) error {
|
||||
body := map[string]string{"user": user}
|
||||
return c.doJSON(ctx, "DELETE", "/conversations/"+conversationID, apiKey, body, nil)
|
||||
}
|
||||
|
||||
// SubmitFeedback submits feedback for a message.
|
||||
func (c *Client) SubmitFeedback(ctx context.Context, apiKey, messageID string, req *FeedbackRequest) error {
|
||||
return c.doJSON(ctx, "POST", "/messages/"+messageID+"/feedbacks", apiKey, req, nil)
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// CreateDataset creates a new knowledge base (dataset) in Dify.
|
||||
func (c *Client) CreateDataset(ctx context.Context, apiKey string, req *DatasetCreateRequest) (*Dataset, error) {
|
||||
var result Dataset
|
||||
if err := c.doJSON(ctx, "POST", "/datasets", apiKey, req, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// DeleteDataset deletes a knowledge base.
|
||||
func (c *Client) DeleteDataset(ctx context.Context, apiKey, datasetID string) error {
|
||||
return c.doJSON(ctx, "DELETE", "/datasets/"+datasetID, apiKey, nil, nil)
|
||||
}
|
||||
|
||||
// UploadDocument uploads a file to a dataset for indexing.
|
||||
func (c *Client) UploadDocument(ctx context.Context, apiKey, datasetID string, filename string, fileReader io.Reader) (*DocumentIndexingStatus, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create form file: %w", err)
|
||||
}
|
||||
if _, err := io.Copy(part, fileReader); err != nil {
|
||||
return nil, fmt.Errorf("copy file: %w", err)
|
||||
}
|
||||
|
||||
// indexing mode
|
||||
if err := writer.WriteField("indexing_technique", "high_quality"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := writer.WriteField("process_rule", `{"mode": "automatic"}`); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
path := fmt.Sprintf("/datasets/%s/document/create_by_file", datasetID)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("upload failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Dify returns document info in response
|
||||
var result struct {
|
||||
Document DocumentIndexingStatus `json:"document"`
|
||||
}
|
||||
if err := decodeJSON(resp.Body, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result.Document, nil
|
||||
}
|
||||
|
||||
// DeleteDocument deletes a document from a dataset.
|
||||
func (c *Client) DeleteDocument(ctx context.Context, apiKey, datasetID, documentID string) error {
|
||||
path := fmt.Sprintf("/datasets/%s/documents/%s", datasetID, documentID)
|
||||
return c.doJSON(ctx, "DELETE", path, apiKey, nil, nil)
|
||||
}
|
||||
|
||||
// GetDocumentIndexingStatus checks the indexing status of a document.
|
||||
func (c *Client) GetDocumentIndexingStatus(ctx context.Context, apiKey, datasetID, batch string) (*DocumentIndexingStatus, error) {
|
||||
path := fmt.Sprintf("/datasets/%s/documents/%s/indexing-status", datasetID, batch)
|
||||
var result DocumentIndexingStatus
|
||||
if err := c.doJSON(ctx, "GET", path, apiKey, nil, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func decodeJSON(r io.Reader, v any) error {
|
||||
return json.NewDecoder(r).Decode(v)
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package dify
|
||||
|
||||
import "time"
|
||||
|
||||
// --- Request types ---
|
||||
|
||||
type ChatRequest struct {
|
||||
Query string `json:"query"`
|
||||
Inputs map[string]any `json:"inputs,omitempty"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
User string `json:"user"`
|
||||
ResponseMode string `json:"response_mode"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Inputs map[string]any `json:"inputs"`
|
||||
User string `json:"user"`
|
||||
ResponseMode string `json:"response_mode"`
|
||||
}
|
||||
|
||||
type FeedbackRequest struct {
|
||||
Rating string `json:"rating"` // "like", "dislike", null
|
||||
User string `json:"user"`
|
||||
}
|
||||
|
||||
// --- Response types ---
|
||||
|
||||
type ChatStreamEvent struct {
|
||||
Event string `json:"event"`
|
||||
TaskID string `json:"task_id"`
|
||||
MessageID string `json:"message_id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Answer string `json:"answer"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type MessageEndMetadata struct {
|
||||
Usage TokenUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Inputs any `json:"inputs"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ConversationListResponse struct {
|
||||
Data []Conversation `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Query string `json:"query"`
|
||||
Answer string `json:"answer"`
|
||||
Feedback *Feedback `json:"feedback"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Inputs map[string]any `json:"inputs"`
|
||||
}
|
||||
|
||||
type Feedback struct {
|
||||
Rating string `json:"rating"`
|
||||
}
|
||||
|
||||
type MessageListResponse struct {
|
||||
Data []Message `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// --- Dataset/Knowledge types ---
|
||||
|
||||
type DatasetCreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
DocCount int `json:"document_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type DocumentIndexingStatus struct {
|
||||
ID string `json:"id"`
|
||||
IndexingStatus string `json:"indexing_status"`
|
||||
ProcessingStart string `json:"processing_started_at"`
|
||||
CompletedAt string `json:"completed_at"`
|
||||
}
|
||||
|
||||
// --- Error ---
|
||||
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Package embedding 提供文本向量化服务,支持 OpenAI 兼容的 Embedding API
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config embedding 服务配置
|
||||
type Config struct {
|
||||
APIKey string // API 密钥
|
||||
BaseURL string // API 基础 URL(OpenAI 兼容格式)
|
||||
Model string // 模型名称
|
||||
Dimensions int // 向量维度
|
||||
}
|
||||
|
||||
// Client embedding 客户端
|
||||
type Client struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建 embedding 客户端
|
||||
func NewClient(cfg Config) *Client {
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = "text-embedding-v3"
|
||||
}
|
||||
if cfg.Dimensions == 0 {
|
||||
cfg.Dimensions = 1024
|
||||
}
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// embeddingRequest OpenAI 兼容的 embedding 请求
|
||||
type embeddingRequest struct {
|
||||
Input interface{} `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// embeddingResponse OpenAI 兼容的 embedding 响应
|
||||
type embeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
// GetEmbedding 获取单条文本的向量嵌入
|
||||
func (c *Client) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
if c.cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("embedding API key not configured")
|
||||
}
|
||||
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("empty text")
|
||||
}
|
||||
// 截断过长文本
|
||||
if len([]rune(text)) > 8000 {
|
||||
text = string([]rune(text)[:8000])
|
||||
}
|
||||
|
||||
req := embeddingRequest{
|
||||
Input: text,
|
||||
Model: c.cfg.Model,
|
||||
Dimensions: c.cfg.Dimensions,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/embeddings"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("embedding error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var result embeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embedding data returned")
|
||||
}
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// GetEmbeddingBatch 批量获取文本向量嵌入
|
||||
func (c *Client) GetEmbeddingBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
results := make([][]float32, len(texts))
|
||||
for i, text := range texts {
|
||||
emb, err := c.GetEmbedding(ctx, text)
|
||||
if err != nil {
|
||||
results[i] = nil
|
||||
continue
|
||||
}
|
||||
results[i] = emb
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// IsConfigured 检查 embedding 服务是否已配置
|
||||
func (c *Client) IsConfigured() bool {
|
||||
return c.cfg.APIKey != ""
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AnthropicProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(apiKey, baseURL, defaultModel string) *AnthropicProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
if defaultModel == "" {
|
||||
defaultModel = "claude-sonnet-4-20250514"
|
||||
}
|
||||
return &AnthropicProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
model: defaultModel,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Name() string { return "anthropic" }
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
system, msgs := extractSystemMessage(req.Messages)
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
aReq := anthropicRequest{
|
||||
Model: model,
|
||||
Messages: toAnthropicMessages(msgs),
|
||||
System: system,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: req.Temperature,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(aReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var aResp anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&aResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, c := range aResp.Content {
|
||||
if c.Type == "text" {
|
||||
content += c.Text
|
||||
}
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: aResp.Model,
|
||||
PromptTokens: aResp.Usage.InputTokens,
|
||||
CompletionTokens: aResp.Usage.OutputTokens,
|
||||
TotalTokens: aResp.Usage.InputTokens + aResp.Usage.OutputTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
system, msgs := extractSystemMessage(req.Messages)
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
aReq := anthropicRequest{
|
||||
Model: model,
|
||||
Messages: toAnthropicMessages(msgs),
|
||||
System: system,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: req.Temperature,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(aReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func extractSystemMessage(msgs []Message) (string, []Message) {
|
||||
system := ""
|
||||
var filtered []Message
|
||||
for _, m := range msgs {
|
||||
if m.Role == RoleSystem {
|
||||
system = m.Content
|
||||
} else {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return system, filtered
|
||||
}
|
||||
|
||||
func toAnthropicMessages(msgs []Message) []anthropicMessage {
|
||||
out := make([]anthropicMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = anthropicMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Manager manages multiple LLM providers and routes requests.
|
||||
type Manager struct {
|
||||
providers map[string]Provider
|
||||
fallback string
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
providers: make(map[string]Provider),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Register(name string, provider Provider) {
|
||||
m.providers[name] = provider
|
||||
if m.fallback == "" {
|
||||
m.fallback = name
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SetFallback(name string) {
|
||||
m.fallback = name
|
||||
}
|
||||
|
||||
func (m *Manager) GetProvider(name string) (Provider, error) {
|
||||
if p, ok := m.providers[name]; ok {
|
||||
return p, nil
|
||||
}
|
||||
if p, ok := m.providers[m.fallback]; ok {
|
||||
return p, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no provider found: %s", name)
|
||||
}
|
||||
|
||||
// Chat performs a blocking chat completion using the specified provider.
|
||||
func (m *Manager) Chat(ctx context.Context, providerName string, req *ChatRequest) (*ChatResponse, error) {
|
||||
provider, err := m.GetProvider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider.ChatCompletion(ctx, req)
|
||||
}
|
||||
|
||||
// ChatStream performs a streaming chat and returns the raw SSE body.
|
||||
func (m *Manager) ChatStream(ctx context.Context, providerName string, req *ChatRequest) (io.ReadCloser, error) {
|
||||
provider, err := m.GetProvider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Stream = true
|
||||
return provider.ChatStream(ctx, req)
|
||||
}
|
||||
|
||||
// StreamEvent represents a normalized SSE event for the frontend.
|
||||
type StreamEvent struct {
|
||||
Event string `json:"event"`
|
||||
Answer string `json:"answer,omitempty"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// TransformOpenAIStream reads an OpenAI SSE stream and writes normalized events to the writer.
|
||||
func TransformOpenAIStream(reader io.Reader, write func(event StreamEvent)) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var totalContent string
|
||||
var model string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
write(StreamEvent{
|
||||
Event: "message_end",
|
||||
Usage: &Usage{
|
||||
TotalTokens: estimateTokens(totalContent),
|
||||
Model: model,
|
||||
},
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
model = chunk.Model
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
totalContent += content
|
||||
write(StreamEvent{
|
||||
Event: "message",
|
||||
Answer: content,
|
||||
MessageID: chunk.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// TransformAnthropicStream reads an Anthropic SSE stream and writes normalized events.
|
||||
func TransformAnthropicStream(reader io.Reader, write func(event StreamEvent)) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var model string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if m, ok := msg["model"].(string); ok {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if delta, ok := event["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
write(StreamEvent{
|
||||
Event: "message",
|
||||
Answer: text,
|
||||
})
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
if usage, ok := event["usage"].(map[string]any); ok {
|
||||
outputTokens := int(getFloat(usage, "output_tokens"))
|
||||
write(StreamEvent{
|
||||
Event: "message_end",
|
||||
Usage: &Usage{
|
||||
CompletionTokens: outputTokens,
|
||||
TotalTokens: outputTokens,
|
||||
Model: model,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func getFloat(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func estimateTokens(text string) int {
|
||||
return len(text) / 4
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(apiKey, baseURL, defaultModel string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if defaultModel == "" {
|
||||
defaultModel = "gpt-4o-mini"
|
||||
}
|
||||
return &OpenAIProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
model: defaultModel,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string { return "openai" }
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var oaiResp openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
if len(oaiResp.Choices) > 0 {
|
||||
content = oaiResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: oaiResp.Model,
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
CompletionTokens: oaiResp.Usage.CompletionTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func toOpenAIMessages(msgs []Message) []openAIMessage {
|
||||
out := make([]openAIMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = openAIMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role Role `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
}
|
||||
|
||||
type StreamDelta struct {
|
||||
Content string `json:"content"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// Provider abstracts an LLM API (OpenAI, Anthropic, etc.)
|
||||
type Provider interface {
|
||||
ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
||||
ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
Type string `json:"type"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user