Files
GovAI/server/internal/handler/chat_llm.go
T
2026-06-15 23:48:37 +08:00

1171 lines
36 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/enterprise-ai-platform/server/internal/middleware"
"github.com/enterprise-ai-platform/server/internal/response"
"github.com/enterprise-ai-platform/server/pkg/embedding"
"github.com/enterprise-ai-platform/server/pkg/llm"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
)
type LLMChatHandler struct {
pool *pgxpool.Pool
manager *llm.Manager
provider string
rdb *redis.Client
workerURL string
embedder *embedding.Client
}
func NewLLMChatHandler(pool *pgxpool.Pool, manager *llm.Manager, defaultProvider string, rdb *redis.Client, workerURL string, embedder *embedding.Client) *LLMChatHandler {
return &LLMChatHandler{pool: pool, manager: manager, provider: defaultProvider, rdb: rdb, workerURL: workerURL, embedder: embedder}
}
type llmChatRequest struct {
Message string `json:"message"`
ConversationID string `json:"conversation_id,omitempty"`
}
type appCfg struct {
SystemPrompt string
Model string
Temp float64
MaxTok int
KnowledgeBaseID *string
AppType string
OrgID *string
AppName string
}
// sameOrgApp 同机构其他应用信息,用于超范围引导跳转
type sameOrgApp struct {
Name string
Slug string
}
func (h *LLMChatHandler) loadAppConfig(ctx context.Context, appID string) (*appCfg, error) {
var cfg appCfg
err := h.pool.QueryRow(ctx,
`SELECT COALESCE(app_config->>'system_prompt', ''),
COALESCE(app_config->>'model', ''),
COALESCE(temperature, 0.7),
COALESCE(max_tokens, 4096),
knowledge_base_id::text,
COALESCE(dify_app_type, ''),
org_id::text,
COALESCE(name, '')
FROM applications WHERE id = $1 AND status = 'approved'`, appID,
).Scan(&cfg.SystemPrompt, &cfg.Model, &cfg.Temp, &cfg.MaxTok, &cfg.KnowledgeBaseID, &cfg.AppType, &cfg.OrgID, &cfg.AppName)
if err != nil {
return nil, err
}
return &cfg, nil
}
// loadSameOrgApps 查询同机构内其他应用(排除当前应用),用于超范围引导
func (h *LLMChatHandler) loadSameOrgApps(ctx context.Context, orgID *string, currentAppID string) []sameOrgApp {
if orgID == nil || *orgID == "" {
return nil
}
rows, err := h.pool.Query(ctx,
`SELECT name, slug FROM applications
WHERE org_id = $1 AND id != $2 AND status = 'approved'
ORDER BY name`, *orgID, currentAppID)
if err != nil {
return nil
}
defer rows.Close()
var apps []sameOrgApp
for rows.Next() {
var a sameOrgApp
if rows.Scan(&a.Name, &a.Slug) == nil {
apps = append(apps, a)
}
}
return apps
}
// cleanQueryForSearch 去除标点符号和常见语气词,提取有效关键词
func cleanQueryForSearch(query string) []string {
// 去除中英文标点和常见语气助词
cleaned := query
for _, ch := range []string{
"", "?", "", "!", "。", ".", "", ",",
"、", "", ":", "", ";", "", "", "(", ")",
"《", "》", "【", "】", "\n", "\t",
"\u201c", "\u201d", "\u2018", "\u2019",
} {
cleaned = strings.ReplaceAll(cleaned, ch, " ")
}
cleaned = strings.TrimSpace(cleaned)
// 去除常见语气词和停用词
stopWords := []string{
"是什么", "是啥", "有哪些", "怎么样", "怎么办",
"的", "了", "在", "和", "与", "或", "等", "中", "为", "被",
"关于", "请问", "什么", "如何", "怎么", "哪些", "哪个",
}
for _, s := range stopWords {
if strings.HasSuffix(cleaned, s) && len([]rune(cleaned)) > len([]rune(s))+2 {
cleaned = strings.TrimSuffix(cleaned, s)
}
}
// 按空格拆分
tokens := strings.Fields(cleaned)
// 对中文长token按常见法律/政策术语边界拆分
var result []string
for _, tok := range tokens {
runes := []rune(tok)
if len(runes) < 2 {
continue
}
result = append(result, tok)
// 长关键词用滑动窗口拆分为2-4字的短语,提高检索召回率
if len(runes) >= 4 {
for size := 2; size <= 4 && size <= len(runes); size++ {
for i := 0; i+size <= len(runes); i++ {
sub := string(runes[i : i+size])
// 去重且过滤停用词
isDup := false
for _, r := range result {
if r == sub {
isDup = true
break
}
}
isStop := false
for _, sw := range stopWords {
if sub == sw {
isStop = true
break
}
}
if !isDup && !isStop {
result = append(result, sub)
}
}
}
}
}
// 限制关键词数量,避免查询过于复杂
if len(result) > 8 {
result = result[:8]
}
if len(result) == 0 && len([]rune(cleaned)) >= 2 {
result = append(result, cleaned)
}
return result
}
func (h *LLMChatHandler) retrieveKnowledge(ctx context.Context, kbID, query string, limit int) (string, error) {
// 混合检索策略:优先向量搜索,降级到关键词搜索
var parts []string
// 1. 尝试向量语义搜索(基于 knowledge_chunks 表)
if h.embedder != nil && h.embedder.IsConfigured() {
vectorResults := h.vectorSearch(ctx, kbID, query, limit)
if len(vectorResults) > 0 {
parts = append(parts, vectorResults...)
log.Debug().Int("vector_results", len(vectorResults)).Msg("vector search completed")
}
}
// 2. 关键词搜索补充(从 knowledge_chunks 或 knowledge_documents
keywordResults := h.keywordSearch(ctx, kbID, query, limit)
for _, kr := range keywordResults {
// 去重:检查是否已在向量结果中
duplicate := false
for _, existing := range parts {
if existing == kr {
duplicate = true
break
}
}
if !duplicate {
parts = append(parts, kr)
}
}
// 限制总结果数
if len(parts) > limit {
parts = parts[:limit]
}
if len(parts) == 0 {
return "", nil
}
return strings.Join(parts, "\n\n---\n\n"), nil
}
// vectorSearch 向量语义搜索(基于 knowledge_chunks + pgvector
func (h *LLMChatHandler) vectorSearch(ctx context.Context, kbID, query string, limit int) []string {
queryEmbedding, err := h.embedder.GetEmbedding(ctx, query)
if err != nil {
log.Warn().Err(err).Msg("query embedding failed, falling back to keyword search")
return nil
}
vecStr := float32SliceToVectorStr(queryEmbedding)
rows, err := h.pool.Query(ctx, `
SELECT kc.content, kd.name,
1 - (kc.embedding <=> $2::vector) AS similarity
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.doc_id = kd.id
WHERE kc.kb_id = $1
AND kc.embedding IS NOT NULL
AND 1 - (kc.embedding <=> $2::vector) > 0.3
ORDER BY kc.embedding <=> $2::vector
LIMIT $3`,
kbID, vecStr, limit)
if err != nil {
log.Warn().Err(err).Msg("vector search query failed")
return nil
}
defer rows.Close()
var results []string
for rows.Next() {
var content, docName string
var similarity float64
if err := rows.Scan(&content, &docName, &similarity); err != nil {
continue
}
trimmed := content
if len([]rune(trimmed)) > 2000 {
trimmed = string([]rune(trimmed)[:2000]) + "..."
}
results = append(results, fmt.Sprintf("【%s · 相似度%.0f%%】\n%s", docName, similarity*100, trimmed))
}
return results
}
// keywordSearch 关键词搜索(降级方案,搜索 chunks 和 documents
func (h *LLMChatHandler) keywordSearch(ctx context.Context, kbID, query string, limit int) []string {
keywords := cleanQueryForSearch(query)
if len(keywords) == 0 {
return nil
}
// 先搜索 chunks 表
var conditions []string
var args []any
args = append(args, kbID) // $1
for _, kw := range keywords {
idx := len(args) + 1
placeholder := fmt.Sprintf("$%d", idx)
args = append(args, "%"+kw+"%")
conditions = append(conditions, fmt.Sprintf("kc.content ILIKE %s", placeholder))
}
limitIdx := len(args) + 1
args = append(args, limit)
sql := fmt.Sprintf(`
SELECT kc.content, kd.name
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.doc_id = kd.id
WHERE kc.kb_id = $1
AND (%s)
ORDER BY kc.created_at DESC
LIMIT $%d`, strings.Join(conditions, " OR "), limitIdx)
rows, err := h.pool.Query(ctx, sql, args...)
if err == nil {
defer rows.Close()
var results []string
for rows.Next() {
var content, docName string
if err := rows.Scan(&content, &docName); err != nil {
continue
}
trimmed := content
if len([]rune(trimmed)) > 2000 {
trimmed = string([]rune(trimmed)[:2000]) + "..."
}
results = append(results, fmt.Sprintf("【%s】\n%s", docName, trimmed))
}
if len(results) > 0 {
return results
}
}
// 降级:搜索 knowledge_documents 原文(没有分片的旧数据)
args2 := []any{kbID}
var conditions2 []string
for _, kw := range keywords {
idx := len(args2) + 1
placeholder := fmt.Sprintf("$%d", idx)
args2 = append(args2, "%"+kw+"%")
conditions2 = append(conditions2, fmt.Sprintf("(name ILIKE %s OR content ILIKE %s)", placeholder, placeholder))
}
limitIdx2 := len(args2) + 1
args2 = append(args2, limit)
sql2 := fmt.Sprintf(`
SELECT name, content
FROM knowledge_documents
WHERE kb_id = $1
AND content IS NOT NULL AND content != ''
AND (%s)
ORDER BY created_at DESC
LIMIT $%d`, strings.Join(conditions2, " OR "), limitIdx2)
rows2, err := h.pool.Query(ctx, sql2, args2...)
if err != nil {
return nil
}
defer rows2.Close()
var results []string
for rows2.Next() {
var name, content string
if err := rows2.Scan(&name, &content); err != nil {
continue
}
trimmed := content
if len([]rune(trimmed)) > 3000 {
trimmed = string([]rune(trimmed)[:3000]) + "..."
}
results = append(results, fmt.Sprintf("【%s】\n%s", name, trimmed))
}
return results
}
func (h *LLMChatHandler) loadConversationHistory(ctx context.Context, appID, userID, convID string, maxTurns int) []llm.Message {
rows, err := h.pool.Query(ctx, `
SELECT user_message, COALESCE(ai_response, '')
FROM app_usage_logs
WHERE app_id = $1 AND user_id = $2 AND conversation_id = $3
ORDER BY created_at ASC`, appID, userID, convID)
if err != nil {
return nil
}
defer rows.Close()
var history []llm.Message
for rows.Next() {
var userMsg, aiResp string
if err := rows.Scan(&userMsg, &aiResp); err != nil {
continue
}
if userMsg != "" {
history = append(history, llm.Message{Role: llm.RoleUser, Content: userMsg})
}
if aiResp != "" {
history = append(history, llm.Message{Role: llm.RoleAssistant, Content: aiResp})
}
}
if maxTurns > 0 && len(history) > maxTurns*2 {
history = history[len(history)-maxTurns*2:]
}
return history
}
func (h *LLMChatHandler) buildMessages(systemPrompt, knowledgeContext string, hasKB bool, history []llm.Message, userMessage string, sameOrgApps ...[]sameOrgApp) []llm.Message {
var msgs []llm.Message
finalSystem := systemPrompt
// 注入同机构应用路由表(用于超范围引导跳转)
if len(sameOrgApps) > 0 && len(sameOrgApps[0]) > 0 {
finalSystem += "\n\n## 超范围引导(必须遵守)\n\n"
finalSystem += "当用户的问题不在本应用的处理范围内时,你必须:\n"
finalSystem += "1. 明确告知用户该问题不在本应用处理范围内\n"
finalSystem += "2. 推荐本机构内更合适的应用,使用以下格式(系统会自动渲染为可点击的跳转链接):\n"
finalSystem += " [[推荐应用:应用名称:应用slug]]\n"
finalSystem += "3. 绝不可对不属于本应用职责的问题强行生成回答\n\n"
finalSystem += "本机构可用的应用列表:\n"
for _, app := range sameOrgApps[0] {
finalSystem += fmt.Sprintf("- %sslug: %s\n", app.Name, app.Slug)
}
finalSystem += "\n推荐示例:建议使用 [[推荐应用:法律咨询助手:legal-consult]] 来处理此类问题。\n"
finalSystem += "注意:如果用户的问题不属于本机构任何应用的范围(如需要联系其他政府部门),则直接用文字说明应联系的部门,不使用上述格式。\n"
}
// 通用红线规则:适用于所有应用
finalSystem += `
## 绝对红线(所有应用必须遵守)
1. **禁止编造事实**:不得虚构任何调查结果、检查记录、走访情况、证据材料等事实性内容。所有事实描述必须且只能来自用户提供的输入内容或知识库检索结果。如果用户未提供相关事实,应明确标注「(待补充)」或提示用户补充,绝不可凭空捏造。
2. **禁止虚构法规条文**:只能引用知识库中存在的法规原文或用户明确提供的法规信息,不得杜撰法条内容、编号或文件名称。如知识库中未检索到相关法规,应注明「(建议补充相关法规依据)」。
3. **管辖权与职责范围判断(必须首先执行)**:收到用户输入后,必须先判断该问题是否属于当前应用的职责范围。判断依据为上方的功能介绍和系统定位。
- **属于职责范围**:正常处理并生成回答。
- **不属于职责范围**:必须明确告知用户该问题不在本应用处理范围内,说明原因,并推荐合适的处理渠道或机构。常见分流指引:
- 消费纠纷/合同纠纷 → 市场监管部门(12315)或法院民事诉讼
- 劳动争议 → 劳动仲裁委员会
- 刑事案件 → 公安机关
- 民事侵权 → 法院民事诉讼
- 行政复议/行政诉讼 → 对应上级行政机关或法院
- 信访事项 → 对应信访部门
- 税务问题 → 税务机关
- 医疗纠纷 → 卫健部门或医调委
- **绝不可**:对不属于本应用职责的问题强行生成专业回答,这会误导用户。
`
if hasKB {
finalSystem += `
## 来源标注规则(必须严格遵守,每一条都不可省略)
本系统已接入知识库。你的回答**必须在正文中内联标注**信息来源,使用以下**双方括号**固定格式(系统会自动将其渲染为彩色徽章):
### 内联标注格式(正文中每个引用点都必须标注)
1. 引用知识库文献时,在引用内容后紧跟:[[知识库:文献名称]]
- 如有具体条款:[[知识库:文献名称:第X条]]
- 示例:"高新技术企业可享受15%优惠税率 [[知识库:高新技术企业认定管理办法:第四条]]"
2. 来自AI模型自身知识(非知识库内容)时,在该内容后紧跟:[[AI建议]]
- 示例:"该政策于2016年首次发布 [[AI建议]]"
- **重要:任何不是直接来自知识库原文的分析、解读、建议、补充说明,都必须标注 [[AI建议]]**
### 末尾来源汇总(必须附加)
回答末尾必须附加来源汇总块(用 blockquote 格式),要求完整引用原文:
> **来源说明**
>
> **知识库引用:**
> - 【文献名称1】:摘录知识库中被引用的原文段落(保留条款编号和原始表述,不要省略或改写)
> - 【文献名称2】:摘录对应原文...
> (如未引用知识库则写"无")
>
> **AI建议:**
> - 说明哪些内容来自AI自身知识的补充分析或解读
### 标注检查清单
- 正文中是否每个引用知识库的地方都标注了 [[知识库:文献名称]]?
- 正文中是否每个AI补充分析的地方都标注了 [[AI建议]]?
- 末尾是否有完整的来源汇总块?
重要:知识库引用部分必须逐条列出被引用的原文内容,不能只写文献名称,要把知识库中的原始段落完整摘录出来,便于用户核实溯源。
`
if knowledgeContext != "" {
finalSystem += "### 知识库检索结果\n\n以下是从知识库中检索到的相关文献,请优先基于这些内容回答:\n\n" + knowledgeContext
} else {
finalSystem += "### 知识库检索结果\n\n当前知识库中未检索到与用户问题直接相关的文献。请使用AI知识回答,并在每句标注 [[AI建议]]。\n"
}
} else {
finalSystem += `
## 来源标注规则(必须严格遵守)
你的回答内容全部来自AI模型的自身知识。请在回答末尾附加:
> **来源说明**
> - 以上内容为AI建议,仅供参考,请以官方文件和专业意见为准。
`
}
// 通用专业标准和输出质量要求
finalSystem += `
## 专业标准(最佳实践者角色)
你是本领域经验最丰富的专业人员,必须遵守:
1. **专业深度**:回答必须有专业深度,不泛泛而谈,使用本领域的专业术语
2. **逻辑结构**:按"问题分析 → 依据引用 → 结论建议"层层递进
3. **明确意见**:在合规前提下给出明确的意见和建议,不模棱两可
4. **实操可行**:建议必须具体、可执行,考虑实际操作可行性
5. **风险预判**:主动识别并提示潜在风险
6. **信息不足时**:主动向用户确认缺失信息,而非自行假设
## 输出质量标准
1. **结构化输出**:使用 Markdown 标题、列表、表格组织内容,禁止输出大段无格式纯文字
2. **重要信息高亮**:关键结论、风险提示、注意事项使用 **加粗** 标注
3. **法规引用格式**:统一为「《法规名》第X条第X款」格式
4. **完整性自检**:输出前自检是否遗漏关键要素
`
if finalSystem != "" {
msgs = append(msgs, llm.Message{Role: llm.RoleSystem, Content: finalSystem})
}
msgs = append(msgs, history...)
msgs = append(msgs, llm.Message{Role: llm.RoleUser, Content: userMessage})
return msgs
}
func (h *LLMChatHandler) Chat(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req llmChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Message == "" {
response.BadRequest(w, "消息不能为空")
return
}
cfg, err := h.loadAppConfig(r.Context(), appID)
if err != nil {
response.NotFound(w, "应用不存在或未上架")
return
}
// PPT 生成应用走专用管线
if cfg.AppType == "ppt_generator" {
h.handlePPTChat(w, r, appID, userID.String(), req.Message, req.ConversationID)
return
}
hasKB := cfg.KnowledgeBaseID != nil && *cfg.KnowledgeBaseID != ""
var knowledgeCtx string
if hasKB {
knowledgeCtx, _ = h.retrieveKnowledge(r.Context(), *cfg.KnowledgeBaseID, req.Message, 3)
}
// 加载同机构应用列表,用于超范围引导跳转
orgApps := h.loadSameOrgApps(r.Context(), cfg.OrgID, appID)
convID := req.ConversationID
isNewConv := convID == ""
if isNewConv {
convID = uuid.New().String()
}
var history []llm.Message
if !isNewConv {
history = h.loadConversationHistory(r.Context(), appID, userID.String(), convID, 10)
}
llmReq := &llm.ChatRequest{
Model: cfg.Model,
Messages: h.buildMessages(cfg.SystemPrompt, knowledgeCtx, hasKB, history, req.Message, orgApps),
Temperature: cfg.Temp,
MaxTokens: cfg.MaxTok,
Stream: true,
}
startTime := time.Now()
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
return
}
defer body.Close()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
msgID := uuid.New().String()
var totalTokens int
var modelName string
var fullResponse strings.Builder
firstEvent := map[string]string{"conversation_id": convID, "message_id": msgID}
data, _ := json.Marshal(firstEvent)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
transform := llm.TransformOpenAIStream
if h.provider == "anthropic" {
transform = llm.TransformAnthropicStream
}
_ = transform(body, func(event llm.StreamEvent) {
if event.MessageID == "" {
event.MessageID = msgID
}
if event.Answer != "" {
fullResponse.WriteString(event.Answer)
}
if event.Usage != nil {
totalTokens = event.Usage.TotalTokens
modelName = event.Usage.Model
}
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
})
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
duration := time.Since(startTime).Milliseconds()
go h.recordUsage(appID, userID.String(), convID, req.Message, fullResponse.String(), totalTokens, modelName, duration)
if isNewConv {
go h.generateConversationName(appID, userID.String(), convID, req.Message)
}
}
func (h *LLMChatHandler) Completion(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req llmChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效的请求格式")
return
}
if req.Message == "" {
response.BadRequest(w, "消息不能为空")
return
}
cfg, err := h.loadAppConfig(r.Context(), appID)
if err != nil {
response.NotFound(w, "应用不存在或未上架")
return
}
hasKB := cfg.KnowledgeBaseID != nil && *cfg.KnowledgeBaseID != ""
var knowledgeCtx string
if hasKB {
knowledgeCtx, _ = h.retrieveKnowledge(r.Context(), *cfg.KnowledgeBaseID, req.Message, 3)
}
// 加载同机构应用列表,用于超范围引导跳转
orgApps := h.loadSameOrgApps(r.Context(), cfg.OrgID, appID)
llmReq := &llm.ChatRequest{
Model: cfg.Model,
Messages: h.buildMessages(cfg.SystemPrompt, knowledgeCtx, hasKB, nil, req.Message, orgApps),
Temperature: cfg.Temp,
MaxTokens: cfg.MaxTok,
Stream: true,
}
startTime := time.Now()
convID := uuid.New().String()
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
if err != nil {
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
return
}
defer body.Close()
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
msgID := uuid.New().String()
var totalTokens int
var modelName string
var fullResponse strings.Builder
firstEvent := map[string]string{"conversation_id": convID, "message_id": msgID}
data, _ := json.Marshal(firstEvent)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
transform := llm.TransformOpenAIStream
if h.provider == "anthropic" {
transform = llm.TransformAnthropicStream
}
_ = transform(body, func(event llm.StreamEvent) {
if event.MessageID == "" {
event.MessageID = msgID
}
if event.Answer != "" {
fullResponse.WriteString(event.Answer)
}
if event.Usage != nil {
totalTokens = event.Usage.TotalTokens
modelName = event.Usage.Model
}
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
})
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
duration := time.Since(startTime).Milliseconds()
go h.recordUsage(appID, userID.String(), convID, req.Message, fullResponse.String(), totalTokens, modelName, duration)
go h.generateConversationName(appID, userID.String(), convID, req.Message)
}
func (h *LLMChatHandler) Conversations(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
rows, err := h.pool.Query(r.Context(), `
SELECT conversation_id,
MIN(created_at) AS first_at,
MAX(created_at) AS last_at,
COUNT(*) AS msg_count,
(SELECT COALESCE(user_message, '') FROM app_usage_logs u2
WHERE u2.conversation_id = u.conversation_id AND u2.user_message != ''
ORDER BY u2.created_at LIMIT 1) AS first_msg
FROM app_usage_logs u
WHERE app_id = $1 AND user_id = $2 AND conversation_id IS NOT NULL
GROUP BY conversation_id
ORDER BY last_at DESC LIMIT 50`, appID, userID)
if err != nil {
response.InternalError(w, "查询对话列表失败")
return
}
defer rows.Close()
customNames := make(map[string]string)
nameRows, err := h.pool.Query(r.Context(),
`SELECT conversation_id, name FROM conversation_names WHERE app_id = $1 AND user_id = $2`,
appID, userID)
if err == nil {
defer nameRows.Close()
for nameRows.Next() {
var cid, n string
if nameRows.Scan(&cid, &n) == nil {
customNames[cid] = n
}
}
}
var convs []map[string]any
for rows.Next() {
var convID string
var firstAt, lastAt time.Time
var msgCount int
var firstMsg *string
if err := rows.Scan(&convID, &firstAt, &lastAt, &msgCount, &firstMsg); err != nil {
continue
}
name := "新对话"
if cn, ok := customNames[convID]; ok && cn != "" {
name = cn
} else if firstMsg != nil && *firstMsg != "" {
name = *firstMsg
runes := []rune(name)
if len(runes) > 30 {
name = string(runes[:30]) + "..."
}
}
convs = append(convs, map[string]any{
"id": convID, "name": name, "created_at": firstAt, "updated_at": lastAt, "msg_count": msgCount,
})
}
if convs == nil {
convs = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"data": convs})
}
func (h *LLMChatHandler) RenameConversation(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
convID := chi.URLParam(r, "convId")
userID := middleware.GetUserID(r.Context())
var req struct {
Name string `json:"name"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
response.BadRequest(w, "无效请求格式")
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
response.BadRequest(w, "名称不能为空")
return
}
runes := []rune(name)
if len(runes) > 50 {
name = string(runes[:50])
}
_, err := h.pool.Exec(r.Context(), `
INSERT INTO conversation_names (app_id, user_id, conversation_id, name, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (app_id, user_id, conversation_id)
DO UPDATE SET name = EXCLUDED.name, updated_at = now()`,
appID, userID, convID, name)
if err != nil {
response.InternalError(w, "重命名失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已重命名", "name": name})
}
func (h *LLMChatHandler) Messages(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
convID := chi.URLParam(r, "convId")
userID := middleware.GetUserID(r.Context())
rows, err := h.pool.Query(r.Context(), `
SELECT user_message, COALESCE(ai_response, ''), created_at
FROM app_usage_logs
WHERE app_id = $1 AND user_id = $2 AND conversation_id = $3
ORDER BY created_at ASC`, appID, userID, convID)
if err != nil {
response.InternalError(w, "查询消息失败")
return
}
defer rows.Close()
var msgs []map[string]any
for rows.Next() {
var userMsg, aiResp string
var createdAt time.Time
if err := rows.Scan(&userMsg, &aiResp, &createdAt); err != nil {
continue
}
if userMsg != "" {
msgs = append(msgs, map[string]any{
"id": fmt.Sprintf("u-%d", createdAt.UnixMilli()), "role": "user", "content": userMsg, "created_at": createdAt,
})
}
if aiResp != "" {
msgs = append(msgs, map[string]any{
"id": fmt.Sprintf("a-%d", createdAt.UnixMilli()), "role": "assistant", "content": aiResp, "created_at": createdAt,
})
}
}
if msgs == nil {
msgs = []map[string]any{}
}
response.JSON(w, http.StatusOK, map[string]any{"data": msgs})
}
func (h *LLMChatHandler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
convID := chi.URLParam(r, "convId")
userID := middleware.GetUserID(r.Context())
appID := chi.URLParam(r, "id")
_, err := h.pool.Exec(r.Context(),
`DELETE FROM app_usage_logs WHERE conversation_id = $1 AND user_id = $2 AND app_id = $3`,
convID, userID, appID)
if err != nil {
response.InternalError(w, "删除对话失败")
return
}
response.JSON(w, http.StatusOK, map[string]string{"message": "已删除"})
}
func (h *LLMChatHandler) BatchDeleteConversations(w http.ResponseWriter, r *http.Request) {
appID := chi.URLParam(r, "id")
userID := middleware.GetUserID(r.Context())
var req struct {
ConversationIDs []string `json:"conversation_ids"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || len(req.ConversationIDs) == 0 {
response.BadRequest(w, "请提供要删除的对话ID列表")
return
}
ids := make([]any, len(req.ConversationIDs))
placeholders := make([]string, len(req.ConversationIDs))
for i, id := range req.ConversationIDs {
ids[i] = id
placeholders[i] = fmt.Sprintf("$%d", i+3)
}
query := fmt.Sprintf(
`DELETE FROM app_usage_logs WHERE app_id = $1 AND user_id = $2 AND conversation_id IN (%s)`,
strings.Join(placeholders, ","))
args := append([]any{appID, userID}, ids...)
result, err := h.pool.Exec(r.Context(), query, args...)
if err != nil {
response.InternalError(w, "批量删除失败")
return
}
response.JSON(w, http.StatusOK, map[string]any{
"message": "已删除",
"deleted": result.RowsAffected(),
})
}
func (h *LLMChatHandler) Feedback(w http.ResponseWriter, r *http.Request) {
response.JSON(w, http.StatusOK, map[string]string{"message": "反馈已收到"})
}
// ==================== PPT 生成聊天处理 ====================
func (h *LLMChatHandler) handlePPTChat(w http.ResponseWriter, r *http.Request, appID, userID, message, existingConvID string) {
convID := existingConvID
if convID == "" {
convID = uuid.New().String()
}
msgID := uuid.New().String()
taskID := uuid.New().String()
// 解析用户消息,提取标题和内容
title, sourceType, sourceContent := h.parsePPTMessage(message)
configJSON, _ := json.Marshal(map[string]any{
"format": "ppt169",
"page_count": 10,
"style": "general",
"language": "zh",
})
// 写入 ppt_tasks 表
_, err := h.pool.Exec(r.Context(),
`INSERT INTO ppt_tasks (id, user_id, title, source_type, source_content, config)
VALUES ($1, $2, $3, $4, $5, $6)`,
taskID, userID, title, sourceType, sourceContent, configJSON,
)
if err != nil {
response.InternalError(w, "创建 PPT 任务失败: "+err.Error())
return
}
// 推送到 Redis 队列
taskMsg, _ := json.Marshal(map[string]string{"task_id": taskID})
h.rdb.LPush(r.Context(), "ppt:tasks", taskMsg)
// 设置 SSE 流式响应
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalError(w, "Streaming not supported")
return
}
// 发送首个事件(conversation_id + message_id
firstEvent := map[string]string{"conversation_id": convID, "message_id": msgID}
data, _ := json.Marshal(firstEvent)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
// 发送初始消息
var fullResponse strings.Builder
h.sendPPTEvent(w, flusher, &fullResponse, msgID, "📊 PPT 生成任务已创建,正在处理中...\n\n")
startTime := time.Now()
lastStatus := ""
lastProgress := 0
// 轮询任务状态
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
timeout := time.After(10 * time.Minute)
for {
select {
case <-r.Context().Done():
return
case <-timeout:
h.sendPPTEvent(w, flusher, &fullResponse, msgID, "\n\n⏱️ 任务超时,请稍后在任务列表中查看结果。")
goto done
case <-ticker.C:
status, progress, statusMsg := h.pollPPTStatus(r.Context(), taskID)
if status != lastStatus || progress != lastProgress {
lastStatus = status
lastProgress = progress
progressBar := h.formatProgress(progress)
line := fmt.Sprintf("**[%d%%]** %s %s\n", progress, progressBar, statusMsg)
h.sendPPTEvent(w, flusher, &fullResponse, msgID, line)
}
if status == "completed" {
downloadURL := fmt.Sprintf("/api/v1/ppt/tasks/%s/download", taskID)
finalMsg := fmt.Sprintf("\n\n✅ **PPT 生成完成!**\n\n📥 [点击下载 PPTX 文件](%s)\n\n> 提示:也可在「PPT 任务列表」中找到此文件。", downloadURL)
h.sendPPTEvent(w, flusher, &fullResponse, msgID, finalMsg)
goto done
}
if status == "failed" {
h.sendPPTEvent(w, flusher, &fullResponse, msgID, "\n\n❌ **PPT 生成失败**,请检查输入内容后重试。")
goto done
}
}
}
done:
fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
duration := time.Since(startTime).Milliseconds()
go h.recordUsage(appID, userID, convID, message, fullResponse.String(), 0, "ppt-generator", duration)
}
func (h *LLMChatHandler) sendPPTEvent(w http.ResponseWriter, flusher http.Flusher, fullResp *strings.Builder, msgID, text string) {
fullResp.WriteString(text)
event := map[string]any{
"event": "message",
"answer": text,
"message_id": msgID,
}
data, _ := json.Marshal(event)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
func (h *LLMChatHandler) parsePPTMessage(message string) (title, sourceType, sourceContent string) {
sourceType = "text"
sourceContent = message
// 检测 URL
if strings.HasPrefix(message, "http://") || strings.HasPrefix(message, "https://") {
lines := strings.SplitN(message, "\n", 2)
sourceType = "url"
sourceContent = strings.TrimSpace(lines[0])
if len(lines) > 1 {
title = strings.TrimSpace(lines[1])
}
if title == "" {
title = "网页内容PPT"
}
return
}
// 从文本中提取标题(取第一行或前30字)
lines := strings.SplitN(message, "\n", 2)
title = strings.TrimSpace(lines[0])
runes := []rune(title)
if len(runes) > 30 {
title = string(runes[:30])
}
if title == "" {
title = "AI生成PPT"
}
return
}
func (h *LLMChatHandler) pollPPTStatus(ctx context.Context, taskID string) (status string, progress int, statusMsg string) {
// 先查 Redis
key := "ppt:status:" + taskID
cached, err := h.rdb.HGetAll(ctx, key).Result()
if err == nil && len(cached) > 0 {
status = cached["status"]
fmt.Sscanf(cached["progress"], "%d", &progress)
statusMsg = cached["message"]
return
}
// 回退到数据库
var dbStatus string
var dbProgress int
var dbMsg *string
err = h.pool.QueryRow(ctx,
`SELECT status, progress, status_message FROM ppt_tasks WHERE id = $1`, taskID,
).Scan(&dbStatus, &dbProgress, &dbMsg)
if err != nil {
return "pending", 0, "等待处理..."
}
status = dbStatus
progress = dbProgress
if dbMsg != nil {
statusMsg = *dbMsg
}
return
}
func (h *LLMChatHandler) formatProgress(progress int) string {
filled := progress / 5
if filled > 20 {
filled = 20
}
empty := 20 - filled
return "▓" + strings.Repeat("█", filled) + strings.Repeat("░", empty) + "▓"
}
// ==================== 通用工具方法 ====================
func (h *LLMChatHandler) recordUsage(appID, userID, convID, userMessage, aiResponse string, tokens int, model string, durationMs int64) {
ctx := context.Background()
_, _ = h.pool.Exec(ctx, `
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
appID, userID, convID, userMessage, aiResponse, tokens, model, durationMs)
_, _ = h.pool.Exec(ctx,
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
}
// generateConversationName 使用LLM为新对话生成简短标题
func (h *LLMChatHandler) generateConversationName(appID, userID, convID, userMessage string) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// 检查是否已有自定义名称
var existing string
err := h.pool.QueryRow(ctx,
`SELECT name FROM conversation_names WHERE app_id=$1 AND user_id=$2 AND conversation_id=$3`,
appID, userID, convID).Scan(&existing)
if err == nil && existing != "" {
return
}
// 截取用户消息前200字符用于生成标题
msg := userMessage
runes := []rune(msg)
if len(runes) > 200 {
msg = string(runes[:200])
}
nameReq := &llm.ChatRequest{
Model: "",
Messages: []llm.Message{
{Role: "system", Content: "请用10个字以内为以下对话内容生成一个简短标题。只输出标题文字,不要引号、标点或解释。"},
{Role: "user", Content: msg},
},
Temperature: 0.3,
MaxTokens: 30,
Stream: false,
}
result, err := h.manager.Chat(ctx, h.provider, nameReq)
if err != nil {
return
}
name := strings.TrimSpace(result.Content)
nameRunes := []rune(name)
if len(nameRunes) > 20 {
name = string(nameRunes[:20])
}
if name == "" {
return
}
_, _ = h.pool.Exec(ctx, `
INSERT INTO conversation_names (app_id, user_id, conversation_id, name, updated_at)
VALUES ($1, $2, $3, $4, now())
ON CONFLICT (app_id, user_id, conversation_id)
DO UPDATE SET name = EXCLUDED.name, updated_at = now()`,
appID, userID, convID, name)
}