Skip to content

Commit bf067db

Browse files
authored
Enhancing URL detection (#57)
* Enhancing URL detection: If allow list has scheme, match exactly * Validating ports * Removing redundant variable assignment * Handle trailing slash in allow lists * Allow scheme-less user input * Fix port matching * improve port matching
1 parent 06c1018 commit bf067db

File tree

2 files changed

+604
-43
lines changed

2 files changed

+604
-43
lines changed

src/guardrails/checks/text/urls.py

Lines changed: 222 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@
2727
from typing import Any
2828
from urllib.parse import ParseResult, urlparse
2929

30-
from pydantic import BaseModel, Field
30+
from pydantic import BaseModel, Field, field_validator
3131

3232
from guardrails.registry import default_spec_registry
3333
from guardrails.spec import GuardrailSpecMetadata
3434
from guardrails.types import GuardrailResult
3535

3636
__all__ = ["urls"]
3737

38+
DEFAULT_PORTS = {
39+
"http": 80,
40+
"https": 443,
41+
}
42+
43+
SCHEME_PREFIX_RE = re.compile(r"^[a-z][a-z0-9+.-]*://")
44+
3845

3946
@dataclass(frozen=True, slots=True)
4047
class UrlDetectionResult:
@@ -66,9 +73,53 @@ class URLConfig(BaseModel):
6673
description="Allow subdomains of allowed domains (e.g. api.example.com if example.com is allowed)",
6774
)
6875

76+
@field_validator("allowed_schemes", mode="before")
77+
@classmethod
78+
def normalize_allowed_schemes(cls, value: Any) -> set[str]:
79+
"""Normalize allowed schemes to bare identifiers without delimiters."""
80+
if value is None:
81+
return {"https"}
82+
83+
if isinstance(value, str):
84+
raw_values = [value]
85+
else:
86+
raw_values = list(value)
87+
88+
normalized: set[str] = set()
89+
for entry in raw_values:
90+
if not isinstance(entry, str):
91+
raise TypeError("allowed_schemes entries must be strings")
92+
cleaned = entry.strip().lower()
93+
if not cleaned:
94+
continue
95+
# Support inputs like "https://", "HTTPS:", or " https "
96+
if cleaned.endswith("://"):
97+
cleaned = cleaned[:-3]
98+
cleaned = cleaned.removesuffix(":")
99+
if cleaned:
100+
normalized.add(cleaned)
101+
102+
if not normalized:
103+
raise ValueError("allowed_schemes must include at least one scheme")
104+
105+
return normalized
106+
69107

70108
def _detect_urls(text: str) -> list[str]:
71-
"""Detect URLs using regex."""
109+
"""Detect URLs using regex patterns with deduplication.
110+
111+
Detects URLs with explicit schemes (http, https, ftp, data, javascript,
112+
vbscript), domain-like patterns without schemes, and IP addresses.
113+
Deduplicates to avoid returning both scheme-ful and scheme-less versions
114+
of the same URL.
115+
116+
Args:
117+
text: The text to scan for URLs.
118+
119+
Returns:
120+
List of unique URL strings found in the text, with trailing
121+
punctuation removed.
122+
"""
72123
# Pattern for cleaning trailing punctuation (] must be escaped)
73124
PUNCTUATION_CLEANUP = r"[.,;:!?)\]]+$"
74125

@@ -155,55 +206,110 @@ def _detect_urls(text: str) -> list[str]:
155206
return list(dict.fromkeys([url for url in final_urls if url]))
156207

157208

