"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines.""" from __future__ import annotations from functools import lru_cache from pathlib import Path from typing import Optional from .device import get_device, hardware_info, torch_dtype from .memory import apply_memory_strategy from .models import ModelSpec, find ROOT = Path(__file__).parent.parent MODELS_DIR = ROOT / "models" OUTPUTS = ROOT / "outputs" OUTPUTS.mkdir(exist_ok=True) def _import_pipeline_for(repo: str): """Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline.""" if "LTX" in repo: from diffusers import LTXPipeline return LTXPipeline if "CogVideoX" in repo: from diffusers import CogVideoXPipeline return CogVideoXPipeline if "Wan" in repo: try: from diffusers import WanPipeline return WanPipeline except ImportError as e: raise RuntimeError( "Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers" ) from e if "Hunyuan" in repo: from diffusers import HunyuanVideoPipeline return HunyuanVideoPipeline from diffusers import DiffusionPipeline return DiffusionPipeline @lru_cache(maxsize=1) def _load_pipeline(repo: str): dtype = torch_dtype() cache_dir = str(MODELS_DIR / "diffusers") PipelineCls = _import_pipeline_for(repo) pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir) apply_memory_strategy(pipe) return pipe def _model_kwargs(repo: str, base: dict) -> dict: """Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names.""" out = dict(base) if "Wan" in repo: # Wan accepts height/width but expects multiples of 16. out["height"] = (out["height"] // 16) * 16 out["width"] = (out["width"] // 16) * 16 if "LTX" in repo: # LTX needs 32-multiple resolutions. out["height"] = (out["height"] // 32) * 32 out["width"] = (out["width"] // 32) * 32 if "CogVideoX" in repo: # CogVideoX has fixed 720x480 default; clamp. out["height"] = min(out["height"], 480) out["width"] = min(out["width"], 720) return out def generate( prompt: str, negative_prompt: str = "", model_id: str = "ltx-video", width: int = 704, height: int = 480, num_frames: int = 73, fps: int = 24, steps: int = 30, guidance: float = 3.0, seed: Optional[int] = None, ) -> str: spec: ModelSpec | None = find(model_id) if spec is None or spec.kind != "video": raise ValueError(f"Unknown video model: {model_id}") pipe = _load_pipeline(spec.repo) import torch from diffusers.utils import export_to_video generator = None if seed is not None: try: generator = torch.Generator(device=get_device()).manual_seed(int(seed)) except RuntimeError: generator = torch.Generator().manual_seed(int(seed)) base = dict( prompt=prompt, negative_prompt=negative_prompt or None, width=int(width), height=int(height), num_frames=int(num_frames), num_inference_steps=int(steps), guidance_scale=float(guidance), generator=generator, ) kwargs = _model_kwargs(spec.repo, base) kwargs = {k: v for k, v in kwargs.items() if v is not None} out = pipe(**kwargs) frames = out.frames[0] import time path = OUTPUTS / f"video_{int(time.time())}.mp4" export_to_video(frames, str(path), fps=int(fps)) return str(path)