Initial commit: GovAI 政务AI平台
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
"""通义万相(Wanx)图片生成客户端 — DashScope 异步任务 API"""
|
||||
|
||||
import time
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from config import config
|
||||
|
||||
|
||||
class WanxClient:
|
||||
"""通义万相文生图客户端,使用 DashScope 异步任务 API"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = config.IMAGE_API_KEY or config.OPENAI_API_KEY
|
||||
self.model = config.WANX_MODEL
|
||||
self.create_url = config.DASHSCOPE_IMAGE_URL
|
||||
self.task_url = config.DASHSCOPE_TASK_URL
|
||||
self.client = httpx.Client(timeout=120.0)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
output_dir: Path,
|
||||
filename: str = "image.png",
|
||||
size: str = "1024*576",
|
||||
n: int = 1,
|
||||
negative_prompt: str = "",
|
||||
style: str = "<auto>",
|
||||
) -> Path | None:
|
||||
"""
|
||||
生成图片并下载到本地。
|
||||
|
||||
Args:
|
||||
prompt: 正向提示词
|
||||
output_dir: 输出目录
|
||||
filename: 输出文件名
|
||||
size: 图片尺寸,格式 "宽*高",如 "1024*576"
|
||||
n: 生成数量(1-4)
|
||||
negative_prompt: 反向提示词
|
||||
style: 风格,<auto> 自动
|
||||
|
||||
Returns:
|
||||
下载后的本地文件路径,失败返回 None
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("未配置 DashScope API Key(OPENAI_API_KEY 或 IMAGE_API_KEY)")
|
||||
|
||||
# Step 1: 提交任务
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
},
|
||||
"parameters": {
|
||||
"size": size,
|
||||
"n": n,
|
||||
"style": style,
|
||||
},
|
||||
}
|
||||
if negative_prompt:
|
||||
payload["input"]["negative_prompt"] = negative_prompt
|
||||
|
||||
resp = self.client.post(self.create_url, headers=self._headers(), json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("code"):
|
||||
raise RuntimeError(f"万相任务创建失败: {data.get('message', data)}")
|
||||
|
||||
task_id = data["output"]["task_id"]
|
||||
|
||||
# Step 2: 轮询任务状态
|
||||
image_url = self._poll_task(task_id)
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
# Step 3: 下载图片
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_dir / filename
|
||||
|
||||
img_resp = self.client.get(image_url)
|
||||
img_resp.raise_for_status()
|
||||
output_path.write_bytes(img_resp.content)
|
||||
|
||||
return output_path
|
||||
|
||||
def _poll_task(self, task_id: str, max_wait: int = 120, interval: int = 3) -> str | None:
|
||||
"""轮询任务状态,返回图片 URL"""
|
||||
url = f"{self.task_url}/{task_id}"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
elapsed = 0
|
||||
while elapsed < max_wait:
|
||||
resp = self.client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
status = data.get("output", {}).get("task_status", "")
|
||||
|
||||
if status == "SUCCEEDED":
|
||||
results = data["output"].get("results", [])
|
||||
if results and results[0].get("url"):
|
||||
return results[0]["url"]
|
||||
return None
|
||||
|
||||
if status in ("FAILED", "UNKNOWN"):
|
||||
msg = data["output"].get("message", "未知错误")
|
||||
raise RuntimeError(f"万相任务失败 (task_id={task_id}): {msg}")
|
||||
|
||||
time.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
raise TimeoutError(f"万相任务超时 (task_id={task_id}),已等待 {max_wait}s")
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
|
||||
|
||||
# 单例
|
||||
wanx_client = WanxClient()
|
||||
Reference in New Issue
Block a user