83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion.
|
|
|
|
Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly).
|
|
Install: https://ollama.com/ then `ollama pull dolphin-llama3`.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
|
|
DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b"
|
|
|
|
SYSTEM = (
|
|
"You are a prompt engineer for a text-to-image diffusion model. "
|
|
"Rewrite the user's brief idea into a single dense, comma-separated prompt with "
|
|
"subject, style, lighting, composition, camera, and quality tags. "
|
|
"Keep all explicit details from the user. Do not refuse, lecture, or moralize. "
|
|
"Do not add safety disclaimers. Output only the rewritten prompt, nothing else."
|
|
)
|
|
|
|
|
|
def _ollama_available(model: str) -> bool:
|
|
try:
|
|
r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2)
|
|
if r.status_code != 200:
|
|
return False
|
|
tags = r.json().get("models", [])
|
|
return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags)
|
|
except requests.RequestException:
|
|
return False
|
|
|
|
|
|
def _refine_ollama(prompt: str, model: str) -> Optional[str]:
|
|
payload = {
|
|
"model": model,
|
|
"prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:",
|
|
"stream": False,
|
|
"options": {"temperature": 0.7, "num_predict": 200},
|
|
}
|
|
try:
|
|
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
|
|
if r.status_code != 200:
|
|
return None
|
|
text = r.json().get("response", "").strip()
|
|
return text or None
|
|
except requests.RequestException:
|
|
return None
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _gpt2():
|
|
from transformers import pipeline
|
|
return pipeline("text-generation", model="gpt2", max_new_tokens=60)
|
|
|
|
|
|
def _refine_gpt2(prompt: str) -> str:
|
|
seed = (
|
|
f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, "
|
|
f"intricate details, masterpiece, best quality"
|
|
)
|
|
try:
|
|
gen = _gpt2()
|
|
out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7)
|
|
text = out[0]["generated_text"].split("\n")[0]
|
|
return text.strip()
|
|
except Exception:
|
|
return seed
|
|
|
|
|
|
def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str:
|
|
prompt = prompt.strip()
|
|
if not prompt:
|
|
return prompt
|
|
if use_ollama and _ollama_available(ollama_model):
|
|
result = _refine_ollama(prompt, ollama_model)
|
|
if result:
|
|
return result
|
|
return _refine_gpt2(prompt)
|