127 lines
3.8 KiB
Python
127 lines
3.8 KiB
Python
"""通义万相(Wanx)图片生成客户端 — DashScope 异步任务 API"""
|
||
|
||
import time
|
||
import httpx
|
||
from pathlib import Path
|
||
from config import config
|
||
|
||
|
||
class WanxClient:
|
||
"""通义万相文生图客户端,使用 DashScope 异步任务 API"""
|
||
|
||
def __init__(self):
|
||
self.api_key = config.IMAGE_API_KEY or config.OPENAI_API_KEY
|
||
self.model = config.WANX_MODEL
|
||
self.create_url = config.DASHSCOPE_IMAGE_URL
|
||
self.task_url = config.DASHSCOPE_TASK_URL
|
||
self.client = httpx.Client(timeout=120.0)
|
||
|
||
def _headers(self) -> dict:
|
||
return {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
"X-DashScope-Async": "enable",
|
||
}
|
||
|
||
def generate(
|
||
self,
|
||
prompt: str,
|
||
output_dir: Path,
|
||
filename: str = "image.png",
|
||
size: str = "1024*576",
|
||
n: int = 1,
|
||
negative_prompt: str = "",
|
||
style: str = "<auto>",
|
||
) -> Path | None:
|
||
"""
|
||
生成图片并下载到本地。
|
||
|
||
Args:
|
||
prompt: 正向提示词
|
||
output_dir: 输出目录
|
||
filename: 输出文件名
|
||
size: 图片尺寸,格式 "宽*高",如 "1024*576"
|
||
n: 生成数量(1-4)
|
||
negative_prompt: 反向提示词
|
||
style: 风格,<auto> 自动
|
||
|
||
Returns:
|
||
下载后的本地文件路径,失败返回 None
|
||
"""
|
||
if not self.api_key:
|
||
raise ValueError("未配置 DashScope API Key(OPENAI_API_KEY 或 IMAGE_API_KEY)")
|
||
|
||
# Step 1: 提交任务
|
||
payload = {
|
||
"model": self.model,
|
||
"input": {
|
||
"prompt": prompt,
|
||
},
|
||
"parameters": {
|
||
"size": size,
|
||
"n": n,
|
||
"style": style,
|
||
},
|
||
}
|
||
if negative_prompt:
|
||
payload["input"]["negative_prompt"] = negative_prompt
|
||
|
||
resp = self.client.post(self.create_url, headers=self._headers(), json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
if data.get("code"):
|
||
raise RuntimeError(f"万相任务创建失败: {data.get('message', data)}")
|
||
|
||
task_id = data["output"]["task_id"]
|
||
|
||
# Step 2: 轮询任务状态
|
||
image_url = self._poll_task(task_id)
|
||
if not image_url:
|
||
return None
|
||
|
||
# Step 3: 下载图片
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
output_path = output_dir / filename
|
||
|
||
img_resp = self.client.get(image_url)
|
||
img_resp.raise_for_status()
|
||
output_path.write_bytes(img_resp.content)
|
||
|
||
return output_path
|
||
|
||
def _poll_task(self, task_id: str, max_wait: int = 120, interval: int = 3) -> str | None:
|
||
"""轮询任务状态,返回图片 URL"""
|
||
url = f"{self.task_url}/{task_id}"
|
||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||
|
||
elapsed = 0
|
||
while elapsed < max_wait:
|
||
resp = self.client.get(url, headers=headers)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
status = data.get("output", {}).get("task_status", "")
|
||
|
||
if status == "SUCCEEDED":
|
||
results = data["output"].get("results", [])
|
||
if results and results[0].get("url"):
|
||
return results[0]["url"]
|
||
return None
|
||
|
||
if status in ("FAILED", "UNKNOWN"):
|
||
msg = data["output"].get("message", "未知错误")
|
||
raise RuntimeError(f"万相任务失败 (task_id={task_id}): {msg}")
|
||
|
||
time.sleep(interval)
|
||
elapsed += interval
|
||
|
||
raise TimeoutError(f"万相任务超时 (task_id={task_id}),已等待 {max_wait}s")
|
||
|
||
def close(self):
|
||
self.client.close()
|
||
|
||
|
||
# 单例
|
||
wanx_client = WanxClient()
|