Files
KawAI/backends/video_ltx.py
T
2026-05-04 09:47:58 +02:00

119 lines
3.6 KiB
Python

"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines."""
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from typing import Optional
from .device import get_device, hardware_info, torch_dtype
from .memory import apply_memory_strategy
from .models import ModelSpec, find
ROOT = Path(__file__).parent.parent
MODELS_DIR = ROOT / "models"
OUTPUTS = ROOT / "outputs"
OUTPUTS.mkdir(exist_ok=True)
def _import_pipeline_for(repo: str):
"""Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline."""
if "LTX" in repo:
from diffusers import LTXPipeline
return LTXPipeline
if "CogVideoX" in repo:
from diffusers import CogVideoXPipeline
return CogVideoXPipeline
if "Wan" in repo:
try:
from diffusers import WanPipeline
return WanPipeline
except ImportError as e:
raise RuntimeError(
"Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers"
) from e
if "Hunyuan" in repo:
from diffusers import HunyuanVideoPipeline
return HunyuanVideoPipeline
from diffusers import DiffusionPipeline
return DiffusionPipeline
@lru_cache(maxsize=1)
def _load_pipeline(repo: str):
dtype = torch_dtype()
cache_dir = str(MODELS_DIR / "diffusers")
PipelineCls = _import_pipeline_for(repo)
pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
apply_memory_strategy(pipe)
return pipe
def _model_kwargs(repo: str, base: dict) -> dict:
"""Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names."""
out = dict(base)
if "Wan" in repo:
# Wan accepts height/width but expects multiples of 16.
out["height"] = (out["height"] // 16) * 16
out["width"] = (out["width"] // 16) * 16
if "LTX" in repo:
# LTX needs 32-multiple resolutions.
out["height"] = (out["height"] // 32) * 32
out["width"] = (out["width"] // 32) * 32
if "CogVideoX" in repo:
# CogVideoX has fixed 720x480 default; clamp.
out["height"] = min(out["height"], 480)
out["width"] = min(out["width"], 720)
return out
def generate(
prompt: str,
negative_prompt: str = "",
model_id: str = "ltx-video",
width: int = 704,
height: int = 480,
num_frames: int = 73,
fps: int = 24,
steps: int = 30,
guidance: float = 3.0,
seed: Optional[int] = None,
) -> str:
spec: ModelSpec | None = find(model_id)
if spec is None or spec.kind != "video":
raise ValueError(f"Unknown video model: {model_id}")
pipe = _load_pipeline(spec.repo)
import torch
from diffusers.utils import export_to_video
generator = None
if seed is not None:
try:
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
except RuntimeError:
generator = torch.Generator().manual_seed(int(seed))
base = dict(
prompt=prompt,
negative_prompt=negative_prompt or None,
width=int(width),
height=int(height),
num_frames=int(num_frames),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=generator,
)
kwargs = _model_kwargs(spec.repo, base)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
out = pipe(**kwargs)
frames = out.frames[0]
import time
path = OUTPUTS / f"video_{int(time.time())}.mp4"
export_to_video(frames, str(path), fps=int(fps))
return str(path)