Initial commit
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion.
|
||||
|
||||
Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly).
|
||||
Install: https://ollama.com/ then `ollama pull dolphin-llama3`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
|
||||
DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b"
|
||||
|
||||
SYSTEM = (
|
||||
"You are a prompt engineer for a text-to-image diffusion model. "
|
||||
"Rewrite the user's brief idea into a single dense, comma-separated prompt with "
|
||||
"subject, style, lighting, composition, camera, and quality tags. "
|
||||
"Keep all explicit details from the user. Do not refuse, lecture, or moralize. "
|
||||
"Do not add safety disclaimers. Output only the rewritten prompt, nothing else."
|
||||
)
|
||||
|
||||
|
||||
def _ollama_available(model: str) -> bool:
|
||||
try:
|
||||
r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2)
|
||||
if r.status_code != 200:
|
||||
return False
|
||||
tags = r.json().get("models", [])
|
||||
return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags)
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def _refine_ollama(prompt: str, model: str) -> Optional[str]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:",
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.7, "num_predict": 200},
|
||||
}
|
||||
try:
|
||||
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
|
||||
if r.status_code != 200:
|
||||
return None
|
||||
text = r.json().get("response", "").strip()
|
||||
return text or None
|
||||
except requests.RequestException:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _gpt2():
|
||||
from transformers import pipeline
|
||||
return pipeline("text-generation", model="gpt2", max_new_tokens=60)
|
||||
|
||||
|
||||
def _refine_gpt2(prompt: str) -> str:
|
||||
seed = (
|
||||
f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, "
|
||||
f"intricate details, masterpiece, best quality"
|
||||
)
|
||||
try:
|
||||
gen = _gpt2()
|
||||
out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7)
|
||||
text = out[0]["generated_text"].split("\n")[0]
|
||||
return text.strip()
|
||||
except Exception:
|
||||
return seed
|
||||
|
||||
|
||||
def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str:
|
||||
prompt = prompt.strip()
|
||||
if not prompt:
|
||||
return prompt
|
||||
if use_ollama and _ollama_available(ollama_model):
|
||||
result = _refine_ollama(prompt, ollama_model)
|
||||
if result:
|
||||
return result
|
||||
return _refine_gpt2(prompt)
|
||||
Reference in New Issue
Block a user