神经网络垃圾回收:在推理中学习遗忘 / Neural Garbage Collection: Learning to Forget while Learning to Reason
1️⃣ 一句话总结
本文提出一种名为“神经网络垃圾回收”的方法,让语言模型在通过强化学习进行端到端推理的同时,自主学会有选择地丢弃KV缓存中的信息,从而在保持高准确率的前提下将峰值缓存占用压缩2到3倍,解决了长链推理中内存瓶颈问题。
Chain-of-thought reasoning has driven striking advances in language model capability, yet every reasoning step grows the KV cache, creating a bottleneck to scaling this paradigm further. Current approaches manage these constraints on the model's behalf using hand-designed criteria. A more scalable approach would let end-to-end learning subsume this design choice entirely, following a broader pattern in deep learning. After all, if a model can learn to reason, why can't it learn to forget? We introduce Neural Garbage Collection (NGC), in which a language model learns to forget while learning to reason, trained end-to-end from outcome-based task reward alone. As the model reasons, it periodically pauses, decides which KV cache entries to evict, and continues to reason conditioned on the remaining cache. By treating tokens in a chain-of-thought and cache-eviction decisions as discrete actions sampled from the language model, we can use reinforcement learning to jointly optimize how the model reasons and how it manages its own memory: what the model evicts shapes what it remembers, what it remembers shapes its reasoning, and the correctness of that reasoning determines its reward. Crucially, the model learns this behavior entirely from a single learning signal - the outcome-based task reward - without supervised fine-tuning or proxy objectives. On Countdown, AMC, and AIME tasks, NGC maintains strong accuracy relative to the full-cache upper bound at 2-3x peak KV cache size compression and substantially outperforms eviction baselines. Our results are a first step towards a broader vision where end-to-end optimization drives both capability and efficiency in language models.
神经网络垃圾回收:在推理中学习遗忘 / Neural Garbage Collection: Learning to Forget while Learning to Reason
本文提出一种名为“神经网络垃圾回收”的方法,让语言模型在通过强化学习进行端到端推理的同时,自主学会有选择地丢弃KV缓存中的信息,从而在保持高准确率的前提下将峰值缓存占用压缩2到3倍,解决了长链推理中内存瓶颈问题。
源自 arXiv: 2604.18002