243 lines
6.9 KiB
Go
243 lines
6.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/enterprise-ai-platform/server/internal/middleware"
|
|
"github.com/enterprise-ai-platform/server/internal/response"
|
|
"github.com/enterprise-ai-platform/server/pkg/llm"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type DocTemplateHandler struct {
|
|
pool *pgxpool.Pool
|
|
manager *llm.Manager
|
|
provider string
|
|
}
|
|
|
|
func NewDocTemplateHandler(pool *pgxpool.Pool, manager *llm.Manager, provider string) *DocTemplateHandler {
|
|
return &DocTemplateHandler{pool: pool, manager: manager, provider: provider}
|
|
}
|
|
|
|
func (h *DocTemplateHandler) ListTemplates(w http.ResponseWriter, r *http.Request) {
|
|
orgID := r.URL.Query().Get("org_id")
|
|
query := `SELECT id, name, doc_type, COALESCE(description,''), COALESCE(icon,''), fields, sort_order
|
|
FROM document_templates
|
|
WHERE is_active = true`
|
|
var args []any
|
|
if orgID != "" {
|
|
query += ` AND (org_id = $1 OR org_id IS NULL)`
|
|
args = append(args, orgID)
|
|
}
|
|
query += ` ORDER BY sort_order ASC`
|
|
rows, err := h.pool.Query(r.Context(), query, args...)
|
|
if err != nil {
|
|
response.InternalError(w, "查询模板失败")
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var templates []map[string]any
|
|
for rows.Next() {
|
|
var id, name, docType, desc, icon string
|
|
var fields json.RawMessage
|
|
var sortOrder int
|
|
if err := rows.Scan(&id, &name, &docType, &desc, &icon, &fields, &sortOrder); err != nil {
|
|
continue
|
|
}
|
|
templates = append(templates, map[string]any{
|
|
"id": id,
|
|
"name": name,
|
|
"doc_type": docType,
|
|
"description": desc,
|
|
"icon": icon,
|
|
"fields": fields,
|
|
"sort_order": sortOrder,
|
|
})
|
|
}
|
|
if templates == nil {
|
|
templates = []map[string]any{}
|
|
}
|
|
response.JSON(w, http.StatusOK, map[string]any{"data": templates})
|
|
}
|
|
|
|
func (h *DocTemplateHandler) GetTemplate(w http.ResponseWriter, r *http.Request) {
|
|
id := chi.URLParam(r, "templateId")
|
|
|
|
var name, docType, desc, icon, formatStd, promptTpl string
|
|
var fields json.RawMessage
|
|
var exampleOutput *string
|
|
|
|
err := h.pool.QueryRow(r.Context(), `
|
|
SELECT name, doc_type, COALESCE(description,''), COALESCE(icon,''),
|
|
format_standard, fields, prompt_template, example_output
|
|
FROM document_templates WHERE id = $1 AND is_active = true`, id,
|
|
).Scan(&name, &docType, &desc, &icon, &formatStd, &fields, &promptTpl, &exampleOutput)
|
|
if err != nil {
|
|
response.NotFound(w, "模板不存在")
|
|
return
|
|
}
|
|
|
|
result := map[string]any{
|
|
"id": id,
|
|
"name": name,
|
|
"doc_type": docType,
|
|
"description": desc,
|
|
"icon": icon,
|
|
"format_standard": formatStd,
|
|
"fields": fields,
|
|
"prompt_template": promptTpl,
|
|
}
|
|
if exampleOutput != nil {
|
|
result["example_output"] = *exampleOutput
|
|
}
|
|
response.JSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
type generateDocRequest struct {
|
|
TemplateID string `json:"template_id"`
|
|
FieldData map[string]string `json:"field_data"`
|
|
}
|
|
|
|
func (h *DocTemplateHandler) GenerateDocument(w http.ResponseWriter, r *http.Request) {
|
|
appID := chi.URLParam(r, "id")
|
|
userID := middleware.GetUserID(r.Context())
|
|
|
|
var req generateDocRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
response.BadRequest(w, "无效请求格式")
|
|
return
|
|
}
|
|
if req.TemplateID == "" {
|
|
response.BadRequest(w, "请选择公文模板")
|
|
return
|
|
}
|
|
|
|
// Load app-specific model and max_tokens
|
|
var appModel string
|
|
var appMaxTokens int
|
|
_ = h.pool.QueryRow(r.Context(), `
|
|
SELECT COALESCE(app_config->>'model', ''), COALESCE(max_tokens, 4096)
|
|
FROM applications WHERE id = $1`, appID,
|
|
).Scan(&appModel, &appMaxTokens)
|
|
|
|
var promptTpl, tplName string
|
|
var fields json.RawMessage
|
|
err := h.pool.QueryRow(r.Context(), `
|
|
SELECT name, fields, prompt_template
|
|
FROM document_templates WHERE id = $1 AND is_active = true`, req.TemplateID,
|
|
).Scan(&tplName, &fields, &promptTpl)
|
|
if err != nil {
|
|
response.NotFound(w, "模板不存在")
|
|
return
|
|
}
|
|
|
|
var fieldDefs []struct {
|
|
Key string `json:"key"`
|
|
Label string `json:"label"`
|
|
Required bool `json:"required"`
|
|
}
|
|
_ = json.Unmarshal(fields, &fieldDefs)
|
|
for _, f := range fieldDefs {
|
|
if f.Required {
|
|
if val, ok := req.FieldData[f.Key]; !ok || strings.TrimSpace(val) == "" {
|
|
response.BadRequest(w, fmt.Sprintf("请填写必填项:%s", f.Label))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
prompt := promptTpl
|
|
for k, v := range req.FieldData {
|
|
prompt = strings.ReplaceAll(prompt, "{{"+k+"}}", v)
|
|
}
|
|
prompt = strings.ReplaceAll(prompt, "{{", "")
|
|
prompt = strings.ReplaceAll(prompt, "}}", "")
|
|
|
|
llmReq := &llm.ChatRequest{
|
|
Model: appModel,
|
|
Messages: []llm.Message{
|
|
{Role: llm.RoleSystem, Content: "你是一个专业的政务公文写作专家,精通《党政机关公文格式》国家标准(GB/T 9704)和《党政机关公文处理工作条例》。请严格按照规范格式生成公文,确保行文庄重、严谨、准确。输出完整公文内容,使用Markdown格式排版。"},
|
|
{Role: llm.RoleUser, Content: prompt},
|
|
},
|
|
Temperature: 0.3,
|
|
MaxTokens: appMaxTokens,
|
|
Stream: true,
|
|
}
|
|
|
|
startTime := time.Now()
|
|
body, err := h.manager.ChatStream(r.Context(), h.provider, llmReq)
|
|
if err != nil {
|
|
response.Error(w, http.StatusBadGateway, 50202, "模型服务不可用: "+err.Error())
|
|
return
|
|
}
|
|
defer body.Close()
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
response.InternalError(w, "Streaming not supported")
|
|
return
|
|
}
|
|
|
|
convID := uuid.New().String()
|
|
msgID := uuid.New().String()
|
|
var fullResponse strings.Builder
|
|
|
|
firstEvent := map[string]string{
|
|
"conversation_id": convID,
|
|
"message_id": msgID,
|
|
"template_name": tplName,
|
|
}
|
|
data, _ := json.Marshal(firstEvent)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
flusher.Flush()
|
|
|
|
transform := llm.TransformOpenAIStream
|
|
if h.provider == "anthropic" {
|
|
transform = llm.TransformAnthropicStream
|
|
}
|
|
|
|
var totalTokens int
|
|
var modelName string
|
|
|
|
_ = transform(body, func(event llm.StreamEvent) {
|
|
if event.Answer != "" {
|
|
fullResponse.WriteString(event.Answer)
|
|
}
|
|
if event.Usage != nil {
|
|
totalTokens = event.Usage.TotalTokens
|
|
modelName = event.Usage.Model
|
|
}
|
|
data, _ := json.Marshal(event)
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
flusher.Flush()
|
|
})
|
|
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
|
|
duration := time.Since(startTime).Milliseconds()
|
|
userMsg := fmt.Sprintf("[公文生成] %s", tplName)
|
|
go func() {
|
|
ctx := context.Background()
|
|
_, _ = h.pool.Exec(ctx, `
|
|
INSERT INTO app_usage_logs (app_id, user_id, conversation_id, user_message, ai_response, total_tokens, model_name, duration_ms, client_type)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'web')`,
|
|
appID, userID, convID, userMsg, fullResponse.String(), totalTokens, modelName, duration)
|
|
_, _ = h.pool.Exec(ctx,
|
|
`UPDATE applications SET usage_count = usage_count + 1 WHERE id = $1`, appID)
|
|
}()
|
|
}
|