Coverage for security / sanitize.py: 96.0%
75 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Input Sanitization & Validation
3Prevents SQL LIKE injection, path traversal, XSS, and input abuse.
4"""
6import ipaddress
7import re
8import os
9import html
10import logging
11from pathlib import Path
12from typing import Optional
13from urllib.parse import urlparse
15logger = logging.getLogger('hevolve_security')
18def escape_like(value: str) -> str:
19 """
20 Escape SQL LIKE wildcards to prevent LIKE injection.
21 Users searching for '%' would otherwise match everything.
22 """
23 return (
24 value
25 .replace('\\', '\\\\')
26 .replace('%', '\\%')
27 .replace('_', '\\_')
28 )
31def sanitize_path(user_input: str, base_dir: str) -> str:
32 """
33 Validate a file path stays within base_dir.
34 Raises ValueError on path traversal attempt.
36 Usage:
37 safe_path = sanitize_path(f"{prompt_id}.json", "prompts")
38 """
39 base = Path(base_dir).resolve()
40 # Strip any path separators from the input
41 cleaned = user_input.replace('..', '').replace('/', '').replace('\\', '')
42 target = (base / cleaned).resolve()
44 if not str(target).startswith(str(base)):
45 logger.warning(f"Path traversal blocked: {user_input!r} escapes {base_dir}")
46 raise ValueError(f"Invalid path: {user_input}")
48 return str(target)
51def sanitize_html(text: str) -> str:
52 """
53 Escape HTML entities to prevent stored XSS.
54 Apply to all user-generated text before JSON serialization.
55 """
56 if not isinstance(text, str):
57 return text
58 return html.escape(text, quote=True)
61def validate_input(
62 value: str,
63 max_length: int = 10000,
64 min_length: int = 0,
65 pattern: Optional[str] = None,
66 field_name: str = 'input',
67) -> str:
68 """
69 Validate input string against length and pattern constraints.
70 Raises ValueError with descriptive message on failure.
71 """
72 if not isinstance(value, str):
73 raise ValueError(f"{field_name} must be a string")
75 value = value.strip()
77 if len(value) < min_length:
78 raise ValueError(f"{field_name} must be at least {min_length} characters")
80 if len(value) > max_length:
81 raise ValueError(f"{field_name} exceeds maximum length of {max_length}")
83 if pattern and not re.match(pattern, value):
84 raise ValueError(f"{field_name} contains invalid characters")
86 return value
89def validate_prompt_id(prompt_id) -> str:
90 """Validate prompt_id is a safe integer string."""
91 pid = str(prompt_id).strip()
92 if not re.match(r'^\d+$', pid):
93 raise ValueError(f"Invalid prompt_id: must be numeric, got {pid!r}")
94 return pid
97def validate_user_id(user_id) -> str:
98 """Validate user_id is alphanumeric."""
99 uid = str(user_id).strip()
100 if not re.match(r'^[a-zA-Z0-9_-]+$', uid):
101 raise ValueError(f"Invalid user_id: must be alphanumeric, got {uid!r}")
102 return uid
105def validate_username(username: str) -> str:
106 """Validate username format for social platform."""
107 return validate_input(
108 username,
109 max_length=50,
110 min_length=2,
111 pattern=r'^[a-zA-Z0-9_.@-]+$',
112 field_name='username',
113 )
116def validate_password(password: str) -> str:
117 """Validate password meets minimum requirements."""
118 return validate_input(
119 password,
120 max_length=128,
121 min_length=8,
122 field_name='password',
123 )
126def validate_search_query(query: str) -> str:
127 """Validate and sanitize search query."""
128 return validate_input(
129 query,
130 max_length=200,
131 min_length=1,
132 field_name='search query',
133 )
136def validate_post_content(content: str) -> str:
137 """Validate post content length."""
138 return validate_input(
139 content,
140 max_length=40000,
141 min_length=1,
142 field_name='post content',
143 )
146def validate_comment(content: str) -> str:
147 """Validate comment content length."""
148 return validate_input(
149 content,
150 max_length=10000,
151 min_length=1,
152 field_name='comment',
153 )
156def validate_url(url: str, allow_private: bool = False) -> str:
157 """Validate URL is safe for server-side requests (SSRF protection).
159 Blocks: private/reserved IPs, non-http(s) schemes, cloud metadata endpoints.
160 Raises ValueError on unsafe URLs.
161 """
162 if not isinstance(url, str) or not url.strip():
163 raise ValueError("URL must be a non-empty string")
165 parsed = urlparse(url.strip())
167 # Scheme check
168 if parsed.scheme not in ('http', 'https'):
169 raise ValueError(f"URL scheme must be http or https, got {parsed.scheme!r}")
171 hostname = parsed.hostname or ''
172 if not hostname:
173 raise ValueError("URL must have a hostname")
175 # Block cloud metadata endpoints
176 _METADATA_HOSTS = {'169.254.169.254', 'metadata.google.internal',
177 'metadata.internal', '100.100.100.200'}
178 if hostname in _METADATA_HOSTS:
179 raise ValueError("Access to cloud metadata endpoint blocked")
181 # Block private/reserved IPs (unless explicitly allowed for internal tools)
182 if not allow_private:
183 try:
184 addr = ipaddress.ip_address(hostname)
185 except ValueError:
186 # hostname is a domain name, not an IP — that's fine
187 addr = None
189 if addr is not None:
190 if addr.is_private or addr.is_reserved or addr.is_loopback or addr.is_link_local:
191 raise ValueError(f"URL targets private/reserved IP: {hostname}")
193 # Block localhost variants
194 if hostname.lower() in ('localhost', '0.0.0.0', '127.0.0.1', '::1'):
195 raise ValueError("URL targets localhost")
197 return url.strip()