Scalable AI Architectures

Posts
Apr 30, 2024
Share

Posts
Apr 30, 2024
Share

Today, two neural architectures scale to the limits of available data and compute:

  1. Transformers
  2. Diffusion

There’s a myth spreading that scalable architectures aren’t rare, and that every architecture scales with enough optimization. However, decades of research have revealed countless architectures – inspired by all manner of physics, neuroscience, fruit fly mating habits, and mathematics – that don’t scale.

Even if data and compute are the most urgent bottlenecks in advancing today’s capabilities, we shouldn’t take scalable architectures for granted. We should value them, advance them, and find better architectures that address their limitations.

Neural architectures that don’t (black) and do (blue) scale

Finding new scalable architectures doesn’t necessarily require massive training runs, thanks to scaling laws. Scaling laws are relationships that predict how model performance will change as dataset size, model size, and training costs vary. In other words, if we plot an architecture’s performance at enough small training runs along those dimensions, we can predict its performance at much larger runs.

For many neural architectures, scaling laws level off well below the 1 billion-parameter mark. Kaplan’s 2020 paper, “Scaling Laws for AI Language Models”, shows the scaling laws of traditional LSTMs plateauing at just 10 million parameters. Tay’s 2022 paper produces scaling laws for ten different architectures, and shows five non-transformer architectures’ scaling laws plateauing well before the 1 billion-parameter mark. 

As a result, ruling out new architectures by examining their scaling laws can be quite cheap. According to Databricks, a 1 billion-parameter training run costs only $2,000 (source). 

Traditional LSTM’s scaling law plateauing in “Scaling Laws for AI Language Models” (Kaplan, 2020)
Non-transformers’ scaling laws plateauing in “Scaling Laws vs Model Architectures” (Yi Tay, 2022) 

In fact, it was OpenAI’s 2019 GPT- 2 transformer training run at 1.5 billion parameters that gave them the confidence to invest in a GPT-3 training run at 175 billion parameters, over two orders of magnitude larger:

  1. Scaling laws: “by the time we made GPT-2… you could look at the scaling laws and see what was going to happen.” - Sam Altman, On with Kara Swisher, 3/23/23
  2. Benchmarks: “state-of-the-art on Winograd Schema, LAMBADA” - OpenAI Announcement
  3. Emergent capabilities: “capable of generating samples… that feel close to human quality” - OpenAI Announcement
GPT scaling laws in ”Scaling Laws for AI Language Models” (Kaplan, 2020)

The straightforward next step to improve model capabilities is continuing to scale the architectures that work, by increasing the size of our data centers and datasets. However, these architectures have major limitations, such as logarithmic scaling, and limited context, which means there's still an opportunity to innovate in core architecture research.

Positively bending our best scaling laws would help not only at the top end, by increasing the capabilities of models trained at the largest data centers on the largest datasets, but also at the middle and low ends, by allowing capable models to be used at affordable prices.

Some of the most exciting lines of work to improve existing scalable architectures include:

  1. Planning (e.g., tree search, reinforcement learning)
  2. Longer context (e.g., Stanford’s FlashAttention, Google’s RingAttention)
  3. Ensembling (e.g., Google’s mixture of experts, Sakana’s evolutionary model merging)

There are also early candidates for new scalable architectures:

  1. State-space models (e.g., Stanford’s S4, Hyena, Cartesia’s Mamba, AI2’s Jamba)
  2. RNN variants (e.g., Bo Peng’s RWKV)
  3. Striped variants of the above with existing architectures

Notably, these new architectures are mostly driven by small teams outside the major industry labs. For example, S4 has three authors, Mamba has two authors, and RWKV is primarily developed by Bo Peng.

These emerging architectures are already proving useful in real-world domains. For example, Arc Institute's Evo, a 7B biological model, is based on an architecture that avoids the quadratic complexity of attention, allowing it to process vast volumes of biological sequence data within a 131K context window.

We're still early in the history of AI, and our architectures are far from physics-limit optimal. While scaling via data center buildouts and data acquisition is the clear next step in advancing model capabilities, there are still major breakthroughs to come from the scalable architectures themselves.

Notes 

  1. I opted for simplicity over accuracy in the discussion of scaling laws. In reality, they’re not as well-understood as many believe. We only know a small sliver of how models scale above a certain level, and it’s hard to answer questions beyond “What is the best model I can train with a fixed budget of $__?”. Factors like data quality are also hard to quantify.
  2. Scaling law papers often prioritize speed over rigor (see failed replications), because the field is moving so fast.
  3. What will be the real bottlenecks in scaling known architectures? Data? GPU supply? Energy? Dollars? Data centers? Researchers?
  4. What should small teams outside the major industry labs work on?
  5. There are many other promising approaches to improving models that don't involve modifying the architecture, like novel training methods (e.g. UL2R), better optimization, synthetic data, data curation, compression, and alignment.

Thanks to friends and colleagues who provided feedback on early drafts of this.

Continue Reading