158-
def _validate_url_security(url_string: str, config: URLConfig) -> tuple[ParseResult | None, str]:
159-
"""Validate URL using stdlib urllib.parse."""
209+
def _validate_url_security(url_string: str, config: URLConfig) -> tuple[ParseResult | None, str, bool]:
210+
"""Validate URL security properties using urllib.parse.
211+
212+
Checks URL structure, validates the scheme is allowed, and ensures no
213+
credentials are embedded in userinfo if block_userinfo is enabled.
214+
215+
Args:
216+
url_string: The URL string to validate.
217+
config: Configuration specifying allowed schemes and userinfo policy.
218+
219+
Returns:
220+
A tuple of (parsed_url, error_reason, had_explicit_scheme). If validation
221+
succeeds, parsed_url is a ParseResult, error_reason is empty, and
222+
had_explicit_scheme indicates if the original URL included a scheme.
223+
If validation fails, parsed_url is None and error_reason describes the failure.
224+
"""
160225
try:
161-
# Parse URL - preserve original scheme for validation
226+
# Parse URL - track whether scheme was explicit
227+
has_explicit_scheme = False
162228
if "://" in url_string:
163229
# Standard URL with double-slash scheme (http://, https://, ftp://, etc.)
164230
parsed_url = urlparse(url_string)
165231
original_scheme = parsed_url.scheme
232+
has_explicit_scheme = True
166233
elif ":" in url_string and url_string.split(":", 1)[0] in {"data", "javascript", "vbscript", "mailto"}:
167234
# Special single-colon schemes
168235
parsed_url = urlparse(url_string)
169236
original_scheme = parsed_url.scheme
237+
has_explicit_scheme = True
170238
else:
171-
# Add http scheme for parsing, but remember this is a default
239+
# Add http scheme for parsing only (user didn't specify a scheme)
172240
parsed_url = urlparse(f"http://{url_string}")
173-
original_scheme = "http" # Default scheme for scheme-less URLs
241+
original_scheme = None # No explicit scheme
242+
has_explicit_scheme = False
174243

175244
# Basic validation: must have scheme and netloc (except for special schemes)
176245
if not parsed_url.scheme:
177-
return None, "Invalid URL format"
246+
return None, "Invalid URL format", False
178247

179248
# Special schemes like data: and javascript: don't need netloc
180249
special_schemes = {"data", "javascript", "vbscript", "mailto"}
181-
if original_scheme not in special_schemes and not parsed_url.netloc:
182-
return None, "Invalid URL format"
250+
if parsed_url.scheme not in special_schemes and not parsed_url.netloc:
251+
return None, "Invalid URL format", False
183252

184-
# Security validations - use original scheme
185-
if original_scheme not in config.allowed_schemes:
186-
return None, f"Blocked scheme: {original_scheme}"
253+
# Security validations - only validate scheme if it was explicitly provided
254+
if has_explicit_scheme and original_scheme not in config.allowed_schemes:
255+
return None, f"Blocked scheme: {original_scheme}", has_explicit_scheme
187256

188-
if config.block_userinfo and parsed_url.username:
189-
return None, "Contains userinfo (potential credential injection)"
257+
if config.block_userinfo and (parsed_url.username or parsed_url.password):
258+
return None, "Contains userinfo (potential credential injection)", has_explicit_scheme
190259

191260
# Everything else (IPs, localhost, private IPs) goes through allow list logic
192-
return parsed_url, ""
261+
return parsed_url, "", has_explicit_scheme
193262

