119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines."""
|
|
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from .device import get_device, hardware_info, torch_dtype
|
|
from .memory import apply_memory_strategy
|
|
from .models import ModelSpec, find
|
|
|
|
ROOT = Path(__file__).parent.parent
|
|
MODELS_DIR = ROOT / "models"
|
|
OUTPUTS = ROOT / "outputs"
|
|
OUTPUTS.mkdir(exist_ok=True)
|
|
|
|
|
|
def _import_pipeline_for(repo: str):
|
|
"""Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline."""
|
|
if "LTX" in repo:
|
|
from diffusers import LTXPipeline
|
|
return LTXPipeline
|
|
if "CogVideoX" in repo:
|
|
from diffusers import CogVideoXPipeline
|
|
return CogVideoXPipeline
|
|
if "Wan" in repo:
|
|
try:
|
|
from diffusers import WanPipeline
|
|
return WanPipeline
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers"
|
|
) from e
|
|
if "Hunyuan" in repo:
|
|
from diffusers import HunyuanVideoPipeline
|
|
return HunyuanVideoPipeline
|
|
from diffusers import DiffusionPipeline
|
|
return DiffusionPipeline
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _load_pipeline(repo: str):
|
|
dtype = torch_dtype()
|
|
cache_dir = str(MODELS_DIR / "diffusers")
|
|
|
|
PipelineCls = _import_pipeline_for(repo)
|
|
pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
|
|
|
apply_memory_strategy(pipe)
|
|
return pipe
|
|
|
|
|
|
def _model_kwargs(repo: str, base: dict) -> dict:
|
|
"""Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names."""
|
|
out = dict(base)
|
|
if "Wan" in repo:
|
|
# Wan accepts height/width but expects multiples of 16.
|
|
out["height"] = (out["height"] // 16) * 16
|
|
out["width"] = (out["width"] // 16) * 16
|
|
if "LTX" in repo:
|
|
# LTX needs 32-multiple resolutions.
|
|
out["height"] = (out["height"] // 32) * 32
|
|
out["width"] = (out["width"] // 32) * 32
|
|
if "CogVideoX" in repo:
|
|
# CogVideoX has fixed 720x480 default; clamp.
|
|
out["height"] = min(out["height"], 480)
|
|
out["width"] = min(out["width"], 720)
|
|
return out
|
|
|
|
|
|
def generate(
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
model_id: str = "ltx-video",
|
|
width: int = 704,
|
|
height: int = 480,
|
|
num_frames: int = 73,
|
|
fps: int = 24,
|
|
steps: int = 30,
|
|
guidance: float = 3.0,
|
|
seed: Optional[int] = None,
|
|
) -> str:
|
|
spec: ModelSpec | None = find(model_id)
|
|
if spec is None or spec.kind != "video":
|
|
raise ValueError(f"Unknown video model: {model_id}")
|
|
|
|
pipe = _load_pipeline(spec.repo)
|
|
|
|
import torch
|
|
from diffusers.utils import export_to_video
|
|
|
|
generator = None
|
|
if seed is not None:
|
|
try:
|
|
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
|
except RuntimeError:
|
|
generator = torch.Generator().manual_seed(int(seed))
|
|
|
|
base = dict(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt or None,
|
|
width=int(width),
|
|
height=int(height),
|
|
num_frames=int(num_frames),
|
|
num_inference_steps=int(steps),
|
|
guidance_scale=float(guidance),
|
|
generator=generator,
|
|
)
|
|
kwargs = _model_kwargs(spec.repo, base)
|
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
|
|
out = pipe(**kwargs)
|
|
frames = out.frames[0]
|
|
|
|
import time
|
|
path = OUTPUTS / f"video_{int(time.time())}.mp4"
|
|
export_to_video(frames, str(path), fps=int(fps))
|
|
return str(path)
|