Initial commit
This commit is contained in:
@@ -0,0 +1,118 @@
|
||||
"""Video generation. LTX-Video / CogVideoX / Wan via diffusers pipelines."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .device import get_device, hardware_info, torch_dtype
|
||||
from .memory import apply_memory_strategy
|
||||
from .models import ModelSpec, find
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
MODELS_DIR = ROOT / "models"
|
||||
OUTPUTS = ROOT / "outputs"
|
||||
OUTPUTS.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def _import_pipeline_for(repo: str):
|
||||
"""Pick the right pipeline class. Wan needs diffusers>=0.32 with WanPipeline."""
|
||||
if "LTX" in repo:
|
||||
from diffusers import LTXPipeline
|
||||
return LTXPipeline
|
||||
if "CogVideoX" in repo:
|
||||
from diffusers import CogVideoXPipeline
|
||||
return CogVideoXPipeline
|
||||
if "Wan" in repo:
|
||||
try:
|
||||
from diffusers import WanPipeline
|
||||
return WanPipeline
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Wan 2.1 needs diffusers>=0.32. Run: pip install -U diffusers"
|
||||
) from e
|
||||
if "Hunyuan" in repo:
|
||||
from diffusers import HunyuanVideoPipeline
|
||||
return HunyuanVideoPipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
return DiffusionPipeline
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_pipeline(repo: str):
|
||||
dtype = torch_dtype()
|
||||
cache_dir = str(MODELS_DIR / "diffusers")
|
||||
|
||||
PipelineCls = _import_pipeline_for(repo)
|
||||
pipe = PipelineCls.from_pretrained(repo, torch_dtype=dtype, cache_dir=cache_dir)
|
||||
|
||||
apply_memory_strategy(pipe)
|
||||
return pipe
|
||||
|
||||
|
||||
def _model_kwargs(repo: str, base: dict) -> dict:
|
||||
"""Adjust kwargs per-pipeline. Some pipes don't accept width/height or have different param names."""
|
||||
out = dict(base)
|
||||
if "Wan" in repo:
|
||||
# Wan accepts height/width but expects multiples of 16.
|
||||
out["height"] = (out["height"] // 16) * 16
|
||||
out["width"] = (out["width"] // 16) * 16
|
||||
if "LTX" in repo:
|
||||
# LTX needs 32-multiple resolutions.
|
||||
out["height"] = (out["height"] // 32) * 32
|
||||
out["width"] = (out["width"] // 32) * 32
|
||||
if "CogVideoX" in repo:
|
||||
# CogVideoX has fixed 720x480 default; clamp.
|
||||
out["height"] = min(out["height"], 480)
|
||||
out["width"] = min(out["width"], 720)
|
||||
return out
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
model_id: str = "ltx-video",
|
||||
width: int = 704,
|
||||
height: int = 480,
|
||||
num_frames: int = 73,
|
||||
fps: int = 24,
|
||||
steps: int = 30,
|
||||
guidance: float = 3.0,
|
||||
seed: Optional[int] = None,
|
||||
) -> str:
|
||||
spec: ModelSpec | None = find(model_id)
|
||||
if spec is None or spec.kind != "video":
|
||||
raise ValueError(f"Unknown video model: {model_id}")
|
||||
|
||||
pipe = _load_pipeline(spec.repo)
|
||||
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
generator = None
|
||||
if seed is not None:
|
||||
try:
|
||||
generator = torch.Generator(device=get_device()).manual_seed(int(seed))
|
||||
except RuntimeError:
|
||||
generator = torch.Generator().manual_seed(int(seed))
|
||||
|
||||
base = dict(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt or None,
|
||||
width=int(width),
|
||||
height=int(height),
|
||||
num_frames=int(num_frames),
|
||||
num_inference_steps=int(steps),
|
||||
guidance_scale=float(guidance),
|
||||
generator=generator,
|
||||
)
|
||||
kwargs = _model_kwargs(spec.repo, base)
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
out = pipe(**kwargs)
|
||||
frames = out.frames[0]
|
||||
|
||||
import time
|
||||
path = OUTPUTS / f"video_{int(time.time())}.mp4"
|
||||
export_to_video(frames, str(path), fps=int(fps))
|
||||
return str(path)
|
||||
Reference in New Issue
Block a user