194263
except (ValueError, UnicodeError, AttributeError) as e:
195264
# Common URL parsing errors:
196265
# - ValueError: Invalid URL structure, invalid port, etc.
197266
# - UnicodeError: Invalid encoding in URL
198267
# - AttributeError: Unexpected URL structure
199-
return None, f"Invalid URL format: {str(e)}"
268+
return None, f"Invalid URL format: {str(e)}", False
200269
except Exception as e:
201270
# Catch any unexpected errors but provide debugging info
202-
return None, f"URL parsing error: {type(e).__name__}: {str(e)}"
271+
return None, f"URL parsing error: {type(e).__name__}: {str(e)}", False
272+
273+
274+
def _safe_get_port(parsed: ParseResult, scheme: str) -> int | None:
275+
"""Safely extract port from ParseResult, handling malformed ports.
276+
277+
Args:
278+
parsed: The parsed URL.
279+
scheme: The URL scheme (for default port lookup).
280+
281+
Returns:
282+
The port number, the default port for the scheme, or None if invalid.
283+
"""
284+
try:
285+
return parsed.port or DEFAULT_PORTS.get(scheme.lower())
286+
except ValueError:
287+
# Port is out of range (0-65535) or malformed
288+
return None
289+
290+
291+
def _is_url_allowed(
292+
parsed_url: ParseResult,
293+
allow_list: list[str],
294+
allow_subdomains: bool,
295+
url_had_explicit_scheme: bool,
296+
) -> bool:
297+
"""Check if parsed URL matches any entry in the allow list.
203298
299+
Supports domain names, IP addresses, CIDR blocks, and full URLs with
300+
paths/ports/query strings. Allow list entries without explicit schemes
301+
match any scheme. Entries with schemes must match exactly against URLs
302+
with explicit schemes, but match any scheme-less URL.
204303
205-
def _is_url_allowed(parsed_url: ParseResult, allow_list: list[str], allow_subdomains: bool) -> bool:
206-
"""Check if URL is allowed."""
304+
Args:
305+
parsed_url: The parsed URL to check.
306+
allow_list: List of allowed URL patterns (domains, IPs, CIDR, full URLs).
307+
allow_subdomains: If True, subdomains of allowed domains are permitted.
308+
url_had_explicit_scheme: Whether the original URL included an explicit scheme.
309+
310+
Returns:
311+
True if the URL matches any allow list entry, False otherwise.
312+
"""
207313
if not allow_list:
208314
return False
209315

