Skip to content

Reinforcement Learning

Quick start

We've provided a minimal script that runs RL on the GSM8K dataset: rl_basic.py:

python -m logits_cookbook.recipes.rl_basic

This script fine-tunes Qwen3-8B on GSM8K math problems with the reward function:

\[ 1[\text{answer is correct}] + 0.1 \times (1[\text{answer is formatted correctly}] - 1) \]

Training takes about 1 minute per iteration and reaches ~63% accuracy after 15 iterations.

Key metrics to watch

  • env/all/correct — fraction of correct answers
  • env/all/format — fraction of correctly formatted completions
  • env/all/reward/total — mean total reward
  • entropy — per-token entropy
  • ac_tokens_per_turn — mean tokens generated per turn
  • kl_sample_train_{v1,v2} — KL divergence estimators

Minimal training loop

For a self-contained example without the environment abstractions, see rl_loop.py:

python -m logits_cookbook.recipes.rl_loop

Results are written to /tmp/logits-examples/rl-loop. Plot the reward curve:

df = pandas.read_json("/tmp/logits-examples/rl-loop/metrics.jsonl", lines=True)
plt.plot(df["reward/total"], label="reward/total")
plt.legend()
plt.show()

Training outputs

Each RL training run writes files to log_path:

File Format Contents
metrics.jsonl JSONL One JSON object per iteration with all scalar metrics
config.json JSON Serialized training config
checkpoints.jsonl JSONL Checkpoint metadata for resume
train_iteration_NNNNNN.html HTML Human-readable rollout report
train_iteration_NNNNNN_logtree.json JSON Machine-readable rollout trace
train_iteration_NNNNNN_rollout_summaries.jsonl JSONL Per-trajectory rewards and metrics
eval_<name>_iteration_NNNNNN.* HTML/JSON/JSONL Same as above, for eval rollouts

Parsing metrics.jsonl

df = pd.read_json("path/to/metrics.jsonl", lines=True)
df.plot(x="progress/batch", y="env/all/reward/total")

Parsing *_rollout_summaries.jsonl

with open("train_iteration_000010_rollout_summaries.jsonl") as f:
    trajectories = [json.loads(line) for line in f]

# Each trajectory has:
# - metadata: schema_version, split, iteration, group_idx, traj_idx
# - episode totals: total_reward, final_reward, trajectory_metrics
# - steps: list of {step_idx, ob_len, ac_len, reward, episode_done, metrics}

Parsing *_logtree.json

The logtree contains full rollout transcripts including prompts, model responses, and reward breakdowns.

def find_conversations(node):
    results = []
    if isinstance(node, dict):
        if node.get("data", {}).get("type") == "conversation":
            results.append(node["data"])
        for child in node.get("children", []):
            if isinstance(child, dict):
                results.extend(find_conversations(child))
    return results

with open("eval_test_iteration_000020_logtree.json") as f:
    trace = json.load(f)

for conv in find_conversations(trace["root"]):
    for msg in conv["messages"]:
        print(f"{msg['role']}: {msg['content'][:100]}")