openclaw-intelligence-core-.../syncpatch/reward_signals.py

63 lines
2.2 KiB
Python
Raw Normal View History

2026-03-21 07:34:09 +00:00
from __future__ import annotations
from typing import Any
BASE_REWARDS = {
'success': 5.0,
'needs_clarification': 1.0,
'tool_output_unverified': -1.5,
'tool_failed': -3.0,
'no_result': -2.5,
}
CAP_BREACH_PENALTIES = {
'daily_cap_exceeded': -1.0,
'path_like_payload': -1.0,
}
def derive_cap_breaches(error_text: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]]) -> list[str]:
text = ' '.join(
[error_text or '']
+ [str(item.get('error') or '') for item in evidence_items]
+ [str(item.get('output') or '')[:200] for item in evidence_items]
).lower()
breaches: list[str] = []
if 'daily_cap_exceeded' in text:
breaches.append('daily_cap_exceeded')
if 'path_like_payload' in text:
breaches.append('path_like_payload')
if 'daily_cap_exceeded' not in breaches:
quarantine = analysis.get('quarantine_reason') or analysis.get('memory_quarantine_reason') or ''
if 'daily_cap_exceeded' in str(quarantine).lower():
breaches.append('daily_cap_exceeded')
return breaches
def compute_reward(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> float:
reward = BASE_REWARDS.get(status, 0.0)
grounded = sum(1 for item in evidence_items if item.get('grounded'))
if grounded:
reward += min(grounded, 3) * 0.5
if analysis.get('force_sequential') and status == 'success':
reward += 0.5
if final_text and len(final_text.strip()) < 24 and status == 'success' and grounded == 0:
reward -= 0.5
for breach in cap_breaches or []:
reward += CAP_BREACH_PENALTIES.get(breach, -0.5)
return reward
def reward_row(status: str, analysis: dict[str, Any], evidence_items: list[dict[str, Any]], final_text: str, *, cap_breaches: list[str] | None = None) -> dict[str, Any]:
cap_breaches = list(cap_breaches or [])
return {
'reward': compute_reward(status, analysis, evidence_items, final_text, cap_breaches=cap_breaches),
'cap_breaches': cap_breaches,
'grounded_count': sum(1 for item in evidence_items if item.get('grounded')),
'evidence_count': len(evidence_items),
}