Files
KawAI/backends/safety.py
T
2026-05-04 09:47:58 +02:00

186 lines
5.2 KiB
Python

"""Safety filter: block CSAM only. All other content (NSFW, gore, violence) allowed.
Layers:
1. Prompt keyword gate: rejects child-term + sexual-term combinations.
2. Output check: NudeNet detects nudity. If nudity present, run face detection
(MTCNN) + age classifier (ViT) on every face. Block if any face is classified
as a minor with high confidence.
This is best-effort. The user is legally responsible for use of generated content.
"""
from __future__ import annotations
import re
import tempfile
from dataclasses import dataclass
from functools import lru_cache
from typing import Iterable
from PIL import Image
# --- prompt keyword gate ---------------------------------------------------
_CHILD_TERMS = [
r"\b(child|children|kid|kids|minor|underage|under-?age|preteen|pre-?teen)\b",
r"\b(toddler|infant|baby|babies)\b",
r"\b(\d|0?[0-9]|1[0-7])\s*(yo|y/o|year[- ]?old)\b",
r"\bloli(con)?\b",
r"\bshota(con)?\b",
r"\bcp\b",
]
_SEXUAL_TERMS = [
r"\b(nude|naked|nsfw|porn|sex|sexual|sexy|erotic|explicit)\b",
r"\b(penis|vagina|breast|nipple|genital|cum|orgasm)\b",
r"\b(intercourse|fellatio|cunnilingus|masturbat)\w*\b",
r"\b(rape|molest)\w*\b",
]
_CHILD_RE = re.compile("|".join(_CHILD_TERMS), re.IGNORECASE)
_SEX_RE = re.compile("|".join(_SEXUAL_TERMS), re.IGNORECASE)
@dataclass
class SafetyResult:
allowed: bool
reason: str = ""
def check_prompt(prompt: str) -> SafetyResult:
if _CHILD_RE.search(prompt) and _SEX_RE.search(prompt):
return SafetyResult(False, "Prompt blocked: combines minor and sexual terms (CSAM gate).")
return SafetyResult(True)
# --- nudity detection ------------------------------------------------------
_NUDITY_LABELS = {
"FEMALE_BREAST_EXPOSED",
"FEMALE_GENITALIA_EXPOSED",
"MALE_GENITALIA_EXPOSED",
"BUTTOCKS_EXPOSED",
"ANUS_EXPOSED",
}
@lru_cache(maxsize=1)
def _nudenet():
try:
from nudenet import NudeDetector
return NudeDetector()
except Exception:
return None
def _has_nudity(detections: Iterable[dict]) -> bool:
for d in detections:
label = d.get("class") or d.get("label") or ""
score = float(d.get("score", 0.0))
if label in _NUDITY_LABELS and score >= 0.5:
return True
return False
# --- face detection + age classification -----------------------------------
@lru_cache(maxsize=1)
def _mtcnn():
try:
from facenet_pytorch import MTCNN
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
return MTCNN(keep_all=True, device=device, post_process=False, min_face_size=40)
except Exception:
return None
@lru_cache(maxsize=1)
def _age_classifier():
try:
from transformers import pipeline
return pipeline(
"image-classification",
model="nateraw/vit-age-classifier",
top_k=3,
)
except Exception:
return None
# Buckets reported by nateraw/vit-age-classifier.
# Conservative minor set. "10-19" includes some adults — treat as minor only on
# strong confidence to limit false positives on young-looking adults.
_HARD_MINOR = {"0-2", "3-9"}
_SOFT_MINOR = {"10-19"}
def _faces(img: Image.Image):
mtcnn = _mtcnn()
if mtcnn is None:
return []
try:
boxes, probs = mtcnn.detect(img)
except Exception:
return []
if boxes is None:
return []
if probs is None:
probs = [None] * len(boxes)
out = []
for box, prob in zip(boxes, probs):
if prob is None or float(prob) < 0.9:
continue
x1, y1, x2, y2 = [int(max(0, v)) for v in box]
if x2 - x1 < 30 or y2 - y1 < 30:
continue
out.append(img.crop((x1, y1, x2, y2)))
return out
def _is_minor_face(face_img: Image.Image) -> tuple[bool, str]:
clf = _age_classifier()
if clf is None:
return False, ""
try:
preds = clf(face_img)
except Exception:
return False, ""
# preds is list[dict(label, score)] sorted by score desc
top = preds[0] if preds else None
if not top:
return False, ""
label = top["label"]
score = float(top["score"])
if label in _HARD_MINOR and score >= 0.55:
return True, f"minor face detected ({label}, conf={score:.2f})"
if label in _SOFT_MINOR and score >= 0.85:
return True, f"likely minor face ({label}, conf={score:.2f})"
return False, ""
def check_image(img: Image.Image) -> SafetyResult:
"""Block if (nudity present) AND (any face classified as minor)."""
det = _nudenet()
if det is None:
# No nudity detector available — fall through. Prompt gate is primary defense.
return SafetyResult(True)
try:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
img.save(f.name)
results = det.detect(f.name)
except Exception:
return SafetyResult(True)
if not _has_nudity(results):
return SafetyResult(True)
faces = _faces(img)
for face in faces:
is_minor, reason = _is_minor_face(face)
if is_minor:
return SafetyResult(
False,
f"Output blocked: nudity + {reason}. Image discarded.",
)
return SafetyResult(True)