Reinforcement Learning¶
Quick start¶
We've provided a minimal script that runs RL on the GSM8K dataset: rl_basic.py:
python -m logits_cookbook.recipes.rl_basic
This script fine-tunes Qwen3-8B on GSM8K math problems with the reward function:
\[
1[\text{answer is correct}] + 0.1 \times (1[\text{answer is formatted correctly}] - 1)
\]
Training takes about 1 minute per iteration and reaches ~63% accuracy after 15 iterations.
Key metrics to watch¶
env/all/correct— fraction of correct answersenv/all/format— fraction of correctly formatted completionsenv/all/reward/total— mean total rewardentropy— per-token entropyac_tokens_per_turn— mean tokens generated per turnkl_sample_train_{v1,v2}— KL divergence estimators
Minimal training loop¶
For a self-contained example without the environment abstractions, see rl_loop.py:
python -m logits_cookbook.recipes.rl_loop
Results are written to /tmp/logits-examples/rl-loop. Plot the reward curve:
df = pandas.read_json("/tmp/logits-examples/rl-loop/metrics.jsonl", lines=True)
plt.plot(df["reward/total"], label="reward/total")
plt.legend()
plt.show()
Training outputs¶
Each RL training run writes files to log_path:
| File | Format | Contents |
|---|---|---|
metrics.jsonl |
JSONL | One JSON object per iteration with all scalar metrics |
config.json |
JSON | Serialized training config |
checkpoints.jsonl |
JSONL | Checkpoint metadata for resume |
train_iteration_NNNNNN.html |
HTML | Human-readable rollout report |
train_iteration_NNNNNN_logtree.json |
JSON | Machine-readable rollout trace |
train_iteration_NNNNNN_rollout_summaries.jsonl |
JSONL | Per-trajectory rewards and metrics |
eval_<name>_iteration_NNNNNN.* |
HTML/JSON/JSONL | Same as above, for eval rollouts |
Parsing metrics.jsonl¶
df = pd.read_json("path/to/metrics.jsonl", lines=True)
df.plot(x="progress/batch", y="env/all/reward/total")
Parsing *_rollout_summaries.jsonl¶
with open("train_iteration_000010_rollout_summaries.jsonl") as f:
trajectories = [json.loads(line) for line in f]
# Each trajectory has:
# - metadata: schema_version, split, iteration, group_idx, traj_idx
# - episode totals: total_reward, final_reward, trajectory_metrics
# - steps: list of {step_idx, ob_len, ac_len, reward, episode_done, metrics}
Parsing *_logtree.json¶
The logtree contains full rollout transcripts including prompts, model responses, and reward breakdowns.
def find_conversations(node):
results = []
if isinstance(node, dict):
if node.get("data", {}).get("type") == "conversation":
results.append(node["data"])
for child in node.get("children", []):
if isinstance(child, dict):
results.extend(find_conversations(child))
return results
with open("eval_test_iteration_000020_logtree.json") as f:
trace = json.load(f)
for conv in find_conversations(trace["root"]):
for msg in conv["messages"]:
print(f"{msg['role']}: {msg['content'][:100]}")