"""通义万相(Wanx)图片生成客户端 — DashScope 异步任务 API""" import time import httpx from pathlib import Path from config import config class WanxClient: """通义万相文生图客户端,使用 DashScope 异步任务 API""" def __init__(self): self.api_key = config.IMAGE_API_KEY or config.OPENAI_API_KEY self.model = config.WANX_MODEL self.create_url = config.DASHSCOPE_IMAGE_URL self.task_url = config.DASHSCOPE_TASK_URL self.client = httpx.Client(timeout=120.0) def _headers(self) -> dict: return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable", } def generate( self, prompt: str, output_dir: Path, filename: str = "image.png", size: str = "1024*576", n: int = 1, negative_prompt: str = "", style: str = "", ) -> Path | None: """ 生成图片并下载到本地。 Args: prompt: 正向提示词 output_dir: 输出目录 filename: 输出文件名 size: 图片尺寸,格式 "宽*高",如 "1024*576" n: 生成数量(1-4) negative_prompt: 反向提示词 style: 风格, 自动 Returns: 下载后的本地文件路径,失败返回 None """ if not self.api_key: raise ValueError("未配置 DashScope API Key(OPENAI_API_KEY 或 IMAGE_API_KEY)") # Step 1: 提交任务 payload = { "model": self.model, "input": { "prompt": prompt, }, "parameters": { "size": size, "n": n, "style": style, }, } if negative_prompt: payload["input"]["negative_prompt"] = negative_prompt resp = self.client.post(self.create_url, headers=self._headers(), json=payload) resp.raise_for_status() data = resp.json() if data.get("code"): raise RuntimeError(f"万相任务创建失败: {data.get('message', data)}") task_id = data["output"]["task_id"] # Step 2: 轮询任务状态 image_url = self._poll_task(task_id) if not image_url: return None # Step 3: 下载图片 output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / filename img_resp = self.client.get(image_url) img_resp.raise_for_status() output_path.write_bytes(img_resp.content) return output_path def _poll_task(self, task_id: str, max_wait: int = 120, interval: int = 3) -> str | None: """轮询任务状态,返回图片 URL""" url = f"{self.task_url}/{task_id}" headers = {"Authorization": f"Bearer {self.api_key}"} elapsed = 0 while elapsed < max_wait: resp = self.client.get(url, headers=headers) resp.raise_for_status() data = resp.json() status = data.get("output", {}).get("task_status", "") if status == "SUCCEEDED": results = data["output"].get("results", []) if results and results[0].get("url"): return results[0]["url"] return None if status in ("FAILED", "UNKNOWN"): msg = data["output"].get("message", "未知错误") raise RuntimeError(f"万相任务失败 (task_id={task_id}): {msg}") time.sleep(interval) elapsed += interval raise TimeoutError(f"万相任务超时 (task_id={task_id}),已等待 {max_wait}s") def close(self): self.client.close() # 单例 wanx_client = WanxClient()