MaD Seminar: Feature Learning with SGD

Speaker: Jason Lee

Location: 60 Fifth Avenue, Room 150

Date: Thursday, March 30, 2023

I will present two recent works on analyzing feature learning for SGD. The first focuses on the task of learning a single index model \sigma(w* x)$ with respect to the isotropic Gaussian distribution in d dimensions, including the special case when \sigma is a kth order Hermite which corresponds to the Gaussian analog of parity learning. Prior work has shown that the sample complexity of learning w* is governed by the \emph{information exponent} k* of the link function \sigma, which is defined as the index of the first nonzero Hermite coefficient of \sigma. Prior upper bounds have shown that n > d^{k*-1} samples suffice for learning w* and that this is tight for online SGD (Ben Arous et al., 2020). However, the CSQ lower bound for gradient based methods only shows that n > d^{k*/2}$ samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on the Gaussian-smoothed loss learns w* with n > d^{k*/2}$ samples.

Next, we turn to the problem of learning multi index models f(x) = g(Ux), where U encodes a latent representation of low dimension. Significant prior work has established that neural networks trained by gradient descent behave like kernel methods, despite significantly worse empirical performance of kernel methods. However, in this work we demonstrate that for this large class of functions that there is a large gap between kernel methods and gradient descent on a two-layer neural network, by showing that gradient descent learns representations relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form f*(x)=g(Ux) where U is d by r. When the degree of f* is p, it is known that n≍d^p samples are necessary to learn f* in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to f*. This results in an improved sample complexity of n≍d^2r+drp. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation U but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of d.