Files
GovAI/server/pkg/llm/openai.go
T
2026-06-15 23:48:37 +08:00

168 lines
4.0 KiB
Go

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type OpenAIProvider struct {
apiKey string
baseURL string
model string
httpClient *http.Client
}
func NewOpenAIProvider(apiKey, baseURL, defaultModel string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
if defaultModel == "" {
defaultModel = "gpt-4o-mini"
}
return &OpenAIProvider{
apiKey: apiKey,
baseURL: baseURL,
model: defaultModel,
httpClient: &http.Client{},
}
}
func (p *OpenAIProvider) Name() string { return "openai" }
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
model := req.Model
if model == "" {
model = p.model
}
oaiReq := openAIRequest{
Model: model,
Messages: toOpenAIMessages(req.Messages),
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Stream: false,
}
body, err := json.Marshal(oaiReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("openai request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
}
var oaiResp openAIResponse
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
return nil, err
}
content := ""
if len(oaiResp.Choices) > 0 {
content = oaiResp.Choices[0].Message.Content
}
return &ChatResponse{
Content: content,
Model: oaiResp.Model,
PromptTokens: oaiResp.Usage.PromptTokens,
CompletionTokens: oaiResp.Usage.CompletionTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
}, nil
}
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
model := req.Model
if model == "" {
model = p.model
}
oaiReq := openAIRequest{
Model: model,
Messages: toOpenAIMessages(req.Messages),
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Stream: true,
}
body, err := json.Marshal(oaiReq)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("openai stream request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
errBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
}
return resp.Body, nil
}
func toOpenAIMessages(msgs []Message) []openAIMessage {
out := make([]openAIMessage, len(msgs))
for i, m := range msgs {
out[i] = openAIMessage{Role: string(m.Role), Content: m.Content}
}
return out
}