package handler import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/enterprise-ai-platform/server/pkg/llm" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) type DocTemplateHandler struct { pool *pgxpool.Pool manager *llm.Manager provider string } func NewDocTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *DocTemplateHandler { return &DocTemplateHandler{pool: pool, manager: manager, provider: provider} } func (h *DocTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { orgID := r.URL.Query().Get("org_id") query := `SELECT id, name, doc_type, COALESCE(description,''), COALESCE(icon,''), fields, sort_order FROM document_templates WHERE is_active = true` var args []any if orgID != "" { query += ` AND (org_id = $1 OR org_id IS NULL)` args = append(args, orgID) } query += ` ORDER BY sort_order ASC` rows, err := h.pool.Query(r.Context(), query, args...) if err != nil { response.InternalError(w, "查询模板失败") return } defer rows.Close() var templates []map[string]any for rows.Next() { var id, name, docType, desc, icon string var fields json.RawMessage var sortOrder int if err := rows.Scan(&id, &name, &docType, &desc, &icon, &fields, &sortOrder); err != nil { continue } templates = append(templates, map[string]any{ "id": id, "name": name, "doc_type": docType, "description": desc, "icon": icon, "fields": fields, "sort_order": sortOrder, }) } if templates == nil { templates = []map[string]any{} } response.JSON(w, http.StatusOK, map[string]any{"data": templates}) } func (h *DocTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "templateId") var name, docType, desc, icon, formatStd, promptTpl string var fields json.RawMessage var exampleOutput *string err := h.pool.QueryRow(r.Context(), ` SELECT name, doc_type, COALESCE(description,''), COALESCE(icon,''), format_standard, fields, prompt_template, example_output FROM document_templates WHERE id = $1 AND is_active = true`, id, ).Scan(&name, &docType, &desc, &icon, &formatStd, &fields, &promptTpl, &exampleOutput) if err != nil { response.NotFound(w, "模板不存在") return } result := map[string]any{ "id": id, "name": name, "doc_type": docType, "description": desc, "icon": icon, "format_standard": formatStd, "fields": fields, "prompt_template": promptTpl, } if exampleOutput != nil { result["example_output"] = *exampleOutput } response.JSON(w, http.StatusOK, result) } type generateDocRequest struct { TemplateID string `json:"template_id"` FieldData map[string]string `json:"field_data"` } func (h *DocTemplateHandler) GenerateDocument(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var req generateDocRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效请求格式") return } if req.TemplateID == "" { response.BadRequest(w, "请选择公文模板") return } // Load app-specific model and max_tokens var appModel string var appMaxTokens int _ = h.pool.QueryRow(r.Context(), ` SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 4096) FROM applications WHERE id = $1`, appID, ).Scan(&appModel, &appMaxTokens) var promptTpl, tplName string var fields json.RawMessage err := h.pool.QueryRow(r.Context(), ` SELECT name, fields, prompt_template FROM document_templates WHERE id = $1 AND is_active = true`, req.TemplateID, ).Scan(&tplName, &fields, &promptTpl) if err != nil { response.NotFound(w, "模板不存在") return } var fieldDefs []struct { Key string `json:"key"` Label string `json:"label"` Required bool `json:"required"` } _ = json.Unmarshal(fields, &fieldDefs) for _, f := range fieldDefs { if f.Required { if val, ok := req.FieldData[f.Key]; !ok || strings.TrimSpace(val) == "" { response.BadRequest(w, fmt.Sprintf("请填写必填项:%s", f.Label)) return } } } prompt := promptTpl for k, v := range req.FieldData { prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v) } prompt = strings.ReplaceAll(prompt, "{{", "") prompt = strings.ReplaceAll(prompt, "}}", "") llmReq := &llm.ChatRequest{ Model: appModel, Messages: []llm.Message{ {Role: llm.RoleSystem, Content: "你是一个专业的政务公文写作专家,精通《党政机关公文格式》国家标准(GB/T 9704)和《党政机关公文处理工作条例》。请严格按照规范格式生成公文,确保行文庄重、严谨、准确。输出完整公文内容,使用Markdown格式排版。"}, {Role: llm.RoleUser, Content: prompt}, }, Temperature: 0.3, MaxTokens: appMaxTokens, Stream: true, } startTime := time.Now() body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq) if err != nil { response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error()) return } defer body.Close() w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") flusher, ok := w.(http.Flusher) if !ok { response.InternalError(w, "Streaming not supported") return } convID := uuid.New().String() msgID := uuid.New().String() var fullResponse strings.Builder firstEvent := map[string]string{ "conversation_id": convID, "message_id": msgID, "template_name": tplName, } data, _ := json.Marshal(firstEvent) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() transform := llm.TransformOpenAIStream if h.provider == "anthropic" { transform = llm.TransformAnthropicStream } var totalTokens int var modelName string _ = transform(body, func(event llm.StreamEvent) { if event.Answer != "" { fullResponse.WriteString(event.Answer) } if event.Usage != nil { totalTokens = event.Usage.TotalTokens modelName = event.Usage.Model } data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() }) fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() duration := time.Since(startTime).Milliseconds() userMsg := fmt.Sprintf("[公文生成] %s", tplName) go func() { ctx := context.Background() _, _ = h.pool.Exec(ctx, ` INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`, appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration) _, _ = h.pool.Exec(ctx, `UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID) }() }