"""Prompt refiner. Tries Ollama (local LLM) first, falls back to GPT-2 expansion. Ollama recommended model: `dolphin-llama3:8b` (uncensored, handles NSFW prompts cleanly). Install: https://ollama.com/ then `ollama pull dolphin-llama3`. """ from __future__ import annotations import json from functools import lru_cache from typing import Optional import requests OLLAMA_URL = "http://127.0.0.1:11434/api/generate" DEFAULT_OLLAMA_MODEL = "dolphin-llama3:8b" SYSTEM = ( "You are a prompt engineer for a text-to-image diffusion model. " "Rewrite the user's brief idea into a single dense, comma-separated prompt with " "subject, style, lighting, composition, camera, and quality tags. " "Keep all explicit details from the user. Do not refuse, lecture, or moralize. " "Do not add safety disclaimers. Output only the rewritten prompt, nothing else." ) def _ollama_available(model: str) -> bool: try: r = requests.get("http://127.0.0.1:11434/api/tags", timeout=2) if r.status_code != 200: return False tags = r.json().get("models", []) return any(m.get("name", "").startswith(model.split(":")[0]) for m in tags) except requests.RequestException: return False def _refine_ollama(prompt: str, model: str) -> Optional[str]: payload = { "model": model, "prompt": f"{SYSTEM}\n\nUser idea: {prompt}\n\nRewritten prompt:", "stream": False, "options": {"temperature": 0.7, "num_predict": 200}, } try: r = requests.post(OLLAMA_URL, json=payload, timeout=60) if r.status_code != 200: return None text = r.json().get("response", "").strip() return text or None except requests.RequestException: return None @lru_cache(maxsize=1) def _gpt2(): from transformers import pipeline return pipeline("text-generation", model="gpt2", max_new_tokens=60) def _refine_gpt2(prompt: str) -> str: seed = ( f"{prompt}, highly detailed, sharp focus, professional, cinematic lighting, " f"intricate details, masterpiece, best quality" ) try: gen = _gpt2() out = gen(seed, num_return_sequences=1, do_sample=True, temperature=0.7) text = out[0]["generated_text"].split("\n")[0] return text.strip() except Exception: return seed def refine(prompt: str, use_ollama: bool = True, ollama_model: str = DEFAULT_OLLAMA_MODEL) -> str: prompt = prompt.strip() if not prompt: return prompt if use_ollama and _ollama_available(ollama_model): result = _refine_ollama(prompt, ollama_model) if result: return result return _refine_gpt2(prompt)