Coverage for core / platform / pr_guardian.py: 91.7%

206 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2PR Guardian — Autonomous code quality enforcement for HART OS. 

3 

4AST-based static analysis providing cyclomatic complexity, function length, 

5nesting depth, and import analysis. Zero external dependencies — stdlib only. 

6 

7Integrates with PRReviewService to provide enhanced reviews that go beyond 

8simple LOC counting. 

9 

10Usage: 

11 from core.platform.pr_guardian import PRGuardian, CodeMetrics 

12 

13 # Analyze a single file 

14 metrics = CodeMetrics.analyze(source) 

15 violations = PRGuardian.check_thresholds(metrics) 

16 

17 # Full PR analysis 

18 report = PRGuardian.analyze_diff(diff_text, changed_files) 

19 comment = PRGuardian.generate_review_comment(report) 

20""" 

21 

22import ast 

23import re 

24from typing import Any, Dict, List, Tuple 

25 

26# ─── Thresholds (frozen by convention) ──────────────────────────── 

27 

28MAX_CYCLOMATIC_COMPLEXITY = 15 

29MAX_FUNCTION_LENGTH = 100 

30MAX_NESTING_DEPTH = 5 

31MAX_FILE_LENGTH = 1000 

32BLOCKED_IMPORTS = frozenset({ 

33 'subprocess', 'ctypes', 'multiprocessing', 

34 'pickle', 'shelve', 'marshal', 

35}) 

36 

37# PR checklist keys 

38_CHECKLIST_KEYS = [ 

39 'tests_added', 'docs_updated', 'no_protected_files', 

40 'manifest_validated', 'sandbox_passes', 

41] 

42 

43 

44# ─── CodeMetrics ───────────────────────────────────────────────── 

45 

46class CodeMetrics: 

47 """AST-based code quality metrics. All static methods, stdlib only.""" 

48 

49 @staticmethod 

50 def cyclomatic_complexity(source: str) -> List[Dict[str, Any]]: 

51 """Compute cyclomatic complexity per function/method. 

52 

53 CC = 1 + number of decision points (if, elif, for, while, and, or, 

54 except, with, assert, ternary IfExp, boolean ops). 

55 

56 Returns list of {name, line, complexity}. 

57 """ 

58 try: 

59 tree = ast.parse(source) 

60 except SyntaxError: 

61 return [] 

62 

63 results = [] 

64 for node in ast.walk(tree): 

65 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 

66 cc = 1 + _count_decisions(node) 

67 results.append({ 

68 'name': node.name, 

69 'line': node.lineno, 

70 'complexity': cc, 

71 }) 

72 return results 

73 

74 @staticmethod 

75 def function_lengths(source: str) -> List[Dict[str, Any]]: 

76 """Compute line count per function/method. 

77 

78 Returns list of {name, line, length}. 

79 """ 

80 try: 

81 tree = ast.parse(source) 

82 except SyntaxError: 

83 return [] 

84 

85 results = [] 

86 for node in ast.walk(tree): 

87 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 

88 length = _function_line_count(node) 

89 results.append({ 

90 'name': node.name, 

91 'line': node.lineno, 

92 'length': length, 

93 }) 

94 return results 

95 

96 @staticmethod 

97 def nesting_depth(source: str) -> List[Dict[str, Any]]: 

98 """Compute maximum nesting depth per function/method. 

99 

100 Nesting: if/for/while/with/try inside each other. 

101 

102 Returns list of {name, line, max_depth}. 

103 """ 

104 try: 

105 tree = ast.parse(source) 

106 except SyntaxError: 

107 return [] 

108 

109 results = [] 

110 for node in ast.walk(tree): 

111 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 

112 depth = _max_nesting(node) 

113 results.append({ 

114 'name': node.name, 

115 'line': node.lineno, 

116 'max_depth': depth, 

117 }) 

118 return results 

119 

120 @staticmethod 

121 def import_analysis(source: str) -> Dict[str, Any]: 

122 """Analyze imports in a source file. 

123 

124 Returns {total, stdlib_count, blocked_imports, all_imports}. 

125 """ 

126 try: 

127 tree = ast.parse(source) 

128 except SyntaxError: 

129 return {'total': 0, 'stdlib_count': 0, 

130 'blocked_imports': [], 'all_imports': []} 

131 

132 all_imports = [] 

133 blocked = [] 

134 

135 for node in ast.walk(tree): 

136 if isinstance(node, ast.Import): 

137 for alias in node.names: 

138 mod = alias.name.split('.')[0] 

139 all_imports.append(mod) 

140 if mod in BLOCKED_IMPORTS: 

141 blocked.append(mod) 

142 elif isinstance(node, ast.ImportFrom): 

143 if node.module: 

144 mod = node.module.split('.')[0] 

145 all_imports.append(mod) 

146 if mod in BLOCKED_IMPORTS: 

147 blocked.append(mod) 

148 

149 return { 

150 'total': len(all_imports), 

151 'stdlib_count': sum(1 for m in all_imports if _is_stdlib(m)), 

152 'blocked_imports': list(set(blocked)), 

153 'all_imports': list(set(all_imports)), 

154 } 

155 

156 @staticmethod 

157 def analyze(source: str) -> Dict[str, Any]: 

158 """Run all metrics on a source string. 

