Low-Rank Gradient Descent

Several recent empirical studies demonstrate that important machine learning tasks such as training deep neural networks, exhibit a low-rank structure, where most of the variation in the loss function occurs only in a few directions of the input space. In this article, we leverage such low-rank stru...

Full description

Bibliographic Details
Main Authors: Romain Cosson, Ali Jadbabaie, Anuran Makur, Amirhossein Reisizadeh, Devavrat Shah
Format: Article
Language:English
Published: IEEE 2023-01-01
Series:IEEE Open Journal of Control Systems
Subjects:
Online Access:https://ieeexplore.ieee.org/document/10250907/
Description
Summary:Several recent empirical studies demonstrate that important machine learning tasks such as training deep neural networks, exhibit a low-rank structure, where most of the variation in the loss function occurs only in a few directions of the input space. In this article, we leverage such low-rank structure to reduce the high computational cost of canonical gradient-based methods such as gradient descent (<monospace>GD</monospace>). Our proposed <italic>Low-Rank Gradient Descent</italic> (<monospace>LRGD</monospace>) algorithm finds an <inline-formula><tex-math notation="LaTeX">$\epsilon$</tex-math></inline-formula>-approximate stationary point of a <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional function by first identifying <inline-formula><tex-math notation="LaTeX">$r \leq p$</tex-math></inline-formula> significant directions, and then estimating the true <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional gradient at every iteration by computing directional derivatives only along those <inline-formula><tex-math notation="LaTeX">$r$</tex-math></inline-formula> directions. We establish that the &#x201C;directional oracle complexities&#x201D; of <monospace>LRGD</monospace> for strongly convex and non-convex objective functions are <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r \log (1/\epsilon) + rp)$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r/\epsilon ^{2} + rp)$</tex-math></inline-formula>, respectively. Therefore, when <inline-formula><tex-math notation="LaTeX">$r \ll p$</tex-math></inline-formula>, <monospace>LRGD</monospace> provides significant improvement over the known complexities of <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p \log (1/\epsilon))$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p/\epsilon ^{2})$</tex-math></inline-formula> of <monospace>GD</monospace> in the strongly convex and non-convex settings, respectively. Furthermore, we formally characterize the classes of exactly and approximately low-rank functions. Empirically, using real and synthetic data, <monospace>LRGD</monospace> provides significant gains over <monospace>GD</monospace> when the data has low-rank structure, and in the absence of such structure, <monospace>LRGD</monospace> does not degrade performance compared to <monospace>GD</monospace>. This suggests that <monospace>LRGD</monospace> could be used in practice in any setting in place of <monospace>GD</monospace>.
ISSN:2694-085X