Beyond Autoregression: Fast LLMs via Self-Distillation Through Time

CLAIRE - School of Computer and Communication Sciences

TL;DR

We present a distillation method for discrete diffusion language models (DDLMs) that allows sampling text 4-8x times faster than AR models that uses KV-caching.

Our best student matches GPT-2 (that uses nucleus sampling) using only 32 model evaluations, generating 32 tokens per call on average.

Distillation targets and latency.

(a) The distillation targets are the log probabilities that lead to a token being denoised, concatenated with log probabilities of the last step for tokens that remain masked. (b) Latency of the distilled models. Sampling with 16 steps achieves an 8x speedup against the AR baseline that uses KV-caching.

Summary

Recent discrete diffusion language models (DDLM) such as MDLM and SEDD approach or match similarly sized AR models in generation quality. Importantly, discrete diffusion language models can generate samples in any order and in parallel, unlike regular autoregressive models.

However, sampling from DDLMs requires thousands of steps to achieve good performance. Additionally, since DDLMs use a bidirectional architecture, KV-caching is not applicable.

To reduce the number of sampling steps while retaining performance, we propose Self-Distillation Through Time (SDTT), a novel distillation method for DDLMs.

Most off-the-shelf distillation methods for continuous diffusion rely on deterministic mappings from noise to images, such as DDIM. Nonetheless, we demonstrate that SDTT can reduce the number of sampling steps for pre-trained MDLMs 32-64 folds.

Importantly, our final student can generate samples with lower perplexity than GPT-2 with nucleus sampling in 32 steps.

Our method is simple to implement and relatively cheap to run. Additionally, we release training and test code along distilled models.

Recent studies have identified one can improve the performance of a fixed model by scaling up computational resources at inference time. In this work, we improve the decoding speed of LLMs by moving away from AR modeling.

Perplexity and accuracy

(a) SDTT on small models trained for 1M steps. Student curves labeled dt=1/k are for students that match the teacher in k sampling steps. Successive lines correspond to additional SDTT rounds. SDTT students outperforms the teacher and GPT-2 with nucleus sampling (with 32 steps). (b) Accuracy with one-step greedy decoding of the last word. Distillation with KLD loss retains the teacher performance, unlike the Mean-Squared Error (MSE) and Total Variation Distance (TVD).

Method

Our method operates by matching the distribution of samples generated with m steps, in k steps, with k smaller than m. Practically, we learn to "collapse" multiple predictions from the teacher into a single step. We generate the distillation targets with the following algorithm (see also the figure at the top of the page):

Distillation teacher targets.

Given the distillation target, we then teach the student using the Kullback-Leibler divergence between the teacher targets and the student predictions:

Example samples (32 steps)

Today is a great day. The sun is shining...

Before SDTT:

the horizon has expanded. A blinked the sun. Then the sun went two feet dark. somewhere outside the football arena was listening discreetly. I wondered who. Two children, a they would care?

The players, and the suitors on hard hats of admirers, still diving into deep breaths in the natural comforts of all the sun. City's home tonight and a massive crowd, against a friendly crowd not particularly, not...

After SDTT:

the atmosphere is warm and cozy. The team has people really mixing and laughing. This is the kind of moment where we see 'good feelings' in people that we're loving with happiness and all the time. We're going to learn from each other, have something good, have fun together. I want to do a similar thing for people like that in the future, which is really about encouraging people who wanted to come here before, these...

All candidates will be provided with...

Before SDTT:

their credits, the pictures they present, and the terms used to extoll mark the commitment. In well as all electoral duties, suffice in the aspect that

Commitment with disclosure at the top of the ranking (reasons at the top of the election makes very simple participation important to do effectively)

Required to do fully and please be Particip to Be fun for as well as different qualities to rank and commit to the endorsements and truly control of the...

After SDTT:

details and will travel with them to their preferred location. On May 9th the rest will receive ' winners’ set up through Card222.

On May 10th, and using the correct results (redeem link on the back) others can select the app sign-up for be individually and for using the existing scoring platform.

Korean athletes will have panels on May 14th and and are hoping to Sydney a 2nd as soon as...

As the news reported...

Before SDTT:

ixixkaval, 50 is Belgica, worker from marticius, an IT-world, from a Brilsen town and lived from Switzerland, eager to have a speed to open Andu 8.0 and turn around, he said, with the Black47 file, which includes the many files he and his wife computer deleted, due to run on September 20, 2016, according to Europa news reports. Related Portugal

Related

Portugal

After SDTT:

 Erika Dike, 20-year-old Nilesville man in police custody, finally was found in the Rio Grande Sound river, around 3:20 p.m. Friday night.

Prosecutors have not yet ruled on Dike being a suspect or a member of the incident’s sheriff's department in Mesa before the death.

But ABC police are detectives soon received messages stating that Dike had been “trashed”

Paper abstract

Autoregressive (AR) Large Language Models (LLMs) have demonstrated significant success across numerous tasks. However, the AR modeling paradigm presents certain limitations; for instance, contemporary autoregressive LLMs are trained to generate one token at a time, which can result in noticeable latency.

Recent advances have indicated that search and repeated sampling can enhance performance in various applications, such as theorem proving, code generation, and alignment, by utilizing greater computational resources during inference. In this study, we demonstrate that diffusion language models are capable of generating at least 32 tokens simultaneously, while exceeding the performance of AR models in text quality and on the LAMBADA natural language understanding benchmark.

This outcome is achieved through a novel distillation method for discrete diffusion models, which reduces the number of inference steps by a factor of 32-64. Practically, our models, even without caching, can generate tokens at a rate that is up to 8 times faster than AR models employing KV-caching, and we anticipate further improvements with the inclusion of caching. Moreover, we demonstrate the efficacy of our approach for diffusion language models with up to 860M parameters.

BibTeX

@article{deschenaux2024autoregressionfastllmsselfdistillation,
        title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
        author={Deschenaux, Justin and Gulcehre, Caglar}
        eprint={2410.21035},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2410.21035}, 
      }