Files
2026-06-15 23:48:37 +08:00

127 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""通义万相(Wanx)图片生成客户端 — DashScope 异步任务 API"""
import time
import httpx
from pathlib import Path
from config import config
class WanxClient:
"""通义万相文生图客户端,使用 DashScope 异步任务 API"""
def __init__(self):
self.api_key = config.IMAGE_API_KEY or config.OPENAI_API_KEY
self.model = config.WANX_MODEL
self.create_url = config.DASHSCOPE_IMAGE_URL
self.task_url = config.DASHSCOPE_TASK_URL
self.client = httpx.Client(timeout=120.0)
def _headers(self) -> dict:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
def generate(
self,
prompt: str,
output_dir: Path,
filename: str = "image.png",
size: str = "1024*576",
n: int = 1,
negative_prompt: str = "",
style: str = "<auto>",
) -> Path | None:
"""
生成图片并下载到本地。
Args:
prompt: 正向提示词
output_dir: 输出目录
filename: 输出文件名
size: 图片尺寸,格式 "宽*高",如 "1024*576"
n: 生成数量(1-4)
negative_prompt: 反向提示词
style: 风格,<auto> 自动
Returns:
下载后的本地文件路径,失败返回 None
"""
if not self.api_key:
raise ValueError("未配置 DashScope API KeyOPENAI_API_KEY 或 IMAGE_API_KEY")
# Step 1: 提交任务
payload = {
"model": self.model,
"input": {
"prompt": prompt,
},
"parameters": {
"size": size,
"n": n,
"style": style,
},
}
if negative_prompt:
payload["input"]["negative_prompt"] = negative_prompt
resp = self.client.post(self.create_url, headers=self._headers(), json=payload)
resp.raise_for_status()
data = resp.json()
if data.get("code"):
raise RuntimeError(f"万相任务创建失败: {data.get('message', data)}")
task_id = data["output"]["task_id"]
# Step 2: 轮询任务状态
image_url = self._poll_task(task_id)
if not image_url:
return None
# Step 3: 下载图片
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / filename
img_resp = self.client.get(image_url)
img_resp.raise_for_status()
output_path.write_bytes(img_resp.content)
return output_path
def _poll_task(self, task_id: str, max_wait: int = 120, interval: int = 3) -> str | None:
"""轮询任务状态,返回图片 URL"""
url = f"{self.task_url}/{task_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}
elapsed = 0
while elapsed < max_wait:
resp = self.client.get(url, headers=headers)
resp.raise_for_status()
data = resp.json()
status = data.get("output", {}).get("task_status", "")
if status == "SUCCEEDED":
results = data["output"].get("results", [])
if results and results[0].get("url"):
return results[0]["url"]
return None
if status in ("FAILED", "UNKNOWN"):
msg = data["output"].get("message", "未知错误")
raise RuntimeError(f"万相任务失败 (task_id={task_id}): {msg}")
time.sleep(interval)
elapsed += interval
raise TimeoutError(f"万相任务超时 (task_id={task_id}),已等待 {max_wait}s")
def close(self):
self.client.close()
# 单例
wanx_client = WanxClient()