197 lines
4.6 KiB
Go
197 lines
4.6 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
)
|
|
|
|
type AnthropicProvider struct {
|
|
apiKey string
|
|
baseURL string
|
|
model string
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func NewAnthropicProvider(apiKey, baseURL, defaultModel string) *AnthropicProvider {
|
|
if baseURL == "" {
|
|
baseURL = "https://api.anthropic.com"
|
|
}
|
|
if defaultModel == "" {
|
|
defaultModel = "claude-sonnet-4-20250514"
|
|
}
|
|
return &AnthropicProvider{
|
|
apiKey: apiKey,
|
|
baseURL: baseURL,
|
|
model: defaultModel,
|
|
httpClient: &http.Client{},
|
|
}
|
|
}
|
|
|
|
func (p *AnthropicProvider) Name() string { return "anthropic" }
|
|
|
|
type anthropicRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []anthropicMessage `json:"messages"`
|
|
System string `json:"system,omitempty"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type anthropicMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type anthropicResponse struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Content []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
} `json:"content"`
|
|
Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
func (p *AnthropicProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
|
model := req.Model
|
|
if model == "" {
|
|
model = p.model
|
|
}
|
|
|
|
system, msgs := extractSystemMessage(req.Messages)
|
|
maxTokens := req.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 4096
|
|
}
|
|
|
|
aReq := anthropicRequest{
|
|
Model: model,
|
|
Messages: toAnthropicMessages(msgs),
|
|
System: system,
|
|
MaxTokens: maxTokens,
|
|
Temperature: req.Temperature,
|
|
Stream: false,
|
|
}
|
|
|
|
body, err := json.Marshal(aReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("x-api-key", p.apiKey)
|
|
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := p.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("anthropic request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
errBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
|
}
|
|
|
|
var aResp anthropicResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&aResp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
content := ""
|
|
for _, c := range aResp.Content {
|
|
if c.Type == "text" {
|
|
content += c.Text
|
|
}
|
|
}
|
|
|
|
return &ChatResponse{
|
|
Content: content,
|
|
Model: aResp.Model,
|
|
PromptTokens: aResp.Usage.InputTokens,
|
|
CompletionTokens: aResp.Usage.OutputTokens,
|
|
TotalTokens: aResp.Usage.InputTokens + aResp.Usage.OutputTokens,
|
|
}, nil
|
|
}
|
|
|
|
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
|
model := req.Model
|
|
if model == "" {
|
|
model = p.model
|
|
}
|
|
|
|
system, msgs := extractSystemMessage(req.Messages)
|
|
maxTokens := req.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 4096
|
|
}
|
|
|
|
aReq := anthropicRequest{
|
|
Model: model,
|
|
Messages: toAnthropicMessages(msgs),
|
|
System: system,
|
|
MaxTokens: maxTokens,
|
|
Temperature: req.Temperature,
|
|
Stream: true,
|
|
}
|
|
|
|
body, err := json.Marshal(aReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/messages", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("x-api-key", p.apiKey)
|
|
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := p.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("anthropic stream request failed: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
errBody, _ := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
return nil, fmt.Errorf("anthropic error (status %d): %s", resp.StatusCode, string(errBody))
|
|
}
|
|
|
|
return resp.Body, nil
|
|
}
|
|
|
|
func extractSystemMessage(msgs []Message) (string, []Message) {
|
|
system := ""
|
|
var filtered []Message
|
|
for _, m := range msgs {
|
|
if m.Role == RoleSystem {
|
|
system = m.Content
|
|
} else {
|
|
filtered = append(filtered, m)
|
|
}
|
|
}
|
|
return system, filtered
|
|
}
|
|
|
|
func toAnthropicMessages(msgs []Message) []anthropicMessage {
|
|
out := make([]anthropicMessage, len(msgs))
|
|
for i, m := range msgs {
|
|
out[i] = anthropicMessage{Role: string(m.Role), Content: m.Content}
|
|
}
|
|
return out
|
|
}
|