Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,305 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/enterprise-ai-platform/server/internal/middleware"
|
||||
"github.com/enterprise-ai-platform/server/internal/response"
|
||||
"github.com/enterprise-ai-platform/server/pkg/llm"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type AnalysisTemplateHandler struct {
|
||||
pool *pgxpool.Pool
|
||||
manager *llm.Manager
|
||||
provider string
|
||||
}
|
||||
|
||||
func NewAnalysisTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *AnalysisTemplateHandler {
|
||||
return &AnalysisTemplateHandler{pool: pool, manager: manager, provider: provider}
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("org_id")
|
||||
query := `SELECT id, name, report_type, COALESCE(description,''), COALESCE(icon,''), steps, sort_order
|
||||
FROM analysis_report_templates
|
||||
WHERE is_active = true`
|
||||
var args []any
|
||||
if orgID != "" {
|
||||
query += ` AND (org_id = $1 OR org_id IS NULL)`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY sort_order ASC`
|
||||
rows, err := h.pool.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
response.InternalError(w, "查询模板失败")
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var templates []map[string]any
|
||||
for rows.Next() {
|
||||
var id, name, reportType, desc, icon string
|
||||
var steps json.RawMessage
|
||||
var sortOrder int
|
||||
if err := rows.Scan(&id, &name, &reportType, &desc, &icon, &steps, &sortOrder); err != nil {
|
||||
continue
|
||||
}
|
||||
templates = append(templates, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"report_type": reportType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"steps": steps,
|
||||
"sort_order": sortOrder,
|
||||
})
|
||||
}
|
||||
if templates == nil {
|
||||
templates = []map[string]any{}
|
||||
}
|
||||
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "templateId")
|
||||
|
||||
var name, reportType, desc, icon, promptTpl string
|
||||
var steps, outputSections json.RawMessage
|
||||
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, report_type, COALESCE(description,''), COALESCE(icon,''),
|
||||
steps, prompt_template, COALESCE(output_sections, '[]')
|
||||
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, id,
|
||||
).Scan(&name, &reportType, &desc, &icon, &steps, &promptTpl, &outputSections)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, http.StatusOK, map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"report_type": reportType,
|
||||
"description": desc,
|
||||
"icon": icon,
|
||||
"steps": steps,
|
||||
"prompt_template": promptTpl,
|
||||
"output_sections": outputSections,
|
||||
})
|
||||
}
|
||||
|
||||
type generateAnalysisRequest struct {
|
||||
TemplateID string `json:"template_id"`
|
||||
FieldData map[string]string `json:"field_data"`
|
||||
}
|
||||
|
||||
func containsAny(list []string, targets ...string) bool {
|
||||
for _, item := range list {
|
||||
for _, t := range targets {
|
||||
if item == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) queryEconomicData(ctx context.Context, districts, timeRange string) string {
|
||||
years := []int{2025}
|
||||
switch timeRange {
|
||||
case "2024-2025年两年对比", "2024_2025":
|
||||
years = []int{2024, 2025}
|
||||
case "2023-2025年三年趋势", "2023_2025":
|
||||
years = []int{2023, 2024, 2025}
|
||||
case "2024年全年", "2024_full":
|
||||
years = []int{2024}
|
||||
}
|
||||
|
||||
districtList := strings.Split(districts, ",")
|
||||
for i := range districtList {
|
||||
districtList[i] = strings.TrimSpace(districtList[i])
|
||||
}
|
||||
|
||||
rows, err := h.pool.Query(ctx, `
|
||||
SELECT district_name, year,
|
||||
COALESCE(gdp,0), COALESCE(gdp_growth,0), COALESCE(gdp_per_capita,0),
|
||||
COALESCE(fiscal_revenue,0), COALESCE(fiscal_revenue_growth,0),
|
||||
COALESCE(fixed_investment,0), COALESCE(fixed_investment_growth,0),
|
||||
COALESCE(retail_sales,0), COALESCE(retail_sales_growth,0),
|
||||
COALESCE(industrial_output,0), COALESCE(industrial_output_growth,0),
|
||||
COALESCE(tertiary_ratio,0),
|
||||
COALESCE(import_export,0), COALESCE(import_export_growth,0),
|
||||
COALESCE(actual_fdi,0),
|
||||
COALESCE(tech_expenditure,0), COALESCE(tech_expenditure_ratio,0),
|
||||
COALESCE(population,0), COALESCE(urban_income,0), COALESCE(rural_income,0)
|
||||
FROM regional_economic_data
|
||||
WHERE year = ANY($1)
|
||||
AND (cardinality($2::text[]) = 0 OR district_name = ANY($2))
|
||||
ORDER BY year ASC, gdp DESC`, years, districtList)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\n\n## 【数据库真实数据】\n\n")
|
||||
sb.WriteString("| 区县 | 年份 | GDP(亿元) | GDP增速(%) | 人均GDP(元) | 财政收入(亿元) | 财政增速(%) | 固投(亿元) | 固投增速(%) | 社零(亿元) | 社零增速(%) | 规上工业(亿元) | 工业增速(%) | 三产占比(%) | 进出口(亿元) | 进出口增速(%) | 利用外资(亿美元) | R&D投入(亿元) | R&D/GDP(%) | 人口(万) | 城镇收入(元) | 农村收入(元) |\n")
|
||||
sb.WriteString("|------|------|-----------|-----------|------------|-------------|-----------|----------|-----------|----------|-----------|-------------|-----------|-----------|-----------|------------|-------------|------------|----------|----------|-----------|----------|\n")
|
||||
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var year int
|
||||
var gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri float64
|
||||
if err := rows.Scan(&name, &year, &gdp, &gdpG, &gdpPC, &fr, &frG, &fi, &fiG, &rs, &rsG, &io, &ioG, &tr, &ie, &ieG, &fdi, &te, &teR, &pop, &ui, &ri); err != nil {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("| %s | %d | %.1f | %.1f | %.0f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.1f | %.2f | %.1f | %.2f | %.1f | %.0f | %.0f |\n",
|
||||
name, year, gdp, gdpG, gdpPC, fr, frG, fi, fiG, rs, rsG, io, ioG, tr, ie, ieG, fdi, te, teR, pop, ui, ri))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (h *AnalysisTemplateHandler) GenerateReport(w http.ResponseWriter, r *http.Request) {
|
||||
appID := chi.URLParam(r, "id")
|
||||
userID := middleware.GetUserID(r.Context())
|
||||
|
||||
var req generateAnalysisRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
response.BadRequest(w, "无效请求格式")
|
||||
return
|
||||
}
|
||||
if req.TemplateID == "" {
|
||||
response.BadRequest(w, "请选择报告模板")
|
||||
return
|
||||
}
|
||||
|
||||
var appModel string
|
||||
var appMaxTokens int
|
||||
var appConfigRaw json.RawMessage
|
||||
_ = h.pool.QueryRow(r.Context(), `
|
||||
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 8192), COALESCE(app_config, '{}')
|
||||
FROM applications WHERE id = $1`, appID,
|
||||
).Scan(&appModel, &appMaxTokens, &appConfigRaw)
|
||||
|
||||
var appCfg struct {
|
||||
Tools []string `json:"tools"`
|
||||
DataSources []string `json:"data_sources"`
|
||||
TemplateSet string `json:"template_set"`
|
||||
}
|
||||
_ = json.Unmarshal(appConfigRaw, &appCfg)
|
||||
|
||||
var promptTpl, tplName, reportType string
|
||||
err := h.pool.QueryRow(r.Context(), `
|
||||
SELECT name, prompt_template, report_type
|
||||
FROM analysis_report_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
|
||||
).Scan(&tplName, &promptTpl, &reportType)
|
||||
if err != nil {
|
||||
response.NotFound(w, "模板不存在")
|
||||
return
|
||||
}
|
||||
|
||||
prompt := promptTpl
|
||||
for k, v := range req.FieldData {
|
||||
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
|
||||
}
|
||||
prompt = strings.ReplaceAll(prompt, "{{", "")
|
||||
prompt = strings.ReplaceAll(prompt, "}}", "")
|
||||
|
||||
hasDataTool := containsAny(appCfg.Tools, "economic_data_query", "data_analysis")
|
||||
hasDataSource := containsAny(appCfg.DataSources, "regional_economic_data", "statistical_yearbook")
|
||||
useEconomicData := hasDataTool || hasDataSource || reportType == "economic_comparison" || reportType == "trend_analysis"
|
||||
|
||||
if useEconomicData {
|
||||
dataTable := h.queryEconomicData(r.Context(), req.FieldData["districts"], req.FieldData["time_range"])
|
||||
if dataTable != "" {
|
||||
prompt += "\n\n⚠️ 以下是从数据库中检索到的真实经济数据,请基于这些真实数据进行分析,不要编造数据:" + dataTable
|
||||
}
|
||||
}
|
||||
|
||||
llmReq := &llm.ChatRequest{
|
||||
Model: appModel,
|
||||
Messages: []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: "你是一位资深的政务数据分析和研判专家。要求:1)基于提供的真实数据库数据进行分析 2)报告必须完整输出到最后一个章节,不能中途截断 3)保持精炼,每个章节500字以内 4)使用Markdown格式排版 5)表格数据必须引用真实数据 6)报告输出完毕后,必须在最后一行单独输出「---\\n\\n**【完稿】**」作为完成标记"},
|
||||
{Role: llm.RoleUser, Content: prompt},
|
||||
},
|
||||
Temperature: 0.4,
|
||||
MaxTokens: appMaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
|
||||
if err != nil {
|
||||
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalError(w, "Streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
convID := uuid.New().String()
|
||||
msgID := uuid.New().String()
|
||||
var fullResponse strings.Builder
|
||||
|
||||
firstEvent := map[string]string{
|
||||
"conversation_id": convID,
|
||||
"message_id": msgID,
|
||||
"template_name": tplName,
|
||||
}
|
||||
data, _ := json.Marshal(firstEvent)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
transform := llm.TransformOpenAIStream
|
||||
if h.provider == "anthropic" {
|
||||
transform = llm.TransformAnthropicStream
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
var modelName string
|
||||
|
||||
_ = transform(body, func(event llm.StreamEvent) {
|
||||
if event.Answer != "" {
|
||||
fullResponse.WriteString(event.Answer)
|
||||
}
|
||||
if event.Usage != nil {
|
||||
totalTokens = event.Usage.TotalTokens
|
||||
modelName = event.Usage.Model
|
||||
}
|
||||
data, _ := json.Marshal(event)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
})
|
||||
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
|
||||
duration := time.Since(startTime).Milliseconds()
|
||||
userMsg := fmt.Sprintf("[研判报告] %s", tplName)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
_, _ = h.pool.Exec(ctx, `
|
||||
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
|
||||
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
|
||||
_, _ = h.pool.Exec(ctx,
|
||||
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user