When human-AI conversations involve multiple rounds of continuous conversation, the powerful large-scale linguistic machine learning models that power chatbots like ChatGPT sometimes start to break down, causing the bot's performance to degrade dramatically.
Researchers at MIT and elsewhere have pinpointed the surprising cause of this problem and developed a simple solution that allows chatbots to maintain nonstop conversations without crashes or slowdowns.
Their method involves tuning the key-value cache (similar to conversational memory) that lies at the core of many large-scale language models. In some methods, the first piece of data is popped out when this cache needs to hold more information than its capacity. This may cause the model to fail.
The researchers' method allows the chatbot to keep chatting no matter how long the conversation goes on by ensuring that the first few data points remain in memory.
This method, called StreamingLLM, allows the model to remain efficient even as conversations grow longer than 4 million words. Compared to other methods that prevent collisions by continuously recalculating portions of past conversations, StreamingLLM performed more than 22 times faster.
This allows chatbots to conduct long conversations throughout the workday without having to constantly reboot, enabling efficient AI assistants for tasks such as copywriting, editing, or code generation.
“This method now allows us to continuously deploy language models at scale. “If we create chatbots that are always available to chat and can always respond based on recent conversations, we can use these chatbots for new applications,” says Guangxuan Xiao, a graduate student in Electrical Engineering and Computer Science (EECS), whose paper is StreamingLLM. is the main author.
Xiao's co-authors include his advisor, Song Han, an associate professor at EECS, a member of the MIT-IBM Watson AI Lab, and a distinguished scientist at NVIDIA. The same goes for Yuandong Tian, a researcher at Meta AI. Beidi Chen, assistant professor at Carnegie Mellon University; and senior author Mike Lewis, a research scientist at Meta AI. This work will be presented at the International Conference on Learning Representations.
a mysterious phenomenon
Large-scale language models encode data, such as words from user queries, into representations called tokens. Many models use an attention mechanism that uses these tokens to generate new text.
Typically, AI chatbots write new text based on the text they just saw, so they store recent tokens, called KV caches, in memory for future use. The attention mechanism builds a grid containing all the tokens in the cache, an “attention map” that maps how strongly each token or word is associated with each other.
Understanding these relationships is one feature that allows large-scale language models to produce human-like text.
However, if the cache becomes very large, the attention map may become even larger, slowing down computation.
Additionally, model performance degrades if content encoding requires more tokens than the cache can hold. For example, one popular model can store 4,096 tokens, while an academic paper has about 10,000 tokens.
To solve this problem, researchers use a “sliding cache,” which pulls out the oldest tokens and adds new ones. However, model performance often plummets as soon as the first token is removed, causing the quality of newly generated words to deteriorate rapidly.
In this new paper, the researchers realized that keeping the first token in a sliding cache ensures that the model maintains its performance even when the cache size is exceeded.
But this didn't make sense. The first word in a novel most likely has nothing to do with the last word. So why is the first word so important for the model to generate the latest word?
In their new paper, researchers also uncovered the cause of this phenomenon.
Distracted
Some models use Softmax operations in their attention mechanism. This assigns each token a score that indicates how related they are to each other. Softmax operations require that the sum of all attention scores be 1. Because most tokens are not closely related, their attention scores are very low. The model dumps any remaining Attention scores on the first token.
The researchers call this first token the “attention sink.”
“We need an attention sink, and the model decides to use the first token as the attention sink. Because this token is displayed globally. All other tokens can see this. We found that we always had to focus our attention on the cache to maintain model dynamics,” says Han.
In building StreamingLLM, researchers found that optimal performance was achieved with four attention sink tokens at the beginning of the sliding cache.
They also found that the positional encoding of each token should remain the same even as new tokens are added and other tokens pop out. If token 5 is bumped out, token 6 must remain encoded as 6 even though it is currently the fifth token in the cache.
By combining these two ideas, StreamingLLM was able to outperform popular methods that use recomputation while maintaining a continuous conversation.
For example, if there are 256 tokens in the cache, the recompute method takes 63 milliseconds to decode a new token, while StreamingLLM takes 31 milliseconds. However, when the cache size increases to 4,096 tokens, 1,411 milliseconds are needed to recalculate new tokens, while StreamingLLM requires only 65 milliseconds.
“StreamingLLM’s innovative approach centered on the attention mechanism ensures stable memory usage and performance even when processing text up to 4 million tokens long,” says Yang You, a young professor in the Department of Computer Science at National University. Singapore did not participate in this action. “This feature isn't just impressive. This is innovative and allows StreamingLLM to be applied to a variety of AI applications. “The performance and versatility of StreamingLLM indicate that it is a very promising technology poised to revolutionize the way we approach AI-based generation applications.”
Tianqi Chen, an assistant professor of machine learning and computer science at Carnegie Mellon University who was not involved in the study, agrees: “Streaming LLM allows us to seamlessly extend the conversation length of large-scale language models. We were able to use it to successfully deploy the Mistral model on iPhone.”
The researchers also investigated how to take advantage of attentional focus during model training by adding multiple placeholder tokens to every training sample.
They found that training with attention sinks allowed them to maintain the model's performance with only one attention sink in the cache, rather than the four attention sinks typically required to stabilize the performance of pretrained models.
However, while StreamingLLM allows the model to conduct continuous conversations, the model cannot remember words that are not stored in the cache. In the future, the researchers plan to target these limitations by investigating ways to retrieve removed tokens or enable the model to remember previous conversations.
StreamingLLM is integrated into TensorRT-LLM, NVIDIA's large-scale language model optimization library.
This research is funded in part by the MIT-IBM Watson AI Lab, the MIT Science Hub, and the National Science Foundation.