Benchmarks & Competition¶
SPS Benchmarks¶
Measure environment throughput before and after algorithmic changes.
# Legacy Python env
python benchmarks/sps_baseline.py
# JAX vectorized (batches: 1, 16, 128, 1024)
python benchmarks/sps_jax_vectorized.py
Results are written to benchmarks/results/sps_baseline.json and sps_jax_vectorized.json.
Expected throughput on CPU:
| Backend | Batch size | Steps/sec |
|---|---|---|
| Legacy | 1 | ~100–400 |
| JAX | 1 | ~500–2 000 |
| JAX | 1 024 | ~500 000–1 500 000 |
PPO Rollout Benchmark¶
Measures wall-clock time for a full PPO rollout with SB3:
Competition Eval Harness¶
benchmarks/competition_eval.py runs submitted agents across scenarios and seeds, computes a composite score, and writes a ranked leaderboard to benchmarks/results/leaderboard.json.
Implementing an agent¶
import numpy as np
from benchmarks.competition_eval import Agent, SubmissionResult, evaluate, submit_to_leaderboard
class MyBlueAgent:
def reset(self) -> None:
pass
def act(self, obs: dict, agent_id: str) -> np.ndarray:
mask = obs["action_mask"]
valid_types = np.where(mask[:32])[0]
valid_targets = np.where(mask[32:])[0]
return np.array([valid_types[0], valid_targets[0]], dtype=np.int64)
sub = SubmissionResult(name="my_blue", team="blue")
evaluate(sub, red_agent=RandomAgent(), blue_agent=MyBlueAgent(), scenarios=["ransomware"], seeds=list(range(10)))
submit_to_leaderboard(sub)
Scoring¶
Blue score = SLA_uptime × 50 − compromised_hosts × 2 − MTTC × 0.1 + blue_reward × 0.1
Red score = compromised_hosts × 2 + exfiltrated_data × 0.01 − SLA_uptime × 10 + red_reward × 0.1
CLI¶
# Evaluate random baseline and update leaderboard
python benchmarks/competition_eval.py --name my_agent --team blue --episodes 10
# Print leaderboard
python benchmarks/competition_eval.py --leaderboard
Metrics in episode info¶
Each step returns per-agent info keys: compromised_hosts, isolated_hosts, SLA_Uptime_Percentage, MTTC, Total_Exfiltrated_Data, false_positives, agent_energy.
The __curriculum__ key is present when using CurriculumWrapper and contains phase, phase_index, mean_reward, window_fill, phase_advanced.