63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
|
||
|
|
BASE_REWARDS = {
|
||
|
|
'success': 5.0,
|
||
|
|
'needs_clarification': 1.0,
|
||
|
|
'tool_output_unverified': -1.5,
|
||
|
|
'tool_failed': -3.0,
|
||
|
|
'no_result': -2.5,
|
||
|
|
}
|
||
|
|
|
||
|
|
CAP_BREACH_PENALTIES = {
|
||
|
|
'daily_cap_exceeded': -1.0,
|
||
|
|
'path_like_payload': -1.0,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def derive_cap_breaches(error_text: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]]) -> list[str]:
|
||
|
|
text = ' '.join(
|
||
|
|
[error_text or '']
|
||
|
|
+ [str(item.get('error') or '') for item in evidence_items]
|
||
|
|
+ [str(item.get('output') or '')[:200] for item in evidence_items]
|
||
|
|
).lower()
|
||
|
|
breaches: list[str] = []
|
||
|
|
if 'daily_cap_exceeded' in text:
|
||
|
|
breaches.append('daily_cap_exceeded')
|
||
|
|
if 'path_like_payload' in text:
|
||
|
|
breaches.append('path_like_payload')
|
||
|
|
if 'daily_cap_exceeded' not in breaches:
|
||
|
|
quarantine = analysis.get('quarantine_reason') or analysis.get('memory_quarantine_reason') or ''
|
||
|
|
if 'daily_cap_exceeded' in str(quarantine).lower():
|
||
|
|
breaches.append('daily_cap_exceeded')
|
||
|
|
return breaches
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def compute_reward(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> float:
|
||
|
|
reward = BASE_REWARDS.get(status, 0.0)
|
||
|
|
grounded = sum(1 for item in evidence_items if item.get('grounded'))
|
||
|
|
if grounded:
|
||
|
|
reward += min(grounded, 3) * 0.5
|
||
|
|
if analysis.get('force_sequential') and status == 'success':
|
||
|
|
reward += 0.5
|
||
|
|
if final_text and len(final_text.strip()) < 24 and status == 'success' and grounded == 0:
|
||
|
|
reward -= 0.5
|
||
|
|
for breach in cap_breaches or []:
|
||
|
|
reward += CAP_BREACH_PENALTIES.get(breach, -0.5)
|
||
|
|
return reward
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def reward_row(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> dict[str, Any]:
|
||
|
|
cap_breaches = list(cap_breaches or [])
|
||
|
|
return {
|
||
|
|
'reward': compute_reward(status, analysis, evidence_items, final_text, cap_breaches=cap_breaches),
|
||
|
|
'cap_breaches': cap_breaches,
|
||
|
|
'grounded_count': sum(1 for item in evidence_items if item.get('grounded')),
|
||
|
|
'evidence_count': len(evidence_items),
|
||
|
|
}
|