159 

160 Returns combined dict of all metric results. 

161 """ 

162 return { 

163 'cyclomatic_complexity': CodeMetrics.cyclomatic_complexity(source), 

164 'function_lengths': CodeMetrics.function_lengths(source), 

165 'nesting_depth': CodeMetrics.nesting_depth(source), 

166 'import_analysis': CodeMetrics.import_analysis(source), 

167 'total_lines': len(source.splitlines()), 

168 } 

169 

170 

171# ─── PRGuardian ────────────────────────────────────────────────── 

172 

173class PRGuardian: 

174 """Autonomous PR review with code quality enforcement. 

175 

176 Analyzes changed files, checks thresholds, generates review comments. 

177 """ 

178 

179 @staticmethod 

180 def analyze_file_source(source: str, 

181 filename: str = '') -> Dict[str, Any]: 

182 """Analyze a single file's source code. 

183 

184 Returns metrics + violations dict. 

185 """ 

186 metrics = CodeMetrics.analyze(source) 

187 violations = PRGuardian.check_thresholds(metrics, filename) 

188 return { 

189 'filename': filename, 

190 'metrics': metrics, 

191 'violations': violations, 

192 'passed': len(violations) == 0, 

193 } 

194 

195 @staticmethod 

196 def analyze_diff(diff_text: str, 

197 changed_files: List[Dict[str, str]]) -> Dict[str, Any]: 

198 """Analyze a full PR diff. 

199 

200 Args: 

201 diff_text: Raw unified diff text. 

202 changed_files: List of {filename, source} dicts for each file. 

203 

204 Returns: 

205 Structured report with per-file analysis and overall verdict. 

206 """ 

207 file_reports = [] 

208 all_violations = [] 

209 

210 for cf in changed_files: 

211 filename = cf.get('filename', '') 

212 source = cf.get('source', '') 

213 if not source or not filename.endswith('.py'): 

214 continue 

215 

216 report = PRGuardian.analyze_file_source(source, filename) 

217 file_reports.append(report) 

218 all_violations.extend(report['violations']) 

219 

220 # Diff-level stats — count lines starting with +/- 

221 additions = 0 

222 deletions = 0 

223 if diff_text: 

224 for line in diff_text.splitlines(): 

225 if line.startswith('+') and not line.startswith('+++'): 

226 additions += 1 

227 elif line.startswith('-') and not line.startswith('---'): 

228 deletions += 1 

229 

230 overall_passed = len(all_violations) == 0 

231 

232 result = { 

233 'files_analyzed': len(file_reports), 

234 'file_reports': file_reports, 

235 'total_violations': len(all_violations), 

236 'all_violations': all_violations, 

237 'diff_stats': { 

238 'additions': additions, 

239 'deletions': deletions, 

240 }, 

241 'passed': overall_passed, 

242 } 

243 

244 # Emit event (non-blocking, best-effort) 

245 try: 

246 from core.platform.events import emit_event 

247 emit_event('pr_review.analysis_complete', { 

248 'files_analyzed': len(file_reports), 

249 'passed': overall_passed, 

250 'violation_count': len(all_violations), 

251 }) 

252 except Exception: 

253 pass 

254 

255 # Audit log error-severity violations 

256 if not overall_passed: 

257 try: 

258 from security.immutable_audit_log import get_audit_log 

259 errors = [v for v in all_violations 

260 if v.get('severity') == 'error'] 

261 if errors: 

262 get_audit_log().log_event( 

263 'code_review', 'pr_guardian', 

264 f'{len(errors)} error-severity violations', 

265 detail={'violations': errors[:10]}) 

266 except Exception: 

267 pass 

268 

269 return result 

270 

271 @staticmethod 

272 def check_thresholds(metrics: Dict[str, Any], 

273 filename: str = '') -> List[Dict[str, str]]: 

274 """Check metrics against quality thresholds. 

275 

276 Returns list of violation dicts: {rule, message, severity}. 

