Search papers, labs, and topics across Lattice.
This paper introduces a reinforcement learning approach to train lightweight decoding adapters that dynamically select sampling strategies for LLMs at inference time, without fine-tuning the LLM itself. The approach frames decoding as a contextual bandit problem at the sequence level and a POMDP at the token level, allowing for adaptive selection of decoding strategies based on prompt embeddings, model features, and available compute budgets. Experiments on MATH and CodeContests demonstrate that the learned adapters improve accuracy-budget tradeoffs, achieving up to 10.2% improvement in Pass@1 accuracy on MATH compared to static baselines.
Forget fixed decoding parameters: this RL-trained adapter dynamically adjusts LLM sampling strategies at inference, boosting accuracy by up to 10% under tight compute budgets.
Decoding from large language models (LLMs) typically relies on fixed sampling hyperparameters (e.g., temperature, top-p), despite substantial variation in task difficulty and uncertainty across prompts and individual decoding steps. We propose to learn adaptive decoding policies that dynamically select sampling strategies at inference time, conditioned on available compute resources. Rather than fine-tuning the language model itself, we introduce lightweight decoding adapters trained with reinforcement learning and verifiable terminal rewards (e.g. correctness on math and coding tasks). At the sequence level, we frame decoding as a contextual bandit problem: a policy selects a decoding strategy (e.g. greedy, top-k, min-p) for each prompt, conditioned on the prompt embedding and a parallel sampling budget. At the token level, we model decoding as a partially observable Markov decision process (POMDP), where a policy selects sampling actions at each token step based on internal model features and the remaining token budget. Experiments on the MATH and CodeContests benchmarks show that the learned adapters improve the accuracy-budget tradeoff: on MATH, the token-level adapter improves Pass@1 accuracy by up to 10.2% over the best static baseline under a fixed token budget, while the sequence-level adapter yields 2-3% gains under fixed parallel sampling. Ablation analyses support the contribution of both sequence- and token-level adaptation.