Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AnthropicProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAnthropicProvider(apiKey, baseURL, defaultModel string) *AnthropicProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
if defaultModel == "" {
|
||||
defaultModel = "claude-sonnet-4-20250514"
|
||||
}
|
||||
return &AnthropicProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
model: defaultModel,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) Name() string { return "anthropic" }
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropicMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type anthropicMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
system, msgs := extractSystemMessage(req.Messages)
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
aReq := anthropicRequest{
|
||||
Model: model,
|
||||
Messages: toAnthropicMessages(msgs),
|
||||
System: system,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: req.Temperature,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(aReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var aResp anthropicResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&aResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, c := range aResp.Content {
|
||||
if c.Type == "text" {
|
||||
content += c.Text
|
||||
}
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: aResp.Model,
|
||||
PromptTokens: aResp.Usage.InputTokens,
|
||||
CompletionTokens: aResp.Usage.OutputTokens,
|
||||
TotalTokens: aResp.Usage.InputTokens + aResp.Usage.OutputTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
system, msgs := extractSystemMessage(req.Messages)
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
|
||||
aReq := anthropicRequest{
|
||||
Model: model,
|
||||
Messages: toAnthropicMessages(msgs),
|
||||
System: system,
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: req.Temperature,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(aReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("anthropic stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func extractSystemMessage(msgs []Message) (string, []Message) {
|
||||
system := ""
|
||||
var filtered []Message
|
||||
for _, m := range msgs {
|
||||
if m.Role == RoleSystem {
|
||||
system = m.Content
|
||||
} else {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return system, filtered
|
||||
}
|
||||
|
||||
func toAnthropicMessages(msgs []Message) []anthropicMessage {
|
||||
out := make([]anthropicMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = anthropicMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Manager manages multiple LLM providers and routes requests.
|
||||
type Manager struct {
|
||||
providers map[string]Provider
|
||||
fallback string
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
providers: make(map[string]Provider),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Register(name string, provider Provider) {
|
||||
m.providers[name] = provider
|
||||
if m.fallback == "" {
|
||||
m.fallback = name
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SetFallback(name string) {
|
||||
m.fallback = name
|
||||
}
|
||||
|
||||
func (m *Manager) GetProvider(name string) (Provider, error) {
|
||||
if p, ok := m.providers[name]; ok {
|
||||
return p, nil
|
||||
}
|
||||
if p, ok := m.providers[m.fallback]; ok {
|
||||
return p, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no provider found: %s", name)
|
||||
}
|
||||
|
||||
// Chat performs a blocking chat completion using the specified provider.
|
||||
func (m *Manager) Chat(ctx context.Context, providerName string, req *ChatRequest) (*ChatResponse, error) {
|
||||
provider, err := m.GetProvider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider.ChatCompletion(ctx, req)
|
||||
}
|
||||
|
||||
// ChatStream performs a streaming chat and returns the raw SSE body.
|
||||
func (m *Manager) ChatStream(ctx context.Context, providerName string, req *ChatRequest) (io.ReadCloser, error) {
|
||||
provider, err := m.GetProvider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Stream = true
|
||||
return provider.ChatStream(ctx, req)
|
||||
}
|
||||
|
||||
// StreamEvent represents a normalized SSE event for the frontend.
|
||||
type StreamEvent struct {
|
||||
Event string `json:"event"`
|
||||
Answer string `json:"answer,omitempty"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// TransformOpenAIStream reads an OpenAI SSE stream and writes normalized events to the writer.
|
||||
func TransformOpenAIStream(reader io.Reader, write func(event StreamEvent)) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var totalContent string
|
||||
var model string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
write(StreamEvent{
|
||||
Event: "message_end",
|
||||
Usage: &Usage{
|
||||
TotalTokens: estimateTokens(totalContent),
|
||||
Model: model,
|
||||
},
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
model = chunk.Model
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
||||
content := chunk.Choices[0].Delta.Content
|
||||
totalContent += content
|
||||
write(StreamEvent{
|
||||
Event: "message",
|
||||
Answer: content,
|
||||
MessageID: chunk.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// TransformAnthropicStream reads an Anthropic SSE stream and writes normalized events.
|
||||
func TransformAnthropicStream(reader io.Reader, write func(event StreamEvent)) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
var model string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if m, ok := msg["model"].(string); ok {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
case "content_block_delta":
|
||||
if delta, ok := event["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
write(StreamEvent{
|
||||
Event: "message",
|
||||
Answer: text,
|
||||
})
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
if usage, ok := event["usage"].(map[string]any); ok {
|
||||
outputTokens := int(getFloat(usage, "output_tokens"))
|
||||
write(StreamEvent{
|
||||
Event: "message_end",
|
||||
Usage: &Usage{
|
||||
CompletionTokens: outputTokens,
|
||||
TotalTokens: outputTokens,
|
||||
Model: model,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func getFloat(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func estimateTokens(text string) int {
|
||||
return len(text) / 4
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(apiKey, baseURL, defaultModel string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if defaultModel == "" {
|
||||
defaultModel = "gpt-4o-mini"
|
||||
}
|
||||
return &OpenAIProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
model: defaultModel,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string { return "openai" }
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var oaiResp openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
if len(oaiResp.Choices) > 0 {
|
||||
content = oaiResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: oaiResp.Model,
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
CompletionTokens: oaiResp.Usage.CompletionTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func toOpenAIMessages(msgs []Message) []openAIMessage {
|
||||
out := make([]openAIMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = openAIMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role Role `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ConversationID string `json:"conversation_id,omitempty"`
|
||||
}
|
||||
|
||||
type StreamDelta struct {
|
||||
Content string `json:"content"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
// Provider abstracts an LLM API (OpenAI, Anthropic, etc.)
|
||||
type Provider interface {
|
||||
ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
||||
ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
Type string `json:"type"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user