Large reasoning models often suffer from "overthinking" - excessive reasoning that wastes computational resources. ARM introduces adaptive reasoning formats and multiple modes to optimize token usage while maintaining performance.
Large reasoning models often suffer from "overthinking" - excessive reasoning that wastes computational resources. ARM introduces adaptive reasoning formats and multiple modes to optimize token usage while maintaining performance.
Supports four reasoning formats: Direct Answer, Short CoT, Code, and Long CoT, enabling optimal format selection for each task.
Three modes: Adaptive (automatic), Instruction-Guided (explicit), and Consensus-Guided (aggregated) for flexible reasoning control.
Optimizes token usage through dynamic format selection, moving towards fully autonomous AI without human intervention.
We propose Adaptive Group Relative Policy Optimization (Ada-GRPO), an adaptation of GRPO, which addresses the format collapse issue in traditional GRPO
Traditional GRPO:
1. Definition: For each question \( q \), samples a group of outputs \( O = \{ o_1, o_2, \ldots, o_G \} \), where \( G \) is the group size. Each output \( o_i \) receives a binary reward \( r_i \) based on prediction accuracy.
2. Problem: Solely optimizes for accuracy, leading to overuse of the highest-accuracy format (e.g., Long CoT) and reduced exploration of more efficient alternatives. This phenomenon is called Format Collapse.
Ada-GRPO Solution:
Core Idea: Introduces a format diversity scaling factor \( \alpha(t) \) that consists of two components:
For each question \( q \), sample a group of outputs \( O = \{ o_1, o_2, \ldots, o_G \} \), where \( G \) is the group size. Each output \( o_i \) is associated with a binary reward \( r_i \) based on prediction accuracy.
Compute binary reward for each output:
\( r_i = \mathbb{1}(\text{pred} = \text{gt}) \)
where 𝟙 is the indicator function, pred is the model's prediction, and gt is the ground truth.
Calculate group advantage for policy optimization:
\( \hat{A}_{ik} = \frac{r_i - \text{mean}(r)}{\text{std}(r)} \)
\( r \) represents the group of rewards
<1%
Performance loss
>30%
Average token savings
>70%
Maximum token savings
2x
Training speedup
Performance of various models across evaluation datasets. "Tokens" refers to the token cost for each model on each dataset. For each model, k=1 corresponds to pass@1, and k=8 corresponds to maj@8. When k=8, the token cost is averaged over a single output. † denotes in-domain tasks, while ‡ denotes out-of-domain tasks.
Models | k | Accuracy (↑) | Tokens (↓) | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Easy | Medium | Hard | Avg. | Easy | Medium | Hard | Avg. | ||||||||||
CSQA† | OBQA‡ | GSM8K† | MATH† | SVAMP‡ | BBH‡ | AIME'25‡ | CSQA† | OBQA‡ | GSM8K† | MATH† | SVAMP‡ | BBH‡ | AIME'25‡ | ||||
GPT-4o | 1 | 85.9 | 94.2 | 95.9 | 75.9 | 91.3 | 84.7 | 10.0 | 76.8 | 192 | 165 | 287 | 663 | 156 | 278 | 984 | 389 |
o1-preview | 1 | 85.5 | 95.6 | 94.2 | 92.6 | 92.7 | 91.8 | 40.0 | 84.6 | 573 | 492 | 456 | 1863 | 489 | 940 | 7919 | 1819 |
o4-mini-high | 1 | 84.7 | 96.0 | 96.9 | 97.7 | 94.0 | 92.2 | 96.7 | 94.0 | 502 | 289 | 339 | 1332 | 301 | 755 | 9850 | 1910 |
DeepSeek-V3 | 1 | 82.4 | 96.0 | 96.5 | 91.8 | 93.7 | 85.8 | 36.7 | 83.3 | 231 | 213 | 236 | 887 | 160 | 400 | 2992 | 732 |
DeepSeek-R1 | 1 | 83.3 | 94.8 | 96.4 | 97.1 | 96.0 | 85.0 | 70.0 | 88.9 | 918 | 736 | 664 | 2339 | 589 | 1030 | 9609 | 2270 |
DS-R1-Distill-1.5B | 1 | 47.6 | 48.6 | 79.4 | 84.6 | 86.7 | 53.5 | 20.0 | 60.1 | 987 | 1540 | 841 | 3875 | 606 | 3005 | 13118 | 3425 |
DS-R1-Distill-7B | 1 | 64.9 | 77.4 | 90.0 | 93.6 | 90.3 | 72.1 | 40.0 | 75.5 | 792 | 928 | 574 | 3093 | 315 | 1448 | 12427 | 2797 |
DS-R1-Distill-14B | 1 | 80.6 | 93.2 | 94.0 | 95.5 | 92.7 | 80.4 | 50.0 | 83.8 | 816 | 750 | 825 | 2682 | 726 | 1292 | 11004 | 2585 |
DS-R1-Distill-32B | 1 | 83.2 | 94.6 | 93.5 | 93.0 | 92.0 | 86.3 | 56.7 | 85.6 | 674 | 698 | 438 | 2161 | 283 | 999 | 11276 | 2361 |
Qwen2.5-3B | 1 | 66.5 | 65.8 | 66.9 | 37.7 | 71.3 | 38.4 | 0 | 49.5 | 97 | 120 | 150 | 419 | 76 | 232 | 1393 | 355 |
8 | 75.5 | 77.4 | 80.9 | 50.8 | 83.7 | 47.1 | 0 | 59.3 | 96 | 100 | 149 | 424 | 85 | 240 | 1544 | 377 | |
Qwen2.5-3BSFT | 1 | 72.8 | 72.4 | 35.7 | 20.9 | 62.3 | 37.4 | 0 | 43.1 | 99 | 108 | 145 | 229 | 126 | 311 | 694 | 245 |
8 | 75.5 | 77.4 | 56.0 | 27.6 | 74.7 | 43.5 | 0 | 50.7 | 97 | 103 | 132 | 231 | 108 | 309 | 537 | 217 | |
Qwen2.5-3BSFT+GRPO | 1 | 79.7 | 79.0 | 88.7 | 66.6 | 92.0 | 52.6 | 6.7 | 66.5 | 425 | 501 | 788 | 1586 | 630 | 994 | 3027 | 1136 |
8 | 80.3 | 80.0 | 91.4 | 74.0 | 94.7 | 56.2 | 6.7 | 69.0 | 429 | 506 | 802 | 1590 | 638 | 996 | 3247 | 1172 | |
ARM-3B | 1 | 79.8 | 78.0 | 83.8 | 62.9 | 89.7 | 50.0 | 6.7 | 64.4 | 118 | 156 | 346 | 1013 | 264 | 436 | 2958 | 756 |
8 | 80.1 | 78.0 | 90.8 | 72.8 | 95.0 | 53.8 | 6.7 | 68.2 | 123 | 169 | 359 | 1036 | 246 | 430 | 3083 | 778 | |
Δ | -0.2 | -2.0 | -0.6 | -1.2 | +0.3 | -2.4 | 0 | -0.8 | -71.3% | -66.6% | -55.2% | -34.8% | -61.4% | -56.8% | -5.1% | -33.6% | |
Qwen2.5-7B | 1 | 76.7 | 78.6 | 81.6 | 50.1 | 81.0 | 51.7 | 3.3 | 60.4 | 64 | 83 | 156 | 376 | 99 | 182 | 767 | 247 |
8 | 82.0 | 86.4 | 89.9 | 64.7 | 89.7 | 62.0 | 3.3 | 68.3 | 66 | 74 | 156 | 370 | 92 | 183 | 881 | 260 | |
Qwen2.5-7BSFT | 1 | 80.8 | 81.2 | 54.4 | 30.4 | 76.0 | 48.2 | 0 | 53.0 | 136 | 150 | 184 | 348 | 126 | 245 | 1239 | 347 |
8 | 83.9 | 84.6 | 79.4 | 42.4 | 88.0 | 56.0 | 0 | 62.0 | 141 | 137 | 185 | 361 | 141 | 274 | 1023 | 323 | |
Qwen2.5-7BSFT+GRPO | 1 | 83.1 | 82.2 | 92.8 | 79.4 | 93.7 | 64.3 | 16.7 | 73.2 | 491 | 651 | 739 | 1410 | 587 | 1133 | 3196 | 1173 |
8 | 83.7 | 84.6 | 94.8 | 84.9 | 95.3 | 69.3 | 20.0 | 76.1 | 496 | 625 | 745 | 1415 | 586 | 1135 | 3145 | 1164 | |
ARM-7B | 1 | 86.1 | 84.4 | 89.2 | 73.9 | 92.0 | 61.4 | 16.7 | 72.0 | 136 | 159 | 305 | 889 | 218 | 401 | 3253 | 766 |
8 | 85.7 | 85.8 | 93.7 | 82.6 | 95.3 | 67.9 | 20.0 | 75.9 | 134 | 154 | 297 | 893 | 218 | 413 | 3392 | 786 | |
Δ | +2.0 | +1.2 | -1.1 | -2.3 | 0 | -1.4 | 0 | -0.2 | -73.0% | -75.4% | -60.1% | -36.9% | -62.8% | -63.6% | +7.9% | -32.5% | |
Qwen2.5-14B | 1 | 79.9 | 83.8 | 84.9 | 52.7 | 84.7 | 56.8 | 3.3 | 63.7 | 56 | 60 | 132 | 335 | 77 | 139 | 611 | 201 |
8 | 83.8 | 90.2 | 92.3 | 68.4 | 91.7 | 67.4 | 3.3 | 71.0 | 55 | 60 | 131 | 325 | 81 | 131 | 735 | 217 | |
Qwen2.5-14BSFT | 1 | 81.8 | 88.0 | 62.6 | 37.4 | 84.0 | 53.5 | 0 | 58.2 | 155 | 140 | 161 | 276 | 152 | 254 | 527 | 238 |
8 | 85.0 | 91.4 | 86.4 | 48.8 | 91.7 | 64.4 | 3.3 | 67.3 | 149 | 141 | 165 | 288 | 140 | 247 | 493 | 232 | |
Qwen2.5-14BSFT+GRPO | 1 | 85.4 | 93.0 | 94.8 | 81.7 | 93.7 | 70.5 | 20.0 | 77.0 | 558 | 531 | 693 | 1805 | 565 | 945 | 4031 | 1304 |
8 | 85.8 | 94.2 | 96.1 | 87.1 | 95.3 | 77.0 | 20.0 | 79.4 | 552 | 537 | 696 | 1810 | 565 | 943 | 3723 | 1261 | |
ARM-14B | 1 | 85.3 | 91.8 | 92.5 | 79.1 | 93.3 | 66.6 | 20.0 | 75.5 | 146 | 128 | 294 | 903 | 212 | 420 | 3871 | 853 |
8 | 85.6 | 91.8 | 96.3 | 86.4 | 95.7 | 72.1 | 23.3 | 78.7 | 145 | 134 | 293 | 910 | 189 | 415 | 3996 | 869 | |
Δ | -0.2 | -2.4 | +0.2 | -0.7 | +0.4 | -4.9 | +3.3 | -0.7 | -73.7% | -75.0% | -57.9% | -49.7% | -66.5% | -56.0% | +7.3% | -31.1% |
The hatched areas indicate the percentage of correct answers that were generated using the selected reasoning format.
Key Takeaways:
ARM-7B | Easy | Medium | Hard | Avg. | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
CSQA† | OBQA‡ | GSM8K† | MATH† | SVAMP‡ | BBH‡ | AIME'25‡ | ||||||||||
Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | Acc. | Tok. | |
Adaptive | 86.1 | 136 | 84.4 | 159 | 89.2 | 305 | 73.9 | 889 | 92.0 | 218 | 61.4 | 401 | 16.7 | 3253 | 72.0 | 766 |
InstDirect | 84.1 | 10 | 81.8 | 10 | 22.9 | 11 | 23.1 | 13 | 67.0 | 11 | 44.7 | 21 | 0 | 12 | 46.2 | 13 |
InstShort CoT | 81.3 | 33 | 77.4 | 35 | 85.0 | 124 | 70.9 | 633 | 86.7 | 66 | 49.7 | 101 | 10.0 | 2010 | 65.9 | 428 |
InstCode | 84.4 | 140 | 81.6 | 147 | 84.2 | 285 | 65.9 | 559 | 88.3 | 182 | 57.9 | 344 | 10.0 | 1821 | 67.5 | 497 |
InstLong CoT | 84.0 | 259 | 87.4 | 294 | 91.8 | 426 | 77.2 | 1220 | 94.3 | 340 | 66.9 | 660 | 20.0 | 4130 | 74.5 | 1047 |
Consensus | 85.8 | 228 | 87.0 | 260 | 92.9 | 777 | 78.4 | 2281 | 95.7 | 433 | 66.4 | 1039 | 20.0 | 7973 | 75.2 | 1856 |
Long CoT Usage | 12.9% | 21.4% | 79.8% | 79.2% | 36.3% | 56.3% | 100% | 55.1% |
Key Takeaways:
Key Takeaways:
Compare how ARM performs in adaptive mode versus different instruction-guided formats.
@article{wu2025arm, title={ARM: Adaptive Reasoning Model}, author={Wu, Siye and Xie, Jian and Zhang, Yikai and Chen, Aili and Zhang, Kai and Su, Yu and Xiao, Yanghua}, journal={arXiv preprint arXiv:2505.20258}, year={2025} }