32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
"""LLM Provider 工厂:按配置创建 provider,并执行数据零出域红线校验。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from app.config import EGRESS_PROVIDERS, LLMProviderName, Settings, get_settings
|
|
from app.llm.base import LLMProvider
|
|
from app.llm.providers import DashScopeProvider, VllmProvider
|
|
|
|
|
|
class EgressPolicyError(RuntimeError):
|
|
"""数据零出域红线违规。"""
|
|
|
|
|
|
def get_llm_provider(settings: Settings | None = None) -> LLMProvider:
|
|
settings = settings or get_settings()
|
|
|
|
# 红线:prod 环境禁止公网 provider
|
|
if settings.is_prod and settings.llm_provider in EGRESS_PROVIDERS:
|
|
raise EgressPolicyError(
|
|
f"数据零出域红线违规:prod 环境禁止使用公网 LLM Provider "
|
|
f"'{settings.llm_provider.value}'。"
|
|
)
|
|
|
|
if settings.llm_provider == LLMProviderName.dashscope:
|
|
return DashScopeProvider(
|
|
api_key=settings.dashscope_api_key, model=settings.dashscope_model
|
|
)
|
|
if settings.llm_provider == LLMProviderName.vllm:
|
|
return VllmProvider(base_url=settings.vllm_base_url, model=settings.vllm_model)
|
|
|
|
raise ValueError(f"未知的 LLM Provider: {settings.llm_provider}")
|