Meta’s new multi-token prediction makes AI models up to 3X faster


Discover how companies are responsibly integrating AI in production. This invite-only event in SF will explore the intersection of technology and business. Find out how you can attend here.


In a recent study, researchers at Meta, Ecole des Ponts ParisTech and Université Paris-Saclay suggest improving the accuracy and speed of AI large language models (LLMs) by making them predict multiple tokens simultaneously.

This goes against the classic structure of auto-regressive language models, which have been designed to predict one token at a time.

While multi-token prediction is not a universal solution for every type of model and language task, it provides substantial benefits in some areas, with triple speeds and better performance on generative tasks.

While it has plenty of room for improvement, the technique could become a powerful tool for some LLM applications.

VB Event

The AI Impact Tour – San Francisco

Join us as we navigate the complexities of responsibly integrating AI in business at the next stop of VB’s AI Impact Tour in San Francisco. Don’t miss out on the chance to gain insights from industry experts, network with like-minded innovators, and explore the future of GenAI with customer experiences and optimize business processes.

Request an invite

Limits of next-token prediction

The classic way to train LLMs is known as “next-token prediction,” a self-supervised learning technique where the model is given a sequence of tokens and must predict the next one.

It then adds the predicted token to the input and repeats the process, one token at a time. By doing this over and over on large corpora of text, the model learns general patterns that allow it to output coherent passages of text.

Researchers have studied and documented the limitations of next-token prediction in acquiring language, world knowledge and reasoning capabilities.

For example, by just focusing on one token, the model becomes too sensitive to local patterns and overlooks predictions that require reasoning over longer horizons. Models trained on next-token prediction also require huge amounts of data to reach levels of fluency that humans acquire with much less text.

The new study by Meta is predicated on the hypothesis that “training language models to predict multiple future tokens at once results in higher sample efficiency.”

Multi-token prediction

Multi-token prediction instructs the LLM to predict several future tokens from each position in the training corpora at the same time. The researchers propose a simple multi-token prediction architecture that does not require extra training time or memory overhead.

The multi-token prediction language model is based on the Transformer architecture used in most LLMs, although with some modifications. The model uses the main structure of the Transformer though instead of a single output, it has multiple independent output heads, one for each token it wants to predict.

Transformer architecture with multi-token prediction

During inference, the model uses the basic next-token prediction scheme for each of the prediction heads and uses the additional output heads to speed up the decoding process. The model takes advantage of several similar works in the field.

“While cost-free and simple, multi-token prediction is an effective modification to train stronger and faster transformer models,” the researchers write.

Multi-token prediction in action

The researchers tested the new multi-token prediction scheme on a variety of tasks with models of 300 million to 13 billion parameters. 

Their findings include several interesting observations. For example, on smaller models, multi-token prediction results in worse results, but it becomes increasingly useful as the model size increases. For example, when trained for 4-token prediction, models with 6.7 billion and 13 billion parameters had several percentage points improvement over the baseline single-token prediction on the MBPP coding benchmark. “It is possible, with the exact same computational budget, to squeeze much more performance out of large language models given a fixed dataset using multi-token prediction,” the researchers write.

image

Multi-token prediction also makes models up to three times faster at inference time across a wide range of batch sizes, according to the researchers. “Pretraining with multi-token prediction allows the additional heads to be much more accurate than a simple finetuning of a next-token prediction model, thus allowing our models to unlock self-speculative decoding’s full potential,” the researchers write. 

The study also shows that multi-token prediction promotes learning longer-term patterns, especially in experiments where the model is trained on “byte-level tokenization,” where every byte is considered a single token. In these experiments, multi-byte prediction outperforms the baseline single-byte prediction models by a wide margin.

This is especially important for applications where there is no predefined vocabulary and the model must learn to work with very small chunks of information.

Multi-token prediction still has room for improvement. For example, the optimal number of tokens to predict depends on the kind of task and model size. The scientists are considering multiple future directions of research, including techniques to automatically choose the optimal number of tokens to predict and studying the dynamics between vocabulary sizes and multi-token predictions.

What might make this research and its future iterations useful for enterprise applications is the potential to provide faster inference and higher accuracy at little or no extra cost for generative tasks such as code completion. Since it also leaves most of the LLM architecture intact, it can be compatible with otheroptimization techniques for the Transformer block.



Source link

About The Author