from __future__ import annotations from typing import Any BASE_REWARDS = { 'success': 5.0, 'needs_clarification': 1.0, 'tool_output_unverified': -1.5, 'tool_failed': -3.0, 'no_result': -2.5, } CAP_BREACH_PENALTIES = { 'daily_cap_exceeded': -1.0, 'path_like_payload': -1.0, } def derive_cap_breaches(error_text: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]]) -> list[str]: text = ' '.join( [error_text or ''] + [str(item.get('error') or '') for item in evidence_items] + [str(item.get('output') or '')[:200] for item in evidence_items] ).lower() breaches: list[str] = [] if 'daily_cap_exceeded' in text: breaches.append('daily_cap_exceeded') if 'path_like_payload' in text: breaches.append('path_like_payload') if 'daily_cap_exceeded' not in breaches: quarantine = analysis.get('quarantine_reason') or analysis.get('memory_quarantine_reason') or '' if 'daily_cap_exceeded' in str(quarantine).lower(): breaches.append('daily_cap_exceeded') return breaches def compute_reward(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> float: reward = BASE_REWARDS.get(status, 0.0) grounded = sum(1 for item in evidence_items if item.get('grounded')) if grounded: reward += min(grounded, 3) * 0.5 if analysis.get('force_sequential') and status == 'success': reward += 0.5 if final_text and len(final_text.strip()) < 24 and status == 'success' and grounded == 0: reward -= 0.5 for breach in cap_breaches or []: reward += CAP_BREACH_PENALTIES.get(breach, -0.5) return reward def reward_row(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> dict[str, Any]: cap_breaches = list(cap_breaches or []) return { 'reward': compute_reward(status, analysis, evidence_items, final_text, cap_breaches=cap_breaches), 'cap_breaches': cap_breaches, 'grounded_count': sum(1 for item in evidence_items if item.get('grounded')), 'evidence_count': len(evidence_items), }