Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 153 additions & 20 deletions src/guardrails/checks/text/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,21 @@
from typing import Any
from urllib.parse import ParseResult, urlparse

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from guardrails.registry import default_spec_registry
from guardrails.spec import GuardrailSpecMetadata
from guardrails.types import GuardrailResult

__all__ = ["urls"]

DEFAULT_PORTS = {
"http": 80,
"https": 443,
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Limited default port coverage: The DEFAULT_PORTS dictionary only includes mappings for HTTP (80) and HTTPS (443), but the code supports additional schemes like FTP, data, javascript, vbscript, and mailto (as seen in the detection patterns and special scheme handling). FTP typically uses port 21 by default. Consider adding FTP's default port or documenting that only HTTP/HTTPS have default port handling.

Suggested change
"https": 443,
"https": 443,
"ftp": 21,

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

}

SCHEME_PREFIX_RE = re.compile(r"^[a-z][a-z0-9+.-]*://")


@dataclass(frozen=True, slots=True)
class UrlDetectionResult:
Expand Down Expand Up @@ -66,9 +73,54 @@ class URLConfig(BaseModel):
description="Allow subdomains of allowed domains (e.g. api.example.com if example.com is allowed)",
)

@field_validator("allowed_schemes", mode="before")
@classmethod
def normalize_allowed_schemes(cls, value: Any) -> set[str]:
"""Normalize allowed schemes to bare identifiers without delimiters."""
if value is None:
return {"https"}

if isinstance(value, str):
raw_values = [value]
else:
raw_values = list(value)

normalized: set[str] = set()
for entry in raw_values:
if not isinstance(entry, str):
raise TypeError("allowed_schemes entries must be strings")
cleaned = entry.strip().lower()
if not cleaned:
continue
# Support inputs like "https://", "HTTPS:", or " https "
if cleaned.endswith("://"):
cleaned = cleaned[:-3]
cleaned = cleaned.removesuffix(":")
cleaned = cleaned.strip()
if cleaned:
normalized.add(cleaned)

if not normalized:
raise ValueError("allowed_schemes must include at least one scheme")

return normalized


def _detect_urls(text: str) -> list[str]:
"""Detect URLs using regex."""
"""Detect URLs using regex patterns with deduplication.

Detects URLs with explicit schemes (http, https, ftp, data, javascript,
vbscript), domain-like patterns without schemes, and IP addresses.
Deduplicates to avoid returning both scheme-ful and scheme-less versions
of the same URL.

Args:
text: The text to scan for URLs.

Returns:
List of unique URL strings found in the text, with trailing
punctuation removed.
"""
# Pattern for cleaning trailing punctuation (] must be escaped)
PUNCTUATION_CLEANUP = r"[.,;:!?)\]]+$"

Expand Down Expand Up @@ -156,7 +208,20 @@ def _detect_urls(text: str) -> list[str]:


def _validate_url_security(url_string: str, config: URLConfig) -> tuple[ParseResult | None, str]:
"""Validate URL using stdlib urllib.parse."""
"""Validate URL security properties using urllib.parse.

Checks URL structure, validates the scheme is allowed, and ensures no
credentials are embedded in userinfo if block_userinfo is enabled.

Args:
url_string: The URL string to validate.
config: Configuration specifying allowed schemes and userinfo policy.

Returns:
A tuple of (parsed_url, error_reason). If validation succeeds,
parsed_url is a ParseResult and error_reason is empty. If validation
fails, parsed_url is None and error_reason describes the failure.
"""
try:
# Parse URL - preserve original scheme for validation
if "://" in url_string:
Expand Down Expand Up @@ -185,7 +250,7 @@ def _validate_url_security(url_string: str, config: URLConfig) -> tuple[ParseRes
if original_scheme not in config.allowed_schemes:
return None, f"Blocked scheme: {original_scheme}"

if config.block_userinfo and parsed_url.username:
if config.block_userinfo and (parsed_url.username or parsed_url.password):
return None, "Contains userinfo (potential credential injection)"

# Everything else (IPs, localhost, private IPs) goes through allow list logic
Expand All @@ -203,7 +268,20 @@ def _validate_url_security(url_string: str, config: URLConfig) -> tuple[ParseRes


def _is_url_allowed(parsed_url: ParseResult, allow_list: list[str], allow_subdomains: bool) -> bool:
"""Check if URL is allowed."""
"""Check if parsed URL matches any entry in the allow list.

Supports domain names, IP addresses, CIDR blocks, and full URLs with
paths/ports/query strings. Allow list entries without explicit schemes
match any scheme. Entries with schemes must match exactly.

Args:
parsed_url: The parsed URL to check.
allow_list: List of allowed URL patterns (domains, IPs, CIDR, full URLs).
allow_subdomains: If True, subdomains of allowed domains are permitted.

Returns:
True if the URL matches any allow list entry, False otherwise.
"""
if not allow_list:
return False

Expand All @@ -212,30 +290,85 @@ def _is_url_allowed(parsed_url: ParseResult, allow_list: list[str], allow_subdom
return False

url_host = url_host.lower()
url_domain = url_host.replace("www.", "")
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using replace("www.", "") removes all occurrences of "www." from the hostname, not just the prefix. This could lead to unexpected behavior. For example, "www.www.example.com" would be treated as equivalent to "example.com" when matching allow list entries. Consider using removeprefix("www.") instead to only remove the "www." prefix, or document this behavior explicitly if it's intentional.

Suggested change
url_domain = url_host.replace("www.", "")
url_domain = url_host.removeprefix("www.")

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WAI. Same comment as above

scheme_lower = parsed_url.scheme.lower() if parsed_url.scheme else ""
url_port = parsed_url.port or DEFAULT_PORTS.get(scheme_lower)
url_path = parsed_url.path or "/"
url_query = parsed_url.query
url_fragment = parsed_url.fragment

try:
url_ip = ip_address(url_host)
except (AddressValueError, ValueError):
url_ip = None

for allowed_entry in allow_list:
allowed_entry = allowed_entry.lower().strip()

# Handle IP addresses and CIDR blocks
has_explicit_scheme = bool(SCHEME_PREFIX_RE.match(allowed_entry))
if has_explicit_scheme:
parsed_allowed = urlparse(allowed_entry)
else:
parsed_allowed = urlparse(f"//{allowed_entry}")
allowed_host = (parsed_allowed.hostname or "").lower()
allowed_port = parsed_allowed.port
allowed_path = parsed_allowed.path
allowed_query = parsed_allowed.query
allowed_fragment = parsed_allowed.fragment

# Handle IP addresses and CIDR blocks (including schemes)
try:
ip_address(allowed_entry.split("/")[0])
if allowed_entry == url_host or ("/" in allowed_entry and ip_address(url_host) in ip_network(allowed_entry, strict=False)):
allowed_ip = ip_address(allowed_host)
except (AddressValueError, ValueError):
allowed_ip = None

if allowed_ip is not None:
if url_ip is None:
continue
if allowed_port is not None and allowed_port != url_port:
continue
if allowed_ip == url_ip:
return True

network_spec = allowed_host
if parsed_allowed.path not in ("", "/"):
network_spec = f"{network_spec}{parsed_allowed.path}"
try:
if network_spec and "/" in network_spec and url_ip in ip_network(network_spec, strict=False):
return True
except (AddressValueError, ValueError):
# Path segment might not represent a CIDR mask; ignore.
pass
continue
except (AddressValueError, ValueError):
pass

# Handle domain matching
allowed_domain = allowed_entry.replace("www.", "")
url_domain = url_host.replace("www.", "")
if not allowed_host:
continue

# Exact match always allowed
if url_domain == allowed_domain:
return True
allowed_domain = allowed_host.replace("www.", "")
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using replace("www.", "") removes all occurrences of "www." from the hostname, not just the prefix. This could lead to unexpected behavior. For example, "www.www.example.com" would be treated as equivalent to "example.com" when matching allow list entries. Consider using removeprefix("www.") instead to only remove the "www." prefix, or document this behavior explicitly if it's intentional.

Suggested change
allowed_domain = allowed_host.replace("www.", "")
allowed_domain = allowed_host.removeprefix("www.")

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is [nit] or actually beneficial. Extra instances of www. would clearly be a typo and does not add anything malicious. Without removing the extra www. we would have higher unnecessary mismatches


if allowed_port is not None and allowed_port != url_port:
continue

host_matches = url_domain == allowed_domain or (
allow_subdomains and url_domain.endswith(f".{allowed_domain}")
)
if not host_matches:
continue

# Path matching with segment boundary respect
if allowed_path not in ("", "/"):
# Ensure path matching respects segment boundaries to prevent
# "/api" from matching "/api2" or "/api-v2"
if url_path != allowed_path and not url_path.startswith(f"{allowed_path}/"):
continue

if allowed_query and allowed_query != url_query:
continue

if allowed_fragment and allowed_fragment != url_fragment:
continue

# Subdomain matching if enabled
if allow_subdomains and url_domain.endswith(f".{allowed_domain}"):
return True
return True

return False

Expand Down Expand Up @@ -282,7 +415,7 @@ async def urls(ctx: Any, data: str, config: URLConfig) -> GuardrailResult:
return GuardrailResult(
tripwire_triggered=bool(blocked),
info={
"guardrail_name": "URL Filter (Direct Config)",
"guardrail_name": "URL Filter",
"config": {
"allowed_schemes": list(config.allowed_schemes),
"block_userinfo": config.block_userinfo,
Expand Down
Loading
Loading