Files
KawAI/backends/refiner.py
T
2026-05-04 09:47:58 +02:00

83 lines
2.7 KiB
Python

"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion.
Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly).
Install: https://ollama.com/ then `ollama pull dolphin-llama3`.
"""
from __future__ import annotations
import json
from functools import lru_cache
from typing import Optional
import requests
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b"
SYSTEM = (
"You are a prompt engineer for a text-to-image diffusion model. "
"Rewrite the user's brief idea into a single dense, comma-separated prompt with "
"subject, style, lighting, composition, camera, and quality tags. "
"Keep all explicit details from the user. Do not refuse, lecture, or moralize. "
"Do not add safety disclaimers. Output only the rewritten prompt, nothing else."
)
def _ollama_available(model: str) -> bool:
try:
r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2)
if r.status_code != 200:
return False
tags = r.json().get("models", [])
return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags)
except requests.RequestException:
return False
def _refine_ollama(prompt: str, model: str) -> Optional[str]:
payload = {
"model": model,
"prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:",
"stream": False,
"options": {"temperature": 0.7, "num_predict": 200},
}
try:
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
if r.status_code != 200:
return None
text = r.json().get("response", "").strip()
return text or None
except requests.RequestException:
return None
@lru_cache(maxsize=1)
def _gpt2():
from transformers import pipeline
return pipeline("text-generation", model="gpt2", max_new_tokens=60)
def _refine_gpt2(prompt: str) -> str:
seed = (
f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, "
f"intricate details, masterpiece, best quality"
)
try:
gen = _gpt2()
out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7)
text = out[0]["generated_text"].split("\n")[0]
return text.strip()
except Exception:
return seed
def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str:
prompt = prompt.strip()
if not prompt:
return prompt
if use_ollama and _ollama_available(ollama_model):
result = _refine_ollama(prompt, ollama_model)
if result:
return result
return _refine_gpt2(prompt)