package handler import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "github.com/enterprise-ai-platform/server/internal/middleware" "github.com/enterprise-ai-platform/server/internal/response" "github.com/enterprise-ai-platform/server/pkg/llm" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) type AnalysisTemplateHandler struct { pool *pgxpool.Pool manager *llm.Manager provider string } func NewAnalysisTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *AnalysisTemplateHandler { return &AnalysisTemplateHandler{pool: pool, manager: manager, provider: provider} } func (h *AnalysisTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) { orgID := r.URL.Query().Get("org_id") query := `SELECT id, name, report_type, COALESCE(description,''), COALESCE(icon,''), steps, sort_order FROM analysis_report_templates WHERE is_active = true` var args []any if orgID != "" { query += ` AND (org_id = $1 OR org_id IS NULL)` args = append(args, orgID) } query += ` ORDER BY sort_order ASC` rows, err := h.pool.Query(r.Context(), query, args...) if err != nil { response.InternalError(w, "查询模板失败") return } defer rows.Close() var templates []map[string]any for rows.Next() { var id, name, reportType, desc, icon string var steps json.RawMessage var sortOrder int if err := rows.Scan(&id, &name, &reportType, &desc, &icon, &steps, &sortOrder); err != nil { continue } templates = append(templates, map[string]any{ "id": id, "name": name, "report_type": reportType, "description": desc, "icon": icon, "steps": steps, "sort_order": sortOrder, }) } if templates == nil { templates = []map[string]any{} } response.JSON(w, http.StatusOK, map[string]any{"data": templates}) } func (h *AnalysisTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "templateId") var name, reportType, desc, icon, promptTpl string var steps, outputSections json.RawMessage err := h.pool.QueryRow(r.Context(), ` SELECT name, report_type, COALESCE(description,''), COALESCE(icon,''), steps, prompt_template, COALESCE(output_sections, '[]') FROM analysis_report_templates WHERE id = $1 AND is_active = true`, id, ).Scan(&name, &reportType, &desc, &icon, &steps, &promptTpl, &outputSections) if err != nil { response.NotFound(w, "模板不存在") return } response.JSON(w, http.StatusOK, map[string]any{ "id": id, "name": name, "report_type": reportType, "description": desc, "icon": icon, "steps": steps, "prompt_template": promptTpl, "output_sections": outputSections, }) } type generateAnalysisRequest struct { TemplateID string `json:"template_id"` FieldData map[string]string `json:"field_data"` } func containsAny(list []string, targets ...string) bool { for _, item := range list { for _, t := range targets { if item == t { return true } } } return false } func (h *AnalysisTemplateHandler) queryEconomicData(ctx context.Context, districts, timeRange string) string { years := []int{2025} switch timeRange { case "2024-2025年两年对比", "2024_2025": years = []int{2024, 2025} case "2023-2025年三年趋势", "2023_2025": years = []int{2023, 2024, 2025} case "2024年全年", "2024_full": years = []int{2024} } districtList := strings.Split(districts, ",") for i := range districtList { districtList[i] = strings.TrimSpace(districtList[i]) } rows, err := h.pool.Query(ctx, ` SELECT district_name, year, COALESCE(gdp,0), COALESCE(gdp_growth,0), COALESCE(gdp_per_capita,0), COALESCE(fiscal_revenue,0), COALESCE(fiscal_revenue_growth,0), COALESCE(fixed_investment,0), COALESCE(fixed_investment_growth,0), COALESCE(retail_sales,0), COALESCE(retail_sales_growth,0), COALESCE(industrial_output,0), COALESCE(industrial_output_growth,0), COALESCE(tertiary_ratio,0), COALESCE(import_export,0), COALESCE(import_export_growth,0), COALESCE(actual_fdi,0), COALESCE(tech_expenditure,0), COALESCE(tech_expenditure_ratio,0), COALESCE(population,0), COALESCE(urban_income,0), COALESCE(rural_income,0) FROM regional_economic_data WHERE year = ANY($1) AND (cardinality($2::text[]) = 0 OR district_name = ANY($2)) ORDER BY year ASC, gdp DESC`, years, districtList) if err != nil { return "" } defer rows.Close() var sb strings.Builder sb.WriteString("\n\n## 【数据库真实数据】\n\n") sb.WriteString("| 区县 | 年份 | GDP(亿元) | GDP增速(%) | 人均GDP(元) | 财政收入(亿元) | 财政增速(%) | 固投(亿元) | 固投增速(%) | 社零(亿元) | 社零增速(%) | 规上工业(亿元) | 工业增速(%) | 三产占比(%) | 进出口(亿元) | 进出口增速(%) | 利用外资(亿美元) | R&D投入(亿元) | R&D/GDP(%) | 人口(万) | 城镇收入(元) | 农村收入(元) |\n") sb.WriteString("|------|------|-----------|-----------|------------|-------------|-----------|----------|-----------|----------|-----------|-------------|-----------|-----------|-----------|------------|-------------|------------|----------|----------|-----------|----------|\n") for rows.Next() { var name string var year int var gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri float64 if err := rows.Scan(&name, &year, &gdp, &gdpG, &gdpPC, &fr, &frG, &fi, &fiG, &rs, &rsG, &io, &ioG, &tr, &ie, &ieG, &fdi, &te, &teR, &pop, &ui, &ri); err != nil { continue } sb.WriteString(fmt.Sprintf("| %s | %d | %.1f | %.1f | %.0f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.2f | %.1f | %.2f | %.1f | %.0f | %.0f |\n", name, year, gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri)) } return sb.String() } func (h *AnalysisTemplateHandler) GenerateReport(w http.ResponseWriter, r *http.Request) { appID := chi.URLParam(r, "id") userID := middleware.GetUserID(r.Context()) var req generateAnalysisRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { response.BadRequest(w, "无效请求格式") return } if req.TemplateID == "" { response.BadRequest(w, "请选择报告模板") return } var appModel string var appMaxTokens int var appConfigRaw json.RawMessage _ = h.pool.QueryRow(r.Context(), ` SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 8192), COALESCE(app_config, '{}') FROM applications WHERE id = $1`, appID, ).Scan(&appModel, &appMaxTokens, &appConfigRaw) var appCfg struct { Tools []string `json:"tools"` DataSources []string `json:"data_sources"` TemplateSet string `json:"template_set"` } _ = json.Unmarshal(appConfigRaw, &appCfg) var promptTpl, tplName, reportType string err := h.pool.QueryRow(r.Context(), ` SELECT name, prompt_template, report_type FROM analysis_report_templates WHERE id = $1 AND is_active = true`, req.TemplateID, ).Scan(&tplName, &promptTpl, &reportType) if err != nil { response.NotFound(w, "模板不存在") return } prompt := promptTpl for k, v := range req.FieldData { prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v) } prompt = strings.ReplaceAll(prompt, "{{", "") prompt = strings.ReplaceAll(prompt, "}}", "") hasDataTool := containsAny(appCfg.Tools, "economic_data_query", "data_analysis") hasDataSource := containsAny(appCfg.DataSources, "regional_economic_data", "statistical_yearbook") useEconomicData := hasDataTool || hasDataSource || reportType == "economic_comparison" || reportType == "trend_analysis" if useEconomicData { dataTable := h.queryEconomicData(r.Context(), req.FieldData["districts"], req.FieldData["time_range"]) if dataTable != "" { prompt += "\n\n⚠️ 以下是从数据库中检索到的真实经济数据,请基于这些真实数据进行分析,不要编造数据:" + dataTable } } llmReq := &llm.ChatRequest{ Model: appModel, Messages: []llm.Message{ {Role: llm.RoleSystem, Content: "你是一位资深的政务数据分析和研判专家。要求:1)基于提供的真实数据库数据进行分析 2)报告必须完整输出到最后一个章节,不能中途截断 3)保持精炼,每个章节500字以内 4)使用Markdown格式排版 5)表格数据必须引用真实数据 6)报告输出完毕后,必须在最后一行单独输出「---\\n\\n**【完稿】**」作为完成标记"}, {Role: llm.RoleUser, Content: prompt}, }, Temperature: 0.4, MaxTokens: appMaxTokens, Stream: true, } startTime := time.Now() body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq) if err != nil { response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error()) return } defer body.Close() w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") flusher, ok := w.(http.Flusher) if !ok { response.InternalError(w, "Streaming not supported") return } convID := uuid.New().String() msgID := uuid.New().String() var fullResponse strings.Builder firstEvent := map[string]string{ "conversation_id": convID, "message_id": msgID, "template_name": tplName, } data, _ := json.Marshal(firstEvent) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() transform := llm.TransformOpenAIStream if h.provider == "anthropic" { transform = llm.TransformAnthropicStream } var totalTokens int var modelName string _ = transform(body, func(event llm.StreamEvent) { if event.Answer != "" { fullResponse.WriteString(event.Answer) } if event.Usage != nil { totalTokens = event.Usage.TotalTokens modelName = event.Usage.Model } data, _ := json.Marshal(event) fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() }) fmt.Fprintf(w, "data: [DONE]\n\n") flusher.Flush() duration := time.Since(startTime).Milliseconds() userMsg := fmt.Sprintf("[研判报告] %s", tplName) go func() { ctx := context.Background() _, _ = h.pool.Exec(ctx, ` INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`, appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration) _, _ = h.pool.Exec(ctx, `UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID) }() }