277 """ 

278 violations = [] 

279 

280 # Cyclomatic complexity 

281 for func in metrics.get('cyclomatic_complexity', []): 

282 if func['complexity'] > MAX_CYCLOMATIC_COMPLEXITY: 

283 violations.append({ 

284 'rule': 'cyclomatic_complexity', 

285 'message': ( 

286 f"{filename}:{func['line']} " 

287 f"'{func['name']}' has CC={func['complexity']} " 

288 f"(max {MAX_CYCLOMATIC_COMPLEXITY})"), 

289 'severity': 'error', 

290 }) 

291 

292 # Function length 

293 for func in metrics.get('function_lengths', []): 

294 if func['length'] > MAX_FUNCTION_LENGTH: 

295 violations.append({ 

296 'rule': 'function_length', 

297 'message': ( 

298 f"{filename}:{func['line']} " 

299 f"'{func['name']}' is {func['length']} lines " 

300 f"(max {MAX_FUNCTION_LENGTH})"), 

301 'severity': 'warning', 

302 }) 

303 

304 # Nesting depth 

305 for func in metrics.get('nesting_depth', []): 

306 if func['max_depth'] > MAX_NESTING_DEPTH: 

307 violations.append({ 

308 'rule': 'nesting_depth', 

309 'message': ( 

310 f"{filename}:{func['line']} " 

311 f"'{func['name']}' has depth={func['max_depth']} " 

312 f"(max {MAX_NESTING_DEPTH})"), 

313 'severity': 'warning', 

314 }) 

315 

316 # Blocked imports 

317 blocked = metrics.get('import_analysis', {}).get('blocked_imports', []) 

318 for mod in blocked: 

319 violations.append({ 

320 'rule': 'blocked_import', 

321 'message': ( 

322 f"{filename}: blocked import '{mod}' " 

323 f"(security risk)"), 

324 'severity': 'error', 

325 }) 

326 

327 # File too long 

328 total = metrics.get('total_lines', 0) 

329 if total > MAX_FILE_LENGTH: 

330 violations.append({ 

331 'rule': 'file_length', 

332 'message': ( 

333 f"{filename}: {total} lines " 

334 f"(max {MAX_FILE_LENGTH})"), 

335 'severity': 'warning', 

336 }) 

337 

338 return violations 

339 

340 @staticmethod 

341 def generate_review_comment(analysis: Dict[str, Any]) -> str: 

342 """Generate a human/agent-readable review comment. 

343 

344 Tries ModelBusService for AI-enhanced summary, falls back to template. 

345 """ 

346 passed = analysis.get('passed', False) 

347 violations = analysis.get('all_violations', []) 

348 files = analysis.get('files_analyzed', 0) 

349 

350 # Try AI-enhanced summary 

351 ai_summary = '' 

352 if violations: 

353 try: 

354 from integrations.agent_engine.model_bus_service import ( 

355 get_model_bus_service, 

356 ) 

357 bus = get_model_bus_service() 

358 if bus: 

359 prompt = ( 

360 f"Summarize these code review violations in 2-3 " 

361 f"sentences for a developer:\n" 

362 f"{violations[:10]}") 

363 result = bus.infer(prompt) 

364 if result and 'response' in result: 

365 ai_summary = result['response'] 

366 except Exception: 

367 pass 

368 

369 # Template 

370 lines = [] 

371 lines.append('## HART PR Guardian Review\n') 

372 

373 if passed: 

374 lines.append(f'All {files} files pass quality checks.\n') 

375 else: 

376 lines.append( 

377 f'Found **{len(violations)} violation(s)** ' 

378 f'across {files} file(s).\n') 

379 

380 if ai_summary: 

381 lines.append(f'### Summary\n{ai_summary}\n') 

382 

383 # Group by severity 

384 errors = [v for v in violations if v.get('severity') == 'error'] 

385 warnings = [v for v in violations if v.get('severity') == 'warning'] 

386 

387 if errors: 

388 lines.append('### Errors (must fix)') 

389 for v in errors: 

390 lines.append(f"- **{v['rule']}**: {v['message']}") 

391 

392 if warnings: 

393 lines.append('\n### Warnings') 

394 for v in warnings: 

395 lines.append(f"- **{v['rule']}**: {v['message']}") 

396 

397 lines.append( 

398 '\n### Thresholds') 

399 lines.append( 

400 f'- Cyclomatic Complexity: <= {MAX_CYCLOMATIC_COMPLEXITY}') 

401 lines.append( 

402 f'- Function Length: <= {MAX_FUNCTION_LENGTH} lines') 

403 lines.append( 

404 f'- Nesting Depth: <= {MAX_NESTING_DEPTH}') 

405 lines.append( 

406 f'- Blocked Imports: {sorted(BLOCKED_IMPORTS)}') 

407 

408 lines.append('\n*Automated by HART PR Guardian*') 

409 

410 return '\n'.join(lines) 

411 

412 @staticmethod 

413 def check_pr_checklist(pr_body: str) -> Dict[str, bool]: 

414 """Parse a PR body for checklist items. 

