Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Minimal client for running an OSS Safeguard policy via an OpenAI-compatible API."""

import json
import os
import re
from dataclasses import dataclass
from typing import Any

from openai import BadRequestError, OpenAI
from threatexchange.classifier.classifier import Classifier

DEFAULT_OPENAI_POLICY_MODEL = "osb-120b-ev3"


def _strip_code_fences(text: str) -> str:
text = text.strip()
text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s*```$", "", text)
return text.strip()


def _try_parse_json_object(text: str) -> dict[str, Any] | None:
text = _strip_code_fences(text)
try:
val = json.loads(text)
if isinstance(val, dict):
return val
except Exception:
pass

# Fallback: try to extract the first JSON object from a longer string.
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
candidate = text[start : end + 1]
try:
val = json.loads(candidate)
if isinstance(val, dict):
return val
except Exception:
return None
return None


def _maybe_raise_helpful_model_error(exc: BadRequestError, *, model: str) -> None:
body = getattr(exc, "body", None)
if not isinstance(body, dict):
return
err = body.get("error")
if not isinstance(err, dict):
return

code = err.get("code")
message = err.get("message")
if code != "model_not_found":
return

msg = str(message) if message else f"Model not found: {model!r}"
raise RuntimeError(
f"{msg}\n\n"
f"This repo assumes the hackathon-provided API model {DEFAULT_OPENAI_POLICY_MODEL!r}.\n"
"If you still see this error, confirm you have access to that model in your OpenAI project/org."
) from exc


@dataclass(frozen=True)
class GPTClassifier(Classifier):
client: OpenAI
model: str

@classmethod
def from_env(cls) -> "GPTClassifier":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("Missing OPENAI_API_KEY")

model = DEFAULT_OPENAI_POLICY_MODEL

client = OpenAI(
api_key=api_key,
organization=os.getenv("OPENAI_ORG_ID") or None,
project=os.getenv("OPENAI_PROJECT_ID") or None,
)
return cls(client=client, model=model) # type: ignore[arg-type]

def get_content_types(self) -> str:
return "text"

def classify(self, *, content: str, policy: str) -> dict[str, Any]:
"""
Returns:
{
"raw_text": "...",
"parsed": { ... } | null
}
"""
raw_text = self._classify_via_responses(content=content, policy=policy)

return {
"raw_text": raw_text,
"parsed": _try_parse_json_object(raw_text),
}

def _classify_via_responses(self, *, content: str, policy: str) -> str:
kwargs: dict[str, Any] = {
"model": self.model,
# Harmony-style message roles: policy in developer message, content in user message.
"input": [
{"role": "developer", "content": policy},
{"role": "user", "content": content},
],
}
try:
response = self.client.responses.create(**kwargs)
except BadRequestError as exc:
_maybe_raise_helpful_model_error(exc, model=self.model)
raise
return (response.output_text or "").strip()
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Basic OSS Safeguard policy (hackathon starter)

Reasoning: medium

You are a content safety classifier. Apply this policy to the user-provided text and output a single JSON object.

## Output schema (JSON only)

Return **only** a JSON object with:

- `label`: one of `"ALLOW"`, `"REVIEW"`, `"BLOCK"`
- `action`: one of `"allow"`, `"route"`, `"block"`
- `category`: one of `"none"`, `"self_harm"`, `"violence"`, `"sexual"`, `"hate_or_harassment"`, `"illicit"`, `"other"`
- `confidence`: number between `0` and `1`
- `rationale`: short, non-sensitive explanation (do not quote the policy)

## Policy rules (high level)

### BLOCK
Use `label="BLOCK"` and `action="block"` if the content includes:
- instructions or intent to do violence or severe wrongdoing
- explicit sexual content involving minors
- credible self-harm intent or encouragement

### REVIEW
Use `label="REVIEW"` and `action="route"` if the content includes:
- ambiguous self-harm ideation without intent
- hateful or harassing language
- requests for wrongdoing (e.g., hacking) without concrete instructions
- sexual content that is suggestive but not explicit

### ALLOW
Use `label="ALLOW"` and `action="allow"` otherwise.

## Reminders

- Be conservative: if uncertain, choose `REVIEW`.
- Return JSON only (no markdown, no backticks).
24 changes: 23 additions & 1 deletion python-threatexchange/threatexchange/cli/classify_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import argparse
import os
import sys
import json
import typing as t
from pathlib import Path

from threatexchange.cli.cli_config import CLISettings
from threatexchange.cli.exceptions import CommandError
Expand All @@ -15,6 +17,7 @@
OpenAIModerationClassifier,
MissingAPIKeyError,
)
from threatexchange.classifier.safeguard.gpt_classifier import GPTClassifier


class ClassifyCommand(command_base.CommandWithSubcommands):
Expand Down Expand Up @@ -52,6 +55,13 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
ap.add_argument(
"-a", "--show-all", action="store_true", help="show all categories"
)
ap.add_argument("-s", "--safeguard", action="store_true")
ap.add_argument(
"-p",
"--policy",
type=Path,
default=(Path("threatexchange/classifier/safeguard/policy/basic_policy.md")),
help="Path to policy file (default: threatexchange/classifier/safeguard/policy/basic_policy.md)",
ap.add_argument(
"-m",
"--model",
Expand All @@ -65,11 +75,15 @@ def __init__(
mod_api: bool = False,
show_all: bool = False,
model: str = "omni-moderation-latest",
safeguard: bool = False,
policy: str = "threatexchange/classifier/safeguard/policy/basic_policy.md",
):
self.input = input
self.mod_api = mod_api
self.show_all = show_all
self.model = model
self.safeguard = safeguard
self.policy = policy

def execute(self, settings: CLISettings) -> None:
# Resolve text input
Expand All @@ -84,7 +98,15 @@ def execute(self, settings: CLISettings) -> None:
# Default to mod-api if no API flag specified
# (currently only mod-api is supported)
try:
classifier = OpenAIModerationClassifier(model=self.model)
if self.safeguard:
classifier = GPTClassifier.from_env()
gpt_result = classifier.classify(
content=self.input, policy=self.policy.read_text(encoding="utf-8")
)
print(json.dumps(gpt_result, indent=2, ensure_ascii=False, sort_keys=True))
return
else:
classifier = OpenAIModerationClassifier(model=self.model)
except MissingAPIKeyError as e:
raise CommandError.user(str(e)) from e

Expand Down
Loading