Coverage for core / platform / pr_guardian.py: 91.7%
206 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2PR Guardian — Autonomous code quality enforcement for HART OS.
4AST-based static analysis providing cyclomatic complexity, function length,
5nesting depth, and import analysis. Zero external dependencies — stdlib only.
7Integrates with PRReviewService to provide enhanced reviews that go beyond
8simple LOC counting.
10Usage:
11 from core.platform.pr_guardian import PRGuardian, CodeMetrics
13 # Analyze a single file
14 metrics = CodeMetrics.analyze(source)
15 violations = PRGuardian.check_thresholds(metrics)
17 # Full PR analysis
18 report = PRGuardian.analyze_diff(diff_text, changed_files)
19 comment = PRGuardian.generate_review_comment(report)
20"""
22import ast
23import re
24from typing import Any, Dict, List, Tuple
26# ─── Thresholds (frozen by convention) ────────────────────────────
28MAX_CYCLOMATIC_COMPLEXITY = 15
29MAX_FUNCTION_LENGTH = 100
30MAX_NESTING_DEPTH = 5
31MAX_FILE_LENGTH = 1000
32BLOCKED_IMPORTS = frozenset({
33 'subprocess', 'ctypes', 'multiprocessing',
34 'pickle', 'shelve', 'marshal',
35})
37# PR checklist keys
38_CHECKLIST_KEYS = [
39 'tests_added', 'docs_updated', 'no_protected_files',
40 'manifest_validated', 'sandbox_passes',
41]
44# ─── CodeMetrics ─────────────────────────────────────────────────
46class CodeMetrics:
47 """AST-based code quality metrics. All static methods, stdlib only."""
49 @staticmethod
50 def cyclomatic_complexity(source: str) -> List[Dict[str, Any]]:
51 """Compute cyclomatic complexity per function/method.
53 CC = 1 + number of decision points (if, elif, for, while, and, or,
54 except, with, assert, ternary IfExp, boolean ops).
56 Returns list of {name, line, complexity}.
57 """
58 try:
59 tree = ast.parse(source)
60 except SyntaxError:
61 return []
63 results = []
64 for node in ast.walk(tree):
65 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
66 cc = 1 + _count_decisions(node)
67 results.append({
68 'name': node.name,
69 'line': node.lineno,
70 'complexity': cc,
71 })
72 return results
74 @staticmethod
75 def function_lengths(source: str) -> List[Dict[str, Any]]:
76 """Compute line count per function/method.
78 Returns list of {name, line, length}.
79 """
80 try:
81 tree = ast.parse(source)
82 except SyntaxError:
83 return []
85 results = []
86 for node in ast.walk(tree):
87 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
88 length = _function_line_count(node)
89 results.append({
90 'name': node.name,
91 'line': node.lineno,
92 'length': length,
93 })
94 return results
96 @staticmethod
97 def nesting_depth(source: str) -> List[Dict[str, Any]]:
98 """Compute maximum nesting depth per function/method.
100 Nesting: if/for/while/with/try inside each other.
102 Returns list of {name, line, max_depth}.
103 """
104 try:
105 tree = ast.parse(source)
106 except SyntaxError:
107 return []
109 results = []
110 for node in ast.walk(tree):
111 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
112 depth = _max_nesting(node)
113 results.append({
114 'name': node.name,
115 'line': node.lineno,
116 'max_depth': depth,
117 })
118 return results
120 @staticmethod
121 def import_analysis(source: str) -> Dict[str, Any]:
122 """Analyze imports in a source file.
124 Returns {total, stdlib_count, blocked_imports, all_imports}.
125 """
126 try:
127 tree = ast.parse(source)
128 except SyntaxError:
129 return {'total': 0, 'stdlib_count': 0,
130 'blocked_imports': [], 'all_imports': []}
132 all_imports = []
133 blocked = []
135 for node in ast.walk(tree):
136 if isinstance(node, ast.Import):
137 for alias in node.names:
138 mod = alias.name.split('.')[0]
139 all_imports.append(mod)
140 if mod in BLOCKED_IMPORTS:
141 blocked.append(mod)
142 elif isinstance(node, ast.ImportFrom):
143 if node.module:
144 mod = node.module.split('.')[0]
145 all_imports.append(mod)
146 if mod in BLOCKED_IMPORTS:
147 blocked.append(mod)
149 return {
150 'total': len(all_imports),
151 'stdlib_count': sum(1 for m in all_imports if _is_stdlib(m)),
152 'blocked_imports': list(set(blocked)),
153 'all_imports': list(set(all_imports)),
154 }
156 @staticmethod
157 def analyze(source: str) -> Dict[str, Any]:
158 """Run all metrics on a source string.
160 Returns combined dict of all metric results.
161 """
162 return {
163 'cyclomatic_complexity': CodeMetrics.cyclomatic_complexity(source),
164 'function_lengths': CodeMetrics.function_lengths(source),
165 'nesting_depth': CodeMetrics.nesting_depth(source),
166 'import_analysis': CodeMetrics.import_analysis(source),
167 'total_lines': len(source.splitlines()),
168 }
171# ─── PRGuardian ──────────────────────────────────────────────────
173class PRGuardian:
174 """Autonomous PR review with code quality enforcement.
176 Analyzes changed files, checks thresholds, generates review comments.
177 """
179 @staticmethod
180 def analyze_file_source(source: str,
181 filename: str = '') -> Dict[str, Any]:
182 """Analyze a single file's source code.
184 Returns metrics + violations dict.
185 """
186 metrics = CodeMetrics.analyze(source)
187 violations = PRGuardian.check_thresholds(metrics, filename)
188 return {
189 'filename': filename,
190 'metrics': metrics,
191 'violations': violations,
192 'passed': len(violations) == 0,
193 }
195 @staticmethod
196 def analyze_diff(diff_text: str,
197 changed_files: List[Dict[str, str]]) -> Dict[str, Any]:
198 """Analyze a full PR diff.
200 Args:
201 diff_text: Raw unified diff text.
202 changed_files: List of {filename, source} dicts for each file.
204 Returns:
205 Structured report with per-file analysis and overall verdict.
206 """
207 file_reports = []
208 all_violations = []
210 for cf in changed_files:
211 filename = cf.get('filename', '')
212 source = cf.get('source', '')
213 if not source or not filename.endswith('.py'):
214 continue
216 report = PRGuardian.analyze_file_source(source, filename)
217 file_reports.append(report)
218 all_violations.extend(report['violations'])
220 # Diff-level stats — count lines starting with +/-
221 additions = 0
222 deletions = 0
223 if diff_text:
224 for line in diff_text.splitlines():
225 if line.startswith('+') and not line.startswith('+++'):
226 additions += 1
227 elif line.startswith('-') and not line.startswith('---'):
228 deletions += 1
230 overall_passed = len(all_violations) == 0
232 result = {
233 'files_analyzed': len(file_reports),
234 'file_reports': file_reports,
235 'total_violations': len(all_violations),
236 'all_violations': all_violations,
237 'diff_stats': {
238 'additions': additions,
239 'deletions': deletions,
240 },
241 'passed': overall_passed,
242 }
244 # Emit event (non-blocking, best-effort)
245 try:
246 from core.platform.events import emit_event
247 emit_event('pr_review.analysis_complete', {
248 'files_analyzed': len(file_reports),
249 'passed': overall_passed,
250 'violation_count': len(all_violations),
251 })
252 except Exception:
253 pass
255 # Audit log error-severity violations
256 if not overall_passed:
257 try:
258 from security.immutable_audit_log import get_audit_log
259 errors = [v for v in all_violations
260 if v.get('severity') == 'error']
261 if errors:
262 get_audit_log().log_event(
263 'code_review', 'pr_guardian',
264 f'{len(errors)} error-severity violations',
265 detail={'violations': errors[:10]})
266 except Exception:
267 pass
269 return result
271 @staticmethod
272 def check_thresholds(metrics: Dict[str, Any],
273 filename: str = '') -> List[Dict[str, str]]:
274 """Check metrics against quality thresholds.
276 Returns list of violation dicts: {rule, message, severity}.
277 """
278 violations = []
280 # Cyclomatic complexity
281 for func in metrics.get('cyclomatic_complexity', []):
282 if func['complexity'] > MAX_CYCLOMATIC_COMPLEXITY:
283 violations.append({
284 'rule': 'cyclomatic_complexity',
285 'message': (
286 f"{filename}:{func['line']} "
287 f"'{func['name']}' has CC={func['complexity']} "
288 f"(max {MAX_CYCLOMATIC_COMPLEXITY})"),
289 'severity': 'error',
290 })
292 # Function length
293 for func in metrics.get('function_lengths', []):
294 if func['length'] > MAX_FUNCTION_LENGTH:
295 violations.append({
296 'rule': 'function_length',
297 'message': (
298 f"{filename}:{func['line']} "
299 f"'{func['name']}' is {func['length']} lines "
300 f"(max {MAX_FUNCTION_LENGTH})"),
301 'severity': 'warning',
302 })
304 # Nesting depth
305 for func in metrics.get('nesting_depth', []):
306 if func['max_depth'] > MAX_NESTING_DEPTH:
307 violations.append({
308 'rule': 'nesting_depth',
309 'message': (
310 f"{filename}:{func['line']} "
311 f"'{func['name']}' has depth={func['max_depth']} "
312 f"(max {MAX_NESTING_DEPTH})"),
313 'severity': 'warning',
314 })
316 # Blocked imports
317 blocked = metrics.get('import_analysis', {}).get('blocked_imports', [])
318 for mod in blocked:
319 violations.append({
320 'rule': 'blocked_import',
321 'message': (
322 f"{filename}: blocked import '{mod}' "
323 f"(security risk)"),
324 'severity': 'error',
325 })
327 # File too long
328 total = metrics.get('total_lines', 0)
329 if total > MAX_FILE_LENGTH:
330 violations.append({
331 'rule': 'file_length',
332 'message': (
333 f"{filename}: {total} lines "
334 f"(max {MAX_FILE_LENGTH})"),
335 'severity': 'warning',
336 })
338 return violations
340 @staticmethod
341 def generate_review_comment(analysis: Dict[str, Any]) -> str:
342 """Generate a human/agent-readable review comment.
344 Tries ModelBusService for AI-enhanced summary, falls back to template.
345 """
346 passed = analysis.get('passed', False)
347 violations = analysis.get('all_violations', [])
348 files = analysis.get('files_analyzed', 0)
350 # Try AI-enhanced summary
351 ai_summary = ''
352 if violations:
353 try:
354 from integrations.agent_engine.model_bus_service import (
355 get_model_bus_service,
356 )
357 bus = get_model_bus_service()
358 if bus:
359 prompt = (
360 f"Summarize these code review violations in 2-3 "
361 f"sentences for a developer:\n"
362 f"{violations[:10]}")
363 result = bus.infer(prompt)
364 if result and 'response' in result:
365 ai_summary = result['response']
366 except Exception:
367 pass
369 # Template
370 lines = []
371 lines.append('## HART PR Guardian Review\n')
373 if passed:
374 lines.append(f'All {files} files pass quality checks.\n')
375 else:
376 lines.append(
377 f'Found **{len(violations)} violation(s)** '
378 f'across {files} file(s).\n')
380 if ai_summary:
381 lines.append(f'### Summary\n{ai_summary}\n')
383 # Group by severity
384 errors = [v for v in violations if v.get('severity') == 'error']
385 warnings = [v for v in violations if v.get('severity') == 'warning']
387 if errors:
388 lines.append('### Errors (must fix)')
389 for v in errors:
390 lines.append(f"- **{v['rule']}**: {v['message']}")
392 if warnings:
393 lines.append('\n### Warnings')
394 for v in warnings:
395 lines.append(f"- **{v['rule']}**: {v['message']}")
397 lines.append(
398 '\n### Thresholds')
399 lines.append(
400 f'- Cyclomatic Complexity: <= {MAX_CYCLOMATIC_COMPLEXITY}')
401 lines.append(
402 f'- Function Length: <= {MAX_FUNCTION_LENGTH} lines')
403 lines.append(
404 f'- Nesting Depth: <= {MAX_NESTING_DEPTH}')
405 lines.append(
406 f'- Blocked Imports: {sorted(BLOCKED_IMPORTS)}')
408 lines.append('\n*Automated by HART PR Guardian*')
410 return '\n'.join(lines)
412 @staticmethod
413 def check_pr_checklist(pr_body: str) -> Dict[str, bool]:
414 """Parse a PR body for checklist items.
416 Looks for markdown checkboxes like:
417 - [x] Tests added
418 - [ ] Docs updated
420 Returns dict of checklist key -> checked status.
421 """
422 result = {k: False for k in _CHECKLIST_KEYS}
424 if not pr_body:
425 return result
427 body_lower = pr_body.lower()
429 # Match checked checkboxes: [x] or [X]
430 checked = set()
431 for match in re.finditer(r'\[x\]\s*(.+)', body_lower):
432 checked.add(match.group(1).strip())
434 # Map known phrases to checklist keys
435 phrase_map = {
436 'tests_added': ['tests added', 'test added', 'tests included'],
437 'docs_updated': ['docs updated', 'documentation updated',
438 'docs included'],
439 'no_protected_files': ['no protected file',
440 'protected files unchanged'],
441 'manifest_validated': ['manifest valid', 'manifest validated'],
442 'sandbox_passes': ['sandbox pass', 'sandbox check'],
443 }
445 for key, phrases in phrase_map.items():
446 for phrase in phrases:
447 if any(phrase in item for item in checked):
448 result[key] = True
449 break
451 return result
454# ─── AST Helpers (module-private) ────────────────────────────────
456_DECISION_NODES = (
457 ast.If, ast.IfExp,
458 ast.For, ast.AsyncFor,
459 ast.While,
460 ast.ExceptHandler,
461 ast.With, ast.AsyncWith,
462 ast.Assert,
463)
466def _count_decisions(node: ast.AST) -> int:
467 """Count decision points in an AST subtree."""
468 count = 0
469 for child in ast.walk(node):
470 if child is node:
471 continue
472 # Skip nested functions — they get their own CC
473 if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
474 continue
475 if isinstance(child, _DECISION_NODES):
476 count += 1
477 elif isinstance(child, ast.BoolOp):
478 # Each `and`/`or` adds (num_values - 1)
479 count += len(child.values) - 1
480 return count
483def _function_line_count(node: ast.AST) -> int:
484 """Compute approximate line count for a function node."""
485 if not hasattr(node, 'end_lineno') or node.end_lineno is None:
486 # Fallback for Python < 3.8
487 lines = set()
488 for child in ast.walk(node):
489 if hasattr(child, 'lineno'):
490 lines.add(child.lineno)
491 return len(lines) if lines else 1
492 return node.end_lineno - node.lineno + 1
495_NESTING_NODES = (ast.If, ast.For, ast.AsyncFor, ast.While,
496 ast.With, ast.AsyncWith, ast.Try)
499def _max_nesting(func_node: ast.AST) -> int:
500 """Compute max nesting depth within a function."""
501 return _nesting_depth_recursive(func_node, 0, is_root=True)
504def _nesting_depth_recursive(node: ast.AST, current: int,
505 is_root: bool = False) -> int:
506 """Recursively compute nesting depth."""
507 if isinstance(node, _NESTING_NODES) and not is_root:
508 current += 1
510 # Skip nested functions (they get their own depth count)
511 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and not is_root:
512 return current
514 max_depth = current
515 for child in ast.iter_child_nodes(node):
516 child_depth = _nesting_depth_recursive(child, current)
517 if child_depth > max_depth:
518 max_depth = child_depth
520 return max_depth
523# Common stdlib top-level modules (subset for quick classification)
524_STDLIB_MODULES = frozenset({
525 'abc', 'ast', 'asyncio', 'collections', 'contextlib', 'copy',
526 'csv', 'datetime', 'enum', 'functools', 'hashlib', 'io',
527 'itertools', 'json', 'logging', 'math', 'os', 'pathlib',
528 'pickle', 're', 'shutil', 'socket', 'sqlite3', 'string',
529 'struct', 'subprocess', 'sys', 'tempfile', 'textwrap',
530 'threading', 'time', 'traceback', 'typing', 'unittest',
531 'urllib', 'uuid', 'warnings', 'xml', 'zipfile',
532})
535def _is_stdlib(module_name: str) -> bool:
536 """Quick check if a module is likely stdlib."""
537 return module_name in _STDLIB_MODULES