Coverage for security / sanitize.py: 96.0%

75 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Input Sanitization & Validation 

3Prevents SQL LIKE injection, path traversal, XSS, and input abuse. 

4""" 

5 

6import ipaddress 

7import re 

8import os 

9import html 

10import logging 

11from pathlib import Path 

12from typing import Optional 

13from urllib.parse import urlparse 

14 

15logger = logging.getLogger('hevolve_security') 

16 

17 

18def escape_like(value: str) -> str: 

19 """ 

20 Escape SQL LIKE wildcards to prevent LIKE injection. 

21 Users searching for '%' would otherwise match everything. 

22 """ 

23 return ( 

24 value 

25 .replace('\\', '\\\\') 

26 .replace('%', '\\%') 

27 .replace('_', '\\_') 

28 ) 

29 

30 

31def sanitize_path(user_input: str, base_dir: str) -> str: 

32 """ 

33 Validate a file path stays within base_dir. 

34 Raises ValueError on path traversal attempt. 

35 

36 Usage: 

37 safe_path = sanitize_path(f"{prompt_id}.json", "prompts") 

38 """ 

39 base = Path(base_dir).resolve() 

40 # Strip any path separators from the input 

41 cleaned = user_input.replace('..', '').replace('/', '').replace('\\', '') 

42 target = (base / cleaned).resolve() 

43 

44 if not str(target).startswith(str(base)): 

45 logger.warning(f"Path traversal blocked: {user_input!r} escapes {base_dir}") 

46 raise ValueError(f"Invalid path: {user_input}") 

47 

48 return str(target) 

49 

50 

51def sanitize_html(text: str) -> str: 

52 """ 

53 Escape HTML entities to prevent stored XSS. 

54 Apply to all user-generated text before JSON serialization. 

55 """ 

56 if not isinstance(text, str): 

57 return text 

58 return html.escape(text, quote=True) 

59 

60 

61def validate_input( 

62 value: str, 

63 max_length: int = 10000, 

64 min_length: int = 0, 

65 pattern: Optional[str] = None, 

66 field_name: str = 'input', 

67) -> str: 

68 """ 

69 Validate input string against length and pattern constraints. 

70 Raises ValueError with descriptive message on failure. 

71 """ 

72 if not isinstance(value, str): 

73 raise ValueError(f"{field_name} must be a string") 

74 

75 value = value.strip() 

76 

77 if len(value) < min_length: 

78 raise ValueError(f"{field_name} must be at least {min_length} characters") 

79 

80 if len(value) > max_length: 

81 raise ValueError(f"{field_name} exceeds maximum length of {max_length}") 

82 

83 if pattern and not re.match(pattern, value): 

84 raise ValueError(f"{field_name} contains invalid characters") 

85 

86 return value 

87 

88 

89def validate_prompt_id(prompt_id) -> str: 

90 """Validate prompt_id is a safe integer string.""" 

91 pid = str(prompt_id).strip() 

92 if not re.match(r'^\d+$', pid): 

93 raise ValueError(f"Invalid prompt_id: must be numeric, got {pid!r}") 

94 return pid 

95 

96 

97def validate_user_id(user_id) -> str: 

98 """Validate user_id is alphanumeric.""" 

99 uid = str(user_id).strip() 

100 if not re.match(r'^[a-zA-Z0-9_-]+$', uid): 

101 raise ValueError(f"Invalid user_id: must be alphanumeric, got {uid!r}") 

102 return uid 

103 

104 

105def validate_username(username: str) -> str: 

106 """Validate username format for social platform.""" 

107 return validate_input( 

108 username, 

109 max_length=50, 

110 min_length=2, 

111 pattern=r'^[a-zA-Z0-9_.@-]+$', 

112 field_name='username', 

113 ) 

114 

115 

116def validate_password(password: str) -> str: 

117 """Validate password meets minimum requirements.""" 

118 return validate_input( 

119 password, 

120 max_length=128, 

121 min_length=8, 

122 field_name='password', 

123 ) 

124 

125 

126def validate_search_query(query: str) -> str: 

127 """Validate and sanitize search query.""" 

128 return validate_input( 

129 query, 

130 max_length=200, 

131 min_length=1, 

132 field_name='search query', 

133 ) 

134 

135 

136def validate_post_content(content: str) -> str: 

137 """Validate post content length.""" 

138 return validate_input( 

139 content, 

140 max_length=40000, 

141 min_length=1, 

142 field_name='post content', 

143 ) 

144 

145 

146def validate_comment(content: str) -> str: 

147 """Validate comment content length.""" 

148 return validate_input( 

149 content, 

150 max_length=10000, 

151 min_length=1, 

152 field_name='comment', 

153 ) 

154 

155 

156def validate_url(url: str, allow_private: bool = False) -> str: 

157 """Validate URL is safe for server-side requests (SSRF protection). 

158 

159 Blocks: private/reserved IPs, non-http(s) schemes, cloud metadata endpoints. 

160 Raises ValueError on unsafe URLs. 

161 """ 

162 if not isinstance(url, str) or not url.strip(): 

163 raise ValueError("URL must be a non-empty string") 

164 

165 parsed = urlparse(url.strip()) 

166 

167 # Scheme check 

168 if parsed.scheme not in ('http', 'https'): 

169 raise ValueError(f"URL scheme must be http or https, got {parsed.scheme!r}") 

170 

171 hostname = parsed.hostname or '' 

172 if not hostname: 

173 raise ValueError("URL must have a hostname") 

174 

175 # Block cloud metadata endpoints 

176 _METADATA_HOSTS = {'169.254.169.254', 'metadata.google.internal', 

177 'metadata.internal', '100.100.100.200'} 

178 if hostname in _METADATA_HOSTS: 

179 raise ValueError("Access to cloud metadata endpoint blocked") 

180 

181 # Block private/reserved IPs (unless explicitly allowed for internal tools) 

182 if not allow_private: 

183 try: 

184 addr = ipaddress.ip_address(hostname) 

185 except ValueError: 

186 # hostname is a domain name, not an IP — that's fine 

187 addr = None 

188 

189 if addr is not None: 

190 if addr.is_private or addr.is_reserved or addr.is_loopback or addr.is_link_local: 

191 raise ValueError(f"URL targets private/reserved IP: {hostname}") 

192 

193 # Block localhost variants 

194 if hostname.lower() in ('localhost', '0.0.0.0', '127.0.0.1', '::1'): 

195 raise ValueError("URL targets localhost") 

196 

197 return url.strip()