Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,167 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(apiKey, baseURL, defaultModel string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
if defaultModel == "" {
|
||||
defaultModel = "gpt-4o-mini"
|
||||
}
|
||||
return &OpenAIProvider{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
model: defaultModel,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string { return "openai" }
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
var oaiResp openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := ""
|
||||
if len(oaiResp.Choices) > 0 {
|
||||
content = oaiResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: oaiResp.Model,
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
CompletionTokens: oaiResp.Usage.CompletionTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (io.ReadCloser, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.model
|
||||
}
|
||||
|
||||
oaiReq := openAIRequest{
|
||||
Model: model,
|
||||
Messages: toOpenAIMessages(req.Messages),
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(oaiReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("openai error (status %d): %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func toOpenAIMessages(msgs []Message) []openAIMessage {
|
||||
out := make([]openAIMessage, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = openAIMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user