mirror of
https://github.com/ksyasuda/dotfiles.git
synced 2026-03-21 06:11:27 -07:00
update skills
This commit is contained in:
876
.agents/skills/imagegen/scripts/image_gen.py
Normal file
876
.agents/skills/imagegen/scripts/image_gen.py
Normal file
@@ -0,0 +1,876 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate or edit images with the OpenAI Image API.
|
||||
|
||||
Defaults to gpt-image-1.5 and a structured prompt augmentation workflow.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
DEFAULT_MODEL = "gpt-image-1.5"
|
||||
DEFAULT_SIZE = "1024x1024"
|
||||
DEFAULT_QUALITY = "auto"
|
||||
DEFAULT_OUTPUT_FORMAT = "png"
|
||||
DEFAULT_CONCURRENCY = 5
|
||||
DEFAULT_DOWNSCALE_SUFFIX = "-web"
|
||||
|
||||
ALLOWED_SIZES = {"1024x1024", "1536x1024", "1024x1536", "auto"}
|
||||
ALLOWED_QUALITIES = {"low", "medium", "high", "auto"}
|
||||
ALLOWED_BACKGROUNDS = {"transparent", "opaque", "auto", None}
|
||||
|
||||
MAX_IMAGE_BYTES = 50 * 1024 * 1024
|
||||
MAX_BATCH_JOBS = 500
|
||||
|
||||
|
||||
def _die(message: str, code: int = 1) -> None:
|
||||
print(f"Error: {message}", file=sys.stderr)
|
||||
raise SystemExit(code)
|
||||
|
||||
|
||||
def _warn(message: str) -> None:
|
||||
print(f"Warning: {message}", file=sys.stderr)
|
||||
|
||||
|
||||
def _ensure_api_key(dry_run: bool) -> None:
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
print("OPENAI_API_KEY is set.", file=sys.stderr)
|
||||
return
|
||||
if dry_run:
|
||||
_warn("OPENAI_API_KEY is not set; dry-run only.")
|
||||
return
|
||||
_die("OPENAI_API_KEY is not set. Export it before running.")
|
||||
|
||||
|
||||
def _read_prompt(prompt: Optional[str], prompt_file: Optional[str]) -> str:
|
||||
if prompt and prompt_file:
|
||||
_die("Use --prompt or --prompt-file, not both.")
|
||||
if prompt_file:
|
||||
path = Path(prompt_file)
|
||||
if not path.exists():
|
||||
_die(f"Prompt file not found: {path}")
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
if prompt:
|
||||
return prompt.strip()
|
||||
_die("Missing prompt. Use --prompt or --prompt-file.")
|
||||
return "" # unreachable
|
||||
|
||||
|
||||
def _check_image_paths(paths: Iterable[str]) -> List[Path]:
|
||||
resolved: List[Path] = []
|
||||
for raw in paths:
|
||||
path = Path(raw)
|
||||
if not path.exists():
|
||||
_die(f"Image file not found: {path}")
|
||||
if path.stat().st_size > MAX_IMAGE_BYTES:
|
||||
_warn(f"Image exceeds 50MB limit: {path}")
|
||||
resolved.append(path)
|
||||
return resolved
|
||||
|
||||
|
||||
def _normalize_output_format(fmt: Optional[str]) -> str:
|
||||
if not fmt:
|
||||
return DEFAULT_OUTPUT_FORMAT
|
||||
fmt = fmt.lower()
|
||||
if fmt not in {"png", "jpeg", "jpg", "webp"}:
|
||||
_die("output-format must be png, jpeg, jpg, or webp.")
|
||||
return "jpeg" if fmt == "jpg" else fmt
|
||||
|
||||
|
||||
def _validate_size(size: str) -> None:
|
||||
if size not in ALLOWED_SIZES:
|
||||
_die(
|
||||
"size must be one of 1024x1024, 1536x1024, 1024x1536, or auto for GPT image models."
|
||||
)
|
||||
|
||||
|
||||
def _validate_quality(quality: str) -> None:
|
||||
if quality not in ALLOWED_QUALITIES:
|
||||
_die("quality must be one of low, medium, high, or auto.")
|
||||
|
||||
|
||||
def _validate_background(background: Optional[str]) -> None:
|
||||
if background not in ALLOWED_BACKGROUNDS:
|
||||
_die("background must be one of transparent, opaque, or auto.")
|
||||
|
||||
|
||||
def _validate_transparency(background: Optional[str], output_format: str) -> None:
|
||||
if background == "transparent" and output_format not in {"png", "webp"}:
|
||||
_die("transparent background requires output-format png or webp.")
|
||||
|
||||
|
||||
def _validate_generate_payload(payload: Dict[str, Any]) -> None:
|
||||
n = int(payload.get("n", 1))
|
||||
if n < 1 or n > 10:
|
||||
_die("n must be between 1 and 10")
|
||||
size = str(payload.get("size", DEFAULT_SIZE))
|
||||
quality = str(payload.get("quality", DEFAULT_QUALITY))
|
||||
background = payload.get("background")
|
||||
_validate_size(size)
|
||||
_validate_quality(quality)
|
||||
_validate_background(background)
|
||||
oc = payload.get("output_compression")
|
||||
if oc is not None and not (0 <= int(oc) <= 100):
|
||||
_die("output_compression must be between 0 and 100")
|
||||
|
||||
|
||||
def _build_output_paths(
|
||||
out: str,
|
||||
output_format: str,
|
||||
count: int,
|
||||
out_dir: Optional[str],
|
||||
) -> List[Path]:
|
||||
ext = "." + output_format
|
||||
|
||||
if out_dir:
|
||||
out_base = Path(out_dir)
|
||||
out_base.mkdir(parents=True, exist_ok=True)
|
||||
return [out_base / f"image_{i}{ext}" for i in range(1, count + 1)]
|
||||
|
||||
out_path = Path(out)
|
||||
if out_path.exists() and out_path.is_dir():
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
return [out_path / f"image_{i}{ext}" for i in range(1, count + 1)]
|
||||
|
||||
if out_path.suffix == "":
|
||||
out_path = out_path.with_suffix(ext)
|
||||
elif output_format and out_path.suffix.lstrip(".").lower() != output_format:
|
||||
_warn(
|
||||
f"Output extension {out_path.suffix} does not match output-format {output_format}."
|
||||
)
|
||||
|
||||
if count == 1:
|
||||
return [out_path]
|
||||
|
||||
return [
|
||||
out_path.with_name(f"{out_path.stem}-{i}{out_path.suffix}")
|
||||
for i in range(1, count + 1)
|
||||
]
|
||||
|
||||
|
||||
def _augment_prompt(args: argparse.Namespace, prompt: str) -> str:
|
||||
fields = _fields_from_args(args)
|
||||
return _augment_prompt_fields(args.augment, prompt, fields)
|
||||
|
||||
|
||||
def _augment_prompt_fields(augment: bool, prompt: str, fields: Dict[str, Optional[str]]) -> str:
|
||||
if not augment:
|
||||
return prompt
|
||||
|
||||
sections: List[str] = []
|
||||
if fields.get("use_case"):
|
||||
sections.append(f"Use case: {fields['use_case']}")
|
||||
sections.append(f"Primary request: {prompt}")
|
||||
if fields.get("scene"):
|
||||
sections.append(f"Scene/background: {fields['scene']}")
|
||||
if fields.get("subject"):
|
||||
sections.append(f"Subject: {fields['subject']}")
|
||||
if fields.get("style"):
|
||||
sections.append(f"Style/medium: {fields['style']}")
|
||||
if fields.get("composition"):
|
||||
sections.append(f"Composition/framing: {fields['composition']}")
|
||||
if fields.get("lighting"):
|
||||
sections.append(f"Lighting/mood: {fields['lighting']}")
|
||||
if fields.get("palette"):
|
||||
sections.append(f"Color palette: {fields['palette']}")
|
||||
if fields.get("materials"):
|
||||
sections.append(f"Materials/textures: {fields['materials']}")
|
||||
if fields.get("text"):
|
||||
sections.append(f"Text (verbatim): \"{fields['text']}\"")
|
||||
if fields.get("constraints"):
|
||||
sections.append(f"Constraints: {fields['constraints']}")
|
||||
if fields.get("negative"):
|
||||
sections.append(f"Avoid: {fields['negative']}")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _fields_from_args(args: argparse.Namespace) -> Dict[str, Optional[str]]:
|
||||
return {
|
||||
"use_case": getattr(args, "use_case", None),
|
||||
"scene": getattr(args, "scene", None),
|
||||
"subject": getattr(args, "subject", None),
|
||||
"style": getattr(args, "style", None),
|
||||
"composition": getattr(args, "composition", None),
|
||||
"lighting": getattr(args, "lighting", None),
|
||||
"palette": getattr(args, "palette", None),
|
||||
"materials": getattr(args, "materials", None),
|
||||
"text": getattr(args, "text", None),
|
||||
"constraints": getattr(args, "constraints", None),
|
||||
"negative": getattr(args, "negative", None),
|
||||
}
|
||||
|
||||
|
||||
def _print_request(payload: dict) -> None:
|
||||
print(json.dumps(payload, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
def _decode_and_write(images: List[str], outputs: List[Path], force: bool) -> None:
|
||||
for idx, image_b64 in enumerate(images):
|
||||
if idx >= len(outputs):
|
||||
break
|
||||
out_path = outputs[idx]
|
||||
if out_path.exists() and not force:
|
||||
_die(f"Output already exists: {out_path} (use --force to overwrite)")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_bytes(base64.b64decode(image_b64))
|
||||
print(f"Wrote {out_path}")
|
||||
|
||||
|
||||
def _derive_downscale_path(path: Path, suffix: str) -> Path:
|
||||
if suffix and not suffix.startswith("-") and not suffix.startswith("_"):
|
||||
suffix = "-" + suffix
|
||||
return path.with_name(f"{path.stem}{suffix}{path.suffix}")
|
||||
|
||||
|
||||
def _downscale_image_bytes(image_bytes: bytes, *, max_dim: int, output_format: str) -> bytes:
|
||||
try:
|
||||
from PIL import Image
|
||||
except Exception:
|
||||
_die(
|
||||
"Downscaling requires Pillow. Install with `uv pip install pillow` (then re-run)."
|
||||
)
|
||||
|
||||
if max_dim < 1:
|
||||
_die("--downscale-max-dim must be >= 1")
|
||||
|
||||
with Image.open(BytesIO(image_bytes)) as img:
|
||||
img.load()
|
||||
w, h = img.size
|
||||
scale = min(1.0, float(max_dim) / float(max(w, h)))
|
||||
target = (max(1, int(round(w * scale))), max(1, int(round(h * scale))))
|
||||
|
||||
resized = img if target == (w, h) else img.resize(target, Image.Resampling.LANCZOS)
|
||||
|
||||
fmt = output_format.lower()
|
||||
if fmt == "jpg":
|
||||
fmt = "jpeg"
|
||||
|
||||
if fmt == "jpeg":
|
||||
if resized.mode in ("RGBA", "LA") or ("transparency" in getattr(resized, "info", {})):
|
||||
bg = Image.new("RGB", resized.size, (255, 255, 255))
|
||||
bg.paste(resized.convert("RGBA"), mask=resized.convert("RGBA").split()[-1])
|
||||
resized = bg
|
||||
else:
|
||||
resized = resized.convert("RGB")
|
||||
|
||||
out = BytesIO()
|
||||
resized.save(out, format=fmt.upper())
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def _decode_write_and_downscale(
|
||||
images: List[str],
|
||||
outputs: List[Path],
|
||||
*,
|
||||
force: bool,
|
||||
downscale_max_dim: Optional[int],
|
||||
downscale_suffix: str,
|
||||
output_format: str,
|
||||
) -> None:
|
||||
for idx, image_b64 in enumerate(images):
|
||||
if idx >= len(outputs):
|
||||
break
|
||||
out_path = outputs[idx]
|
||||
if out_path.exists() and not force:
|
||||
_die(f"Output already exists: {out_path} (use --force to overwrite)")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
raw = base64.b64decode(image_b64)
|
||||
out_path.write_bytes(raw)
|
||||
print(f"Wrote {out_path}")
|
||||
|
||||
if downscale_max_dim is None:
|
||||
continue
|
||||
|
||||
derived = _derive_downscale_path(out_path, downscale_suffix)
|
||||
if derived.exists() and not force:
|
||||
_die(f"Output already exists: {derived} (use --force to overwrite)")
|
||||
derived.parent.mkdir(parents=True, exist_ok=True)
|
||||
resized = _downscale_image_bytes(raw, max_dim=downscale_max_dim, output_format=output_format)
|
||||
derived.write_bytes(resized)
|
||||
print(f"Wrote {derived}")
|
||||
|
||||
|
||||
def _create_client():
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
_die("openai SDK not installed. Install with `uv pip install openai`.")
|
||||
return OpenAI()
|
||||
|
||||
|
||||
def _create_async_client():
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
except ImportError:
|
||||
try:
|
||||
import openai as _openai # noqa: F401
|
||||
except ImportError:
|
||||
_die("openai SDK not installed. Install with `uv pip install openai`.")
|
||||
_die(
|
||||
"AsyncOpenAI not available in this openai SDK version. Upgrade with `uv pip install -U openai`."
|
||||
)
|
||||
return AsyncOpenAI()
|
||||
|
||||
|
||||
def _slugify(value: str) -> str:
|
||||
value = value.strip().lower()
|
||||
value = re.sub(r"[^a-z0-9]+", "-", value)
|
||||
value = re.sub(r"-{2,}", "-", value).strip("-")
|
||||
return value[:60] if value else "job"
|
||||
|
||||
|
||||
def _normalize_job(job: Any, idx: int) -> Dict[str, Any]:
|
||||
if isinstance(job, str):
|
||||
prompt = job.strip()
|
||||
if not prompt:
|
||||
_die(f"Empty prompt at job {idx}")
|
||||
return {"prompt": prompt}
|
||||
if isinstance(job, dict):
|
||||
if "prompt" not in job or not str(job["prompt"]).strip():
|
||||
_die(f"Missing prompt for job {idx}")
|
||||
return job
|
||||
_die(f"Invalid job at index {idx}: expected string or object.")
|
||||
return {} # unreachable
|
||||
|
||||
|
||||
def _read_jobs_jsonl(path: str) -> List[Dict[str, Any]]:
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
_die(f"Input file not found: {p}")
|
||||
jobs: List[Dict[str, Any]] = []
|
||||
for line_no, raw in enumerate(p.read_text(encoding="utf-8").splitlines(), start=1):
|
||||
line = raw.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
try:
|
||||
item: Any
|
||||
if line.startswith("{"):
|
||||
item = json.loads(line)
|
||||
else:
|
||||
item = line
|
||||
jobs.append(_normalize_job(item, idx=line_no))
|
||||
except json.JSONDecodeError as exc:
|
||||
_die(f"Invalid JSON on line {line_no}: {exc}")
|
||||
if not jobs:
|
||||
_die("No jobs found in input file.")
|
||||
if len(jobs) > MAX_BATCH_JOBS:
|
||||
_die(f"Too many jobs ({len(jobs)}). Max is {MAX_BATCH_JOBS}.")
|
||||
return jobs
|
||||
|
||||
|
||||
def _merge_non_null(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
|
||||
merged = dict(dst)
|
||||
for k, v in src.items():
|
||||
if v is not None:
|
||||
merged[k] = v
|
||||
return merged
|
||||
|
||||
|
||||
def _job_output_paths(
|
||||
*,
|
||||
out_dir: Path,
|
||||
output_format: str,
|
||||
idx: int,
|
||||
prompt: str,
|
||||
n: int,
|
||||
explicit_out: Optional[str],
|
||||
) -> List[Path]:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
ext = "." + output_format
|
||||
|
||||
if explicit_out:
|
||||
base = Path(explicit_out)
|
||||
if base.suffix == "":
|
||||
base = base.with_suffix(ext)
|
||||
elif base.suffix.lstrip(".").lower() != output_format:
|
||||
_warn(
|
||||
f"Job {idx}: output extension {base.suffix} does not match output-format {output_format}."
|
||||
)
|
||||
base = out_dir / base.name
|
||||
else:
|
||||
slug = _slugify(prompt[:80])
|
||||
base = out_dir / f"{idx:03d}-{slug}{ext}"
|
||||
|
||||
if n == 1:
|
||||
return [base]
|
||||
return [
|
||||
base.with_name(f"{base.stem}-{i}{base.suffix}")
|
||||
for i in range(1, n + 1)
|
||||
]
|
||||
|
||||
|
||||
def _extract_retry_after_seconds(exc: Exception) -> Optional[float]:
|
||||
# Best-effort: openai SDK errors vary by version. Prefer a conservative fallback.
|
||||
for attr in ("retry_after", "retry_after_seconds"):
|
||||
val = getattr(exc, attr, None)
|
||||
if isinstance(val, (int, float)) and val >= 0:
|
||||
return float(val)
|
||||
msg = str(exc)
|
||||
m = re.search(r"retry[- ]after[:= ]+([0-9]+(?:\\.[0-9]+)?)", msg, re.IGNORECASE)
|
||||
if m:
|
||||
try:
|
||||
return float(m.group(1))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _is_rate_limit_error(exc: Exception) -> bool:
|
||||
name = exc.__class__.__name__.lower()
|
||||
if "ratelimit" in name or "rate_limit" in name:
|
||||
return True
|
||||
msg = str(exc).lower()
|
||||
return "429" in msg or "rate limit" in msg or "too many requests" in msg
|
||||
|
||||
|
||||
def _is_transient_error(exc: Exception) -> bool:
|
||||
if _is_rate_limit_error(exc):
|
||||
return True
|
||||
name = exc.__class__.__name__.lower()
|
||||
if "timeout" in name or "timedout" in name or "tempor" in name:
|
||||
return True
|
||||
msg = str(exc).lower()
|
||||
return "timeout" in msg or "timed out" in msg or "connection reset" in msg
|
||||
|
||||
|
||||
async def _generate_one_with_retries(
|
||||
client: Any,
|
||||
payload: Dict[str, Any],
|
||||
*,
|
||||
attempts: int,
|
||||
job_label: str,
|
||||
) -> Any:
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(1, attempts + 1):
|
||||
try:
|
||||
return await client.images.generate(**payload)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if not _is_transient_error(exc):
|
||||
raise
|
||||
if attempt == attempts:
|
||||
raise
|
||||
sleep_s = _extract_retry_after_seconds(exc)
|
||||
if sleep_s is None:
|
||||
sleep_s = min(60.0, 2.0**attempt)
|
||||
print(
|
||||
f"{job_label} attempt {attempt}/{attempts} failed ({exc.__class__.__name__}); retrying in {sleep_s:.1f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
await asyncio.sleep(sleep_s)
|
||||
raise last_exc or RuntimeError("unknown error")
|
||||
|
||||
|
||||
async def _run_generate_batch(args: argparse.Namespace) -> int:
|
||||
jobs = _read_jobs_jsonl(args.input)
|
||||
out_dir = Path(args.out_dir)
|
||||
|
||||
base_fields = _fields_from_args(args)
|
||||
base_payload = {
|
||||
"model": args.model,
|
||||
"n": args.n,
|
||||
"size": args.size,
|
||||
"quality": args.quality,
|
||||
"background": args.background,
|
||||
"output_format": args.output_format,
|
||||
"output_compression": args.output_compression,
|
||||
"moderation": args.moderation,
|
||||
}
|
||||
|
||||
if args.dry_run:
|
||||
for i, job in enumerate(jobs, start=1):
|
||||
prompt = str(job["prompt"]).strip()
|
||||
fields = _merge_non_null(base_fields, job.get("fields", {}))
|
||||
# Allow flat job keys as well (use_case, scene, etc.)
|
||||
fields = _merge_non_null(fields, {k: job.get(k) for k in base_fields.keys()})
|
||||
augmented = _augment_prompt_fields(args.augment, prompt, fields)
|
||||
|
||||
job_payload = dict(base_payload)
|
||||
job_payload["prompt"] = augmented
|
||||
job_payload = _merge_non_null(job_payload, {k: job.get(k) for k in base_payload.keys()})
|
||||
job_payload = {k: v for k, v in job_payload.items() if v is not None}
|
||||
|
||||
_validate_generate_payload(job_payload)
|
||||
effective_output_format = _normalize_output_format(job_payload.get("output_format"))
|
||||
_validate_transparency(job_payload.get("background"), effective_output_format)
|
||||
if "output_format" in job_payload:
|
||||
job_payload["output_format"] = effective_output_format
|
||||
|
||||
n = int(job_payload.get("n", 1))
|
||||
outputs = _job_output_paths(
|
||||
out_dir=out_dir,
|
||||
output_format=effective_output_format,
|
||||
idx=i,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
explicit_out=job.get("out"),
|
||||
)
|
||||
downscaled = None
|
||||
if args.downscale_max_dim is not None:
|
||||
downscaled = [
|
||||
str(_derive_downscale_path(p, args.downscale_suffix)) for p in outputs
|
||||
]
|
||||
_print_request(
|
||||
{
|
||||
"endpoint": "/v1/images/generations",
|
||||
"job": i,
|
||||
"outputs": [str(p) for p in outputs],
|
||||
"outputs_downscaled": downscaled,
|
||||
**job_payload,
|
||||
}
|
||||
)
|
||||
return 0
|
||||
|
||||
client = _create_async_client()
|
||||
sem = asyncio.Semaphore(args.concurrency)
|
||||
|
||||
any_failed = False
|
||||
|
||||
async def run_job(i: int, job: Dict[str, Any]) -> Tuple[int, Optional[str]]:
|
||||
nonlocal any_failed
|
||||
prompt = str(job["prompt"]).strip()
|
||||
job_label = f"[job {i}/{len(jobs)}]"
|
||||
|
||||
fields = _merge_non_null(base_fields, job.get("fields", {}))
|
||||
fields = _merge_non_null(fields, {k: job.get(k) for k in base_fields.keys()})
|
||||
augmented = _augment_prompt_fields(args.augment, prompt, fields)
|
||||
|
||||
payload = dict(base_payload)
|
||||
payload["prompt"] = augmented
|
||||
payload = _merge_non_null(payload, {k: job.get(k) for k in base_payload.keys()})
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
n = int(payload.get("n", 1))
|
||||
_validate_generate_payload(payload)
|
||||
effective_output_format = _normalize_output_format(payload.get("output_format"))
|
||||
_validate_transparency(payload.get("background"), effective_output_format)
|
||||
if "output_format" in payload:
|
||||
payload["output_format"] = effective_output_format
|
||||
outputs = _job_output_paths(
|
||||
out_dir=out_dir,
|
||||
output_format=effective_output_format,
|
||||
idx=i,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
explicit_out=job.get("out"),
|
||||
)
|
||||
try:
|
||||
async with sem:
|
||||
print(f"{job_label} starting", file=sys.stderr)
|
||||
started = time.time()
|
||||
result = await _generate_one_with_retries(
|
||||
client,
|
||||
payload,
|
||||
attempts=args.max_attempts,
|
||||
job_label=job_label,
|
||||
)
|
||||
elapsed = time.time() - started
|
||||
print(f"{job_label} completed in {elapsed:.1f}s", file=sys.stderr)
|
||||
images = [item.b64_json for item in result.data]
|
||||
_decode_write_and_downscale(
|
||||
images,
|
||||
outputs,
|
||||
force=args.force,
|
||||
downscale_max_dim=args.downscale_max_dim,
|
||||
downscale_suffix=args.downscale_suffix,
|
||||
output_format=effective_output_format,
|
||||
)
|
||||
return i, None
|
||||
except Exception as exc:
|
||||
any_failed = True
|
||||
print(f"{job_label} failed: {exc}", file=sys.stderr)
|
||||
if args.fail_fast:
|
||||
raise
|
||||
return i, str(exc)
|
||||
|
||||
tasks = [asyncio.create_task(run_job(i, job)) for i, job in enumerate(jobs, start=1)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
raise
|
||||
|
||||
return 1 if any_failed else 0
|
||||
|
||||
|
||||
def _generate_batch(args: argparse.Namespace) -> None:
|
||||
exit_code = asyncio.run(_run_generate_batch(args))
|
||||
if exit_code:
|
||||
raise SystemExit(exit_code)
|
||||
|
||||
|
||||
def _generate(args: argparse.Namespace) -> None:
|
||||
prompt = _read_prompt(args.prompt, args.prompt_file)
|
||||
prompt = _augment_prompt(args, prompt)
|
||||
|
||||
payload = {
|
||||
"model": args.model,
|
||||
"prompt": prompt,
|
||||
"n": args.n,
|
||||
"size": args.size,
|
||||
"quality": args.quality,
|
||||
"background": args.background,
|
||||
"output_format": args.output_format,
|
||||
"output_compression": args.output_compression,
|
||||
"moderation": args.moderation,
|
||||
}
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
output_format = _normalize_output_format(args.output_format)
|
||||
_validate_transparency(args.background, output_format)
|
||||
if "output_format" in payload:
|
||||
payload["output_format"] = output_format
|
||||
output_paths = _build_output_paths(args.out, output_format, args.n, args.out_dir)
|
||||
|
||||
if args.dry_run:
|
||||
_print_request({"endpoint": "/v1/images/generations", **payload})
|
||||
return
|
||||
|
||||
print(
|
||||
"Calling Image API (generation). This can take up to a couple of minutes.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
started = time.time()
|
||||
client = _create_client()
|
||||
result = client.images.generate(**payload)
|
||||
elapsed = time.time() - started
|
||||
print(f"Generation completed in {elapsed:.1f}s.", file=sys.stderr)
|
||||
|
||||
images = [item.b64_json for item in result.data]
|
||||
_decode_write_and_downscale(
|
||||
images,
|
||||
output_paths,
|
||||
force=args.force,
|
||||
downscale_max_dim=args.downscale_max_dim,
|
||||
downscale_suffix=args.downscale_suffix,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
|
||||
def _edit(args: argparse.Namespace) -> None:
|
||||
prompt = _read_prompt(args.prompt, args.prompt_file)
|
||||
prompt = _augment_prompt(args, prompt)
|
||||
|
||||
image_paths = _check_image_paths(args.image)
|
||||
mask_path = Path(args.mask) if args.mask else None
|
||||
if mask_path:
|
||||
if not mask_path.exists():
|
||||
_die(f"Mask file not found: {mask_path}")
|
||||
if mask_path.suffix.lower() != ".png":
|
||||
_warn(f"Mask should be a PNG with an alpha channel: {mask_path}")
|
||||
if mask_path.stat().st_size > MAX_IMAGE_BYTES:
|
||||
_warn(f"Mask exceeds 50MB limit: {mask_path}")
|
||||
|
||||
payload = {
|
||||
"model": args.model,
|
||||
"prompt": prompt,
|
||||
"n": args.n,
|
||||
"size": args.size,
|
||||
"quality": args.quality,
|
||||
"background": args.background,
|
||||
"output_format": args.output_format,
|
||||
"output_compression": args.output_compression,
|
||||
"input_fidelity": args.input_fidelity,
|
||||
"moderation": args.moderation,
|
||||
}
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
output_format = _normalize_output_format(args.output_format)
|
||||
_validate_transparency(args.background, output_format)
|
||||
if "output_format" in payload:
|
||||
payload["output_format"] = output_format
|
||||
output_paths = _build_output_paths(args.out, output_format, args.n, args.out_dir)
|
||||
|
||||
if args.dry_run:
|
||||
payload_preview = dict(payload)
|
||||
payload_preview["image"] = [str(p) for p in image_paths]
|
||||
if mask_path:
|
||||
payload_preview["mask"] = str(mask_path)
|
||||
_print_request({"endpoint": "/v1/images/edits", **payload_preview})
|
||||
return
|
||||
|
||||
print(
|
||||
f"Calling Image API (edit) with {len(image_paths)} image(s).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
started = time.time()
|
||||
client = _create_client()
|
||||
|
||||
with _open_files(image_paths) as image_files, _open_mask(mask_path) as mask_file:
|
||||
request = dict(payload)
|
||||
request["image"] = image_files if len(image_files) > 1 else image_files[0]
|
||||
if mask_file is not None:
|
||||
request["mask"] = mask_file
|
||||
result = client.images.edit(**request)
|
||||
|
||||
elapsed = time.time() - started
|
||||
print(f"Edit completed in {elapsed:.1f}s.", file=sys.stderr)
|
||||
images = [item.b64_json for item in result.data]
|
||||
_decode_write_and_downscale(
|
||||
images,
|
||||
output_paths,
|
||||
force=args.force,
|
||||
downscale_max_dim=args.downscale_max_dim,
|
||||
downscale_suffix=args.downscale_suffix,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
|
||||
def _open_files(paths: List[Path]):
|
||||
return _FileBundle(paths)
|
||||
|
||||
|
||||
def _open_mask(mask_path: Optional[Path]):
|
||||
if mask_path is None:
|
||||
return _NullContext()
|
||||
return _SingleFile(mask_path)
|
||||
|
||||
|
||||
class _NullContext:
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class _SingleFile:
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
self._handle = None
|
||||
|
||||
def __enter__(self):
|
||||
self._handle = self._path.open("rb")
|
||||
return self._handle
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if self._handle:
|
||||
try:
|
||||
self._handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class _FileBundle:
|
||||
def __init__(self, paths: List[Path]):
|
||||
self._paths = paths
|
||||
self._handles: List[object] = []
|
||||
|
||||
def __enter__(self):
|
||||
self._handles = [p.open("rb") for p in self._paths]
|
||||
return self._handles
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
for handle in self._handles:
|
||||
try:
|
||||
handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _add_shared_args(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||
parser.add_argument("--prompt")
|
||||
parser.add_argument("--prompt-file")
|
||||
parser.add_argument("--n", type=int, default=1)
|
||||
parser.add_argument("--size", default=DEFAULT_SIZE)
|
||||
parser.add_argument("--quality", default=DEFAULT_QUALITY)
|
||||
parser.add_argument("--background")
|
||||
parser.add_argument("--output-format")
|
||||
parser.add_argument("--output-compression", type=int)
|
||||
parser.add_argument("--moderation")
|
||||
parser.add_argument("--out", default="output.png")
|
||||
parser.add_argument("--out-dir")
|
||||
parser.add_argument("--force", action="store_true")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--augment", dest="augment", action="store_true")
|
||||
parser.add_argument("--no-augment", dest="augment", action="store_false")
|
||||
parser.set_defaults(augment=True)
|
||||
|
||||
# Prompt augmentation hints
|
||||
parser.add_argument("--use-case")
|
||||
parser.add_argument("--scene")
|
||||
parser.add_argument("--subject")
|
||||
parser.add_argument("--style")
|
||||
parser.add_argument("--composition")
|
||||
parser.add_argument("--lighting")
|
||||
parser.add_argument("--palette")
|
||||
parser.add_argument("--materials")
|
||||
parser.add_argument("--text")
|
||||
parser.add_argument("--constraints")
|
||||
parser.add_argument("--negative")
|
||||
|
||||
# Post-processing (optional): generate an additional downscaled copy for fast web loading.
|
||||
parser.add_argument("--downscale-max-dim", type=int)
|
||||
parser.add_argument("--downscale-suffix", default=DEFAULT_DOWNSCALE_SUFFIX)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Generate or edit images via the Image API")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
gen_parser = subparsers.add_parser("generate", help="Create a new image")
|
||||
_add_shared_args(gen_parser)
|
||||
gen_parser.set_defaults(func=_generate)
|
||||
|
||||
batch_parser = subparsers.add_parser(
|
||||
"generate-batch",
|
||||
help="Generate multiple prompts concurrently (JSONL input)",
|
||||
)
|
||||
_add_shared_args(batch_parser)
|
||||
batch_parser.add_argument("--input", required=True, help="Path to JSONL file (one job per line)")
|
||||
batch_parser.add_argument("--concurrency", type=int, default=DEFAULT_CONCURRENCY)
|
||||
batch_parser.add_argument("--max-attempts", type=int, default=3)
|
||||
batch_parser.add_argument("--fail-fast", action="store_true")
|
||||
batch_parser.set_defaults(func=_generate_batch)
|
||||
|
||||
edit_parser = subparsers.add_parser("edit", help="Edit an existing image")
|
||||
_add_shared_args(edit_parser)
|
||||
edit_parser.add_argument("--image", action="append", required=True)
|
||||
edit_parser.add_argument("--mask")
|
||||
edit_parser.add_argument("--input-fidelity")
|
||||
edit_parser.set_defaults(func=_edit)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.n < 1 or args.n > 10:
|
||||
_die("--n must be between 1 and 10")
|
||||
if getattr(args, "concurrency", 1) < 1 or getattr(args, "concurrency", 1) > 25:
|
||||
_die("--concurrency must be between 1 and 25")
|
||||
if getattr(args, "max_attempts", 3) < 1 or getattr(args, "max_attempts", 3) > 10:
|
||||
_die("--max-attempts must be between 1 and 10")
|
||||
if args.output_compression is not None and not (0 <= args.output_compression <= 100):
|
||||
_die("--output-compression must be between 0 and 100")
|
||||
if args.command == "generate-batch" and not args.out_dir:
|
||||
_die("generate-batch requires --out-dir")
|
||||
if getattr(args, "downscale_max_dim", None) is not None and args.downscale_max_dim < 1:
|
||||
_die("--downscale-max-dim must be >= 1")
|
||||
|
||||
_validate_size(args.size)
|
||||
_validate_quality(args.quality)
|
||||
_validate_background(args.background)
|
||||
_ensure_api_key(args.dry_run)
|
||||
|
||||
args.func(args)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user