415 

416 Looks for markdown checkboxes like: 

417 - [x] Tests added 

418 - [ ] Docs updated 

419 

420 Returns dict of checklist key -> checked status. 

421 """ 

422 result = {k: False for k in _CHECKLIST_KEYS} 

423 

424 if not pr_body: 

425 return result 

426 

427 body_lower = pr_body.lower() 

428 

429 # Match checked checkboxes: [x] or [X] 

430 checked = set() 

431 for match in re.finditer(r'\[x\]\s*(.+)', body_lower): 

432 checked.add(match.group(1).strip()) 

433 

434 # Map known phrases to checklist keys 

435 phrase_map = { 

436 'tests_added': ['tests added', 'test added', 'tests included'], 

437 'docs_updated': ['docs updated', 'documentation updated', 

438 'docs included'], 

439 'no_protected_files': ['no protected file', 

440 'protected files unchanged'], 

441 'manifest_validated': ['manifest valid', 'manifest validated'], 

442 'sandbox_passes': ['sandbox pass', 'sandbox check'], 

443 } 

444 

445 for key, phrases in phrase_map.items(): 

446 for phrase in phrases: 

447 if any(phrase in item for item in checked): 

448 result[key] = True 

449 break 

450 

451 return result 

452 

453 

454# ─── AST Helpers (module-private) ──────────────────────────────── 

455 

456_DECISION_NODES = ( 

457 ast.If, ast.IfExp, 

458 ast.For, ast.AsyncFor, 

459 ast.While, 

460 ast.ExceptHandler, 

461 ast.With, ast.AsyncWith, 

462 ast.Assert, 

463) 

464 

465 

466def _count_decisions(node: ast.AST) -> int: 

467 """Count decision points in an AST subtree.""" 

468 count = 0 

469 for child in ast.walk(node): 

470 if child is node: 

471 continue 

472 # Skip nested functions — they get their own CC 

473 if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): 

474 continue 

475 if isinstance(child, _DECISION_NODES): 

476 count += 1 

477 elif isinstance(child, ast.BoolOp): 

478 # Each `and`/`or` adds (num_values - 1) 

479 count += len(child.values) - 1 

480 return count 

481 

482 

483def _function_line_count(node: ast.AST) -> int: 

484 """Compute approximate line count for a function node.""" 

485 if not hasattr(node, 'end_lineno') or node.end_lineno is None: 

486 # Fallback for Python < 3.8 

487 lines = set() 

488 for child in ast.walk(node): 

489 if hasattr(child, 'lineno'): 

490 lines.add(child.lineno) 

491 return len(lines) if lines else 1 

492 return node.end_lineno - node.lineno + 1 

493 

494 

495_NESTING_NODES = (ast.If, ast.For, ast.AsyncFor, ast.While, 

496 ast.With, ast.AsyncWith, ast.Try) 

497 

498 

499def _max_nesting(func_node: ast.AST) -> int: 

500 """Compute max nesting depth within a function.""" 

501 return _nesting_depth_recursive(func_node, 0, is_root=True) 

502 

503 

504def _nesting_depth_recursive(node: ast.AST, current: int, 

505 is_root: bool = False) -> int: 

506 """Recursively compute nesting depth.""" 

507 if isinstance(node, _NESTING_NODES) and not is_root: 

508 current += 1 

509 

510 # Skip nested functions (they get their own depth count) 

511 if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and not is_root: 

512 return current 

513 

514 max_depth = current 

515 for child in ast.iter_child_nodes(node): 

516 child_depth = _nesting_depth_recursive(child, current) 

517 if child_depth > max_depth: 

518 max_depth = child_depth 

519 

520 return max_depth 

521 

522 

523# Common stdlib top-level modules (subset for quick classification) 

524_STDLIB_MODULES = frozenset({ 

525 'abc', 'ast', 'asyncio', 'collections', 'contextlib', 'copy', 

526 'csv', 'datetime', 'enum', 'functools', 'hashlib', 'io', 

527 'itertools', 'json', 'logging', 'math', 'os', 'pathlib', 

528 'pickle', 're', 'shutil', 'socket', 'sqlite3', 'string', 

529 'struct', 'subprocess', 'sys', 'tempfile', 'textwrap', 

530 'threading', 'time', 'traceback', 'typing', 'unittest', 

531 'urllib', 'uuid', 'warnings', 'xml', 'zipfile', 

532}) 

533 

534 

535def _is_stdlib(module_name: str) -> bool: 

536 """Quick check if a module is likely stdlib.""" 

537 return module_name in _STDLIB_MODULES