Hardware-Aware Efficient Primitives for Machine Learning

Speaker: Dan Fu

Location: 60 Fifth Avenue, Room 150

Date: Monday, March 18, 2024

Efficiency is increasingly tied to quality to machine learning, with more efficient training algorithms leading to more powerful models. However, today's most popular machine learning models are built on asymptotically inefficient primitives. For example, attention in Transformers scales quadratically in the input size, while MLPs scale quadratically in model dimension. In this talk, I discuss my work on improving the efficiency of the core primitives in machine learning, with an emphasis on hardware-aware algorithms and long-context applications. First, I focus on replacing attention with gated state space models (SSMs) and convolutions, which scale sub-quadratically in context length. I describe the H3 (Hungry Hungry Hippos) architecture, a gated SSM architecture that matches Transformers in quality up to 3B parameters and achieves 2.4x faster inference. Second, I focus on developing hardware-aware algorithms for SSMs and convolutions. I describe FlashFFTConv, a fast algorithm for computing SSMs and convolutions on GPU by optimizing the Fast Fourier Transform (FFT). FlashFFTConv yields up to 7x speedup and 5x memory savings, even over vendor solutions from Nvidia. Third, I will briefly touch on how these same techniques can also be used to develop sub-quadratic scaling in the model dimension. I will describe Monarch Mixer, which uses a generalization of the FFT to achieve sub-quadratic scaling in both sequence length and model dimension. Throughout the talk, I will give examples of how these ideas are beginning to take hold, with gated SSMs and their variants now leading to state-of-the-art performance in long-context language models, embedding models, and DNA foundation models.