198 lines
4.6 KiB
Go
198 lines
4.6 KiB
Go
package llm
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
)
|
|
|
|
// Manager manages multiple LLM providers and routes requests.
|
|
type Manager struct {
|
|
providers map[string]Provider
|
|
fallback string
|
|
}
|
|
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
providers: make(map[string]Provider),
|
|
}
|
|
}
|
|
|
|
func (m *Manager) Register(name string, provider Provider) {
|
|
m.providers[name] = provider
|
|
if m.fallback == "" {
|
|
m.fallback = name
|
|
}
|
|
}
|
|
|
|
func (m *Manager) SetFallback(name string) {
|
|
m.fallback = name
|
|
}
|
|
|
|
func (m *Manager) GetProvider(name string) (Provider, error) {
|
|
if p, ok := m.providers[name]; ok {
|
|
return p, nil
|
|
}
|
|
if p, ok := m.providers[m.fallback]; ok {
|
|
return p, nil
|
|
}
|
|
return nil, fmt.Errorf("no provider found: %s", name)
|
|
}
|
|
|
|
// Chat performs a blocking chat completion using the specified provider.
|
|
func (m *Manager) Chat(ctx context.Context, providerName string, req *ChatRequest) (*ChatResponse, error) {
|
|
provider, err := m.GetProvider(providerName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return provider.ChatCompletion(ctx, req)
|
|
}
|
|
|
|
// ChatStream performs a streaming chat and returns the raw SSE body.
|
|
func (m *Manager) ChatStream(ctx context.Context, providerName string, req *ChatRequest) (io.ReadCloser, error) {
|
|
provider, err := m.GetProvider(providerName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Stream = true
|
|
return provider.ChatStream(ctx, req)
|
|
}
|
|
|
|
// StreamEvent represents a normalized SSE event for the frontend.
|
|
type StreamEvent struct {
|
|
Event string `json:"event"`
|
|
Answer string `json:"answer,omitempty"`
|
|
MessageID string `json:"message_id,omitempty"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
// TransformOpenAIStream reads an OpenAI SSE stream and writes normalized events to the writer.
|
|
func TransformOpenAIStream(reader io.Reader, write func(event StreamEvent)) error {
|
|
scanner := bufio.NewScanner(reader)
|
|
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
|
|
var totalContent string
|
|
var model string
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
write(StreamEvent{
|
|
Event: "message_end",
|
|
Usage: &Usage{
|
|
TotalTokens: estimateTokens(totalContent),
|
|
Model: model,
|
|
},
|
|
})
|
|
break
|
|
}
|
|
|
|
var chunk struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
continue
|
|
}
|
|
|
|
model = chunk.Model
|
|
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
|
content := chunk.Choices[0].Delta.Content
|
|
totalContent += content
|
|
write(StreamEvent{
|
|
Event: "message",
|
|
Answer: content,
|
|
MessageID: chunk.ID,
|
|
})
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
// TransformAnthropicStream reads an Anthropic SSE stream and writes normalized events.
|
|
func TransformAnthropicStream(reader io.Reader, write func(event StreamEvent)) error {
|
|
scanner := bufio.NewScanner(reader)
|
|
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
|
|
var model string
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
if strings.HasPrefix(line, "event: ") {
|
|
continue
|
|
}
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
|
|
var event map[string]any
|
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
|
continue
|
|
}
|
|
|
|
eventType, _ := event["type"].(string)
|
|
switch eventType {
|
|
case "message_start":
|
|
if msg, ok := event["message"].(map[string]any); ok {
|
|
if m, ok := msg["model"].(string); ok {
|
|
model = m
|
|
}
|
|
}
|
|
case "content_block_delta":
|
|
if delta, ok := event["delta"].(map[string]any); ok {
|
|
if text, ok := delta["text"].(string); ok {
|
|
write(StreamEvent{
|
|
Event: "message",
|
|
Answer: text,
|
|
})
|
|
}
|
|
}
|
|
case "message_delta":
|
|
if usage, ok := event["usage"].(map[string]any); ok {
|
|
outputTokens := int(getFloat(usage, "output_tokens"))
|
|
write(StreamEvent{
|
|
Event: "message_end",
|
|
Usage: &Usage{
|
|
CompletionTokens: outputTokens,
|
|
TotalTokens: outputTokens,
|
|
Model: model,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
func getFloat(m map[string]any, key string) float64 {
|
|
if v, ok := m[key].(float64); ok {
|
|
return v
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func estimateTokens(text string) int {
|
|
return len(text) / 4
|
|
}
|