81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
"""具体 LLM Provider 实现:DashScope(公网千问,仅 dev)、vLLM(本地,prod)。
|
|
|
|
两者均走 OpenAI 兼容的 /chat/completions 协议。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import httpx
|
|
|
|
from app.llm.base import ChatMessage, LLMProvider, LLMResponse
|
|
|
|
|
|
class DashScopeProvider(LLMProvider):
|
|
"""公网千问(DashScope,OpenAI 兼容模式)。仅限开发测试,且只允许脱敏/样例假数据。"""
|
|
|
|
name = "dashscope"
|
|
egress = True # 走公网,出域
|
|
|
|
_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
|
|
def __init__(self, api_key: str, model: str, timeout: float = 30.0) -> None:
|
|
self._api_key = api_key
|
|
self._model = model
|
|
self._timeout = timeout
|
|
|
|
def chat(self, messages: list[ChatMessage], **kwargs) -> LLMResponse:
|
|
payload = {
|
|
"model": self._model,
|
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
**kwargs,
|
|
}
|
|
headers = {"Authorization": f"Bearer {self._api_key}"}
|
|
with httpx.Client(timeout=self._timeout) as client:
|
|
resp = client.post(
|
|
f"{self._BASE_URL}/chat/completions", json=payload, headers=headers
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
content = data["choices"][0]["message"]["content"]
|
|
return LLMResponse(
|
|
content=content, model=self._model, provider=self.name, egress=True, raw=data
|
|
)
|
|
|
|
def health(self) -> bool:
|
|
return bool(self._api_key)
|
|
|
|
|
|
class VllmProvider(LLMProvider):
|
|
"""本地 vLLM(OpenAI 兼容)。生产使用,数据不出域。"""
|
|
|
|
name = "vllm"
|
|
egress = False
|
|
|
|
def __init__(self, base_url: str, model: str, timeout: float = 60.0) -> None:
|
|
self._base_url = base_url.rstrip("/")
|
|
self._model = model
|
|
self._timeout = timeout
|
|
|
|
def chat(self, messages: list[ChatMessage], **kwargs) -> LLMResponse:
|
|
payload = {
|
|
"model": self._model,
|
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
**kwargs,
|
|
}
|
|
with httpx.Client(timeout=self._timeout) as client:
|
|
resp = client.post(f"{self._base_url}/chat/completions", json=payload)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
content = data["choices"][0]["message"]["content"]
|
|
return LLMResponse(
|
|
content=content, model=self._model, provider=self.name, egress=False, raw=data
|
|
)
|
|
|
|
def health(self) -> bool:
|
|
try:
|
|
with httpx.Client(timeout=5.0) as client:
|
|
resp = client.get(f"{self._base_url}/models")
|
|
return resp.status_code == 200
|
|
except httpx.HTTPError:
|
|
return False
|