@@ -212,30 +318,109 @@ def _is_url_allowed(parsed_url: ParseResult, allow_list: list[str], allow_subdom
212318
return False
213319

214320
url_host = url_host.lower()
321+
url_domain = url_host.replace("www.", "")
322+
scheme_lower = parsed_url.scheme.lower() if parsed_url.scheme else ""
323+
# Safely get port (rejects malformed ports)
324+
url_port = _safe_get_port(parsed_url, scheme_lower)
325+
# Early rejection of malformed ports
326+
try:
327+
_ = parsed_url.port # This will raise ValueError for malformed ports
328+
except ValueError:
329+
return False
330+
url_path = parsed_url.path or "/"
331+
url_query = parsed_url.query
332+
url_fragment = parsed_url.fragment
333+
334+
try:
335+
url_ip = ip_address(url_host)
336+
except (AddressValueError, ValueError):
337+
url_ip = None
215338

216339
for allowed_entry in allow_list:
217340
allowed_entry = allowed_entry.lower().strip()
218341

219-
# Handle IP addresses and CIDR blocks
342+
has_explicit_scheme = bool(SCHEME_PREFIX_RE.match(allowed_entry))
343+
if has_explicit_scheme:
344+
parsed_allowed = urlparse(allowed_entry)
345+
else:
346+
parsed_allowed = urlparse(f"//{allowed_entry}")
347+
allowed_host = (parsed_allowed.hostname or "").lower()
348+
allowed_scheme = parsed_allowed.scheme.lower() if parsed_allowed.scheme else ""
349+
# Check if port was explicitly specified (safely)
350+
try:
351+
allowed_port_explicit = parsed_allowed.port
352+
except ValueError:
353+
allowed_port_explicit = None
354+
allowed_port = _safe_get_port(parsed_allowed, allowed_scheme)
355+
allowed_path = parsed_allowed.path
356+
allowed_query = parsed_allowed.query
357+
allowed_fragment = parsed_allowed.fragment
358+
359+
# Handle IP addresses and CIDR blocks (including schemes)
220360
try:
221-
ip_address(allowed_entry.split("/")[0])
222-
if allowed_entry == url_host or ("/" in allowed_entry and ip_address(url_host) in ip_network(allowed_entry, strict=False)):
361+
allowed_ip = ip_address(allowed_host)
362+
except (AddressValueError, ValueError):
363+
allowed_ip = None
364+
365+
if allowed_ip is not None:
366+
if url_ip is None:
367+
continue
368+
# Scheme matching for IPs: if both allow list and URL have explicit schemes, they must match
369+
if has_explicit_scheme and url_had_explicit_scheme and allowed_scheme and allowed_scheme != scheme_lower:
370+
continue
371+
# Port matching: enforce if allow list has explicit port
372+
if allowed_port_explicit is not None and allowed_port != url_port:
373+
continue
374+
if allowed_ip == url_ip:
223375
return True
376+
377+
network_spec = allowed_host
378+
if parsed_allowed.path not in ("", "/"):
379+
network_spec = f"{network_spec}{parsed_allowed.path}"
380+
try:
381+
if network_spec and "/" in network_spec and url_ip in ip_network(network_spec, strict=False):
382+
return True
383+
except (AddressValueError, ValueError):
384+
# Path segment might not represent a CIDR mask; ignore.
385+
pass
386+
continue
387+
388+
if not allowed_host:
224389
continue
225-
except (AddressValueError, ValueError):
226-
pass
227390

228-
# Handle domain matching
229-
allowed_domain = allowed_entry.replace("www.", "")
230-
url_domain = url_host.replace("www.", "")
391+
allowed_domain = allowed_host.replace("www.", "")
231392

232-
# Exact match always allowed
233-
if url_domain == allowed_domain:
234-
return True
393+
# Port matching: enforce if allow list has explicit port
394+
if allowed_port_explicit is not None and allowed_port != url_port:
395+
continue
396+
397+
host_matches = url_domain == allowed_domain or (
398+
allow_subdomains and url_domain.endswith(f".{allowed_domain}")
399+
)
400+
if not host_matches:
401+
continue
402+
403+
# Scheme matching: if both allow list and URL have explicit schemes, they must match
404+
if has_explicit_scheme and url_had_explicit_scheme and allowed_scheme and allowed_scheme != scheme_lower:
405+
continue
406+
407+
# Path matching with segment boundary respect
408+
if allowed_path not in ("", "/"):
409+
# Normalize trailing slashes to prevent issues with entries like "/api/"
410+
# which should match "/api/users" but would fail with double-slash check
411+
normalized_allowed_path = allowed_path.rstrip("/")
412+
# Ensure path matching respects segment boundaries to prevent
413+
# "/api" from matching "/api2" or "/api-v2"
414+
if url_path != allowed_path and url_path != normalized_allowed_path and not url_path.startswith(f"{normalized_allowed_path}/"):
415+
continue
416+
417+
if allowed_query and allowed_query != url_query:
418+
continue
419+
420+
if allowed_fragment and allowed_fragment != url_fragment:
421+
continue
235422

236-
# Subdomain matching if enabled
237-
if allow_subdomains and url_domain.endswith(f".{allowed_domain}"):
238-
return True
423+
return True
239424

240425
return False
241426

@@ -258,7 +443,7 @@ async def urls(ctx: Any, data: str, config: URLConfig) -> GuardrailResult:
258443

259444
for url_string in detected_urls:
260445
# Validate URL with security checks
261-
parsed_url, error_reason = _validate_url_security(url_string, config)
446+
parsed_url, error_reason, url_had_explicit_scheme = _validate_url_security(url_string, config)
262447

263448
if parsed_url is None:
264449
blocked.append(url_string)
@@ -273,7 +458,7 @@ async def urls(ctx: Any, data: str, config: URLConfig) -> GuardrailResult:
273458
# For hostless schemes, only scheme permission matters (no allow list needed)
274459
# They were already validated for scheme permission in _validate_url_security
275460
allowed.append(url_string)
276-
elif _is_url_allowed(parsed_url, config.url_allow_list, config.allow_subdomains):
461+
elif _is_url_allowed(parsed_url, config.url_allow_list, config.allow_subdomains, url_had_explicit_scheme):
277462
allowed.append(url_string)
278463
else:
279464
blocked.append(url_string)
@@ -282,7 +467,7 @@ async def urls(ctx: Any, data: str, config: URLConfig) -> GuardrailResult:
282467
return GuardrailResult(
283468
tripwire_triggered=bool(blocked),
284469
info={
285-
"guardrail_name": "URL Filter (Direct Config)",
470+
"guardrail_name": "URL Filter",
286471
"config": {
287472
"allowed_schemes": list(config.allowed_schemes),
288473
"block_userinfo": config.block_userinfo,

0 commit comments

Comments
 (0)