Low-Rank Gradient Descent

Several recent empirical studies demonstrate that important machine learning tasks such as training deep neural networks, exhibit a low-rank structure, where most of the variation in the loss function occurs only in a few directions of the input space. In this article, we leverage such low-rank stru...

Full description

Bibliographic Details
Main Authors: Romain Cosson, Ali Jadbabaie, Anuran Makur, Amirhossein Reisizadeh, Devavrat Shah
Format: Article
Language:English
Published: IEEE 2023-01-01
Series:IEEE Open Journal of Control Systems
Subjects:
Online Access:https://ieeexplore.ieee.org/document/10250907/
_version_ 1827212241753079808
author Romain Cosson
Ali Jadbabaie
Anuran Makur
Amirhossein Reisizadeh
Devavrat Shah
author_facet Romain Cosson
Ali Jadbabaie
Anuran Makur
Amirhossein Reisizadeh
Devavrat Shah
author_sort Romain Cosson
collection DOAJ
description Several recent empirical studies demonstrate that important machine learning tasks such as training deep neural networks, exhibit a low-rank structure, where most of the variation in the loss function occurs only in a few directions of the input space. In this article, we leverage such low-rank structure to reduce the high computational cost of canonical gradient-based methods such as gradient descent (<monospace>GD</monospace>). Our proposed <italic>Low-Rank Gradient Descent</italic> (<monospace>LRGD</monospace>) algorithm finds an <inline-formula><tex-math notation="LaTeX">$\epsilon$</tex-math></inline-formula>-approximate stationary point of a <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional function by first identifying <inline-formula><tex-math notation="LaTeX">$r \leq p$</tex-math></inline-formula> significant directions, and then estimating the true <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional gradient at every iteration by computing directional derivatives only along those <inline-formula><tex-math notation="LaTeX">$r$</tex-math></inline-formula> directions. We establish that the &#x201C;directional oracle complexities&#x201D; of <monospace>LRGD</monospace> for strongly convex and non-convex objective functions are <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r \log (1/\epsilon) + rp)$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r/\epsilon ^{2} + rp)$</tex-math></inline-formula>, respectively. Therefore, when <inline-formula><tex-math notation="LaTeX">$r \ll p$</tex-math></inline-formula>, <monospace>LRGD</monospace> provides significant improvement over the known complexities of <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p \log (1/\epsilon))$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p/\epsilon ^{2})$</tex-math></inline-formula> of <monospace>GD</monospace> in the strongly convex and non-convex settings, respectively. Furthermore, we formally characterize the classes of exactly and approximately low-rank functions. Empirically, using real and synthetic data, <monospace>LRGD</monospace> provides significant gains over <monospace>GD</monospace> when the data has low-rank structure, and in the absence of such structure, <monospace>LRGD</monospace> does not degrade performance compared to <monospace>GD</monospace>. This suggests that <monospace>LRGD</monospace> could be used in practice in any setting in place of <monospace>GD</monospace>.
first_indexed 2024-03-08T21:35:13Z
format Article
id doaj.art-6ea85ff32aa7488f883684b55870207f
institution Directory Open Access Journal
issn 2694-085X
language English
last_indexed 2025-03-21T13:56:25Z
publishDate 2023-01-01
publisher IEEE
record_format Article
series IEEE Open Journal of Control Systems
spelling doaj.art-6ea85ff32aa7488f883684b55870207f2024-06-25T23:09:35ZengIEEEIEEE Open Journal of Control Systems2694-085X2023-01-01238039510.1109/OJCSYS.2023.331508810250907Low-Rank Gradient DescentRomain Cosson0https://orcid.org/0009-0004-8784-7112Ali Jadbabaie1https://orcid.org/0000-0003-1122-3069Anuran Makur2https://orcid.org/0000-0002-2978-8116Amirhossein Reisizadeh3https://orcid.org/0000-0002-1730-8402Devavrat Shah4https://orcid.org/0000-0003-0737-3259National Institute for Research in Digital Science and Technology, Paris, FranceLaboratory for Information and Decision Systems, Massachusetts Institute Technology, Cambridge, MA, USADepartment of Computer Science and School of Electrical and Computer Engineering, Purdue University, West Lafayette, IN, USALaboratory for Information and Decision Systems, Massachusetts Institute Technology, Cambridge, MA, USALaboratory for Information and Decision Systems, Massachusetts Institute Technology, Cambridge, MA, USASeveral recent empirical studies demonstrate that important machine learning tasks such as training deep neural networks, exhibit a low-rank structure, where most of the variation in the loss function occurs only in a few directions of the input space. In this article, we leverage such low-rank structure to reduce the high computational cost of canonical gradient-based methods such as gradient descent (<monospace>GD</monospace>). Our proposed <italic>Low-Rank Gradient Descent</italic> (<monospace>LRGD</monospace>) algorithm finds an <inline-formula><tex-math notation="LaTeX">$\epsilon$</tex-math></inline-formula>-approximate stationary point of a <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional function by first identifying <inline-formula><tex-math notation="LaTeX">$r \leq p$</tex-math></inline-formula> significant directions, and then estimating the true <inline-formula><tex-math notation="LaTeX">$p$</tex-math></inline-formula>-dimensional gradient at every iteration by computing directional derivatives only along those <inline-formula><tex-math notation="LaTeX">$r$</tex-math></inline-formula> directions. We establish that the &#x201C;directional oracle complexities&#x201D; of <monospace>LRGD</monospace> for strongly convex and non-convex objective functions are <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r \log (1/\epsilon) + rp)$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(r/\epsilon ^{2} + rp)$</tex-math></inline-formula>, respectively. Therefore, when <inline-formula><tex-math notation="LaTeX">$r \ll p$</tex-math></inline-formula>, <monospace>LRGD</monospace> provides significant improvement over the known complexities of <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p \log (1/\epsilon))$</tex-math></inline-formula> and <inline-formula><tex-math notation="LaTeX">${\mathcal {O}}(p/\epsilon ^{2})$</tex-math></inline-formula> of <monospace>GD</monospace> in the strongly convex and non-convex settings, respectively. Furthermore, we formally characterize the classes of exactly and approximately low-rank functions. Empirically, using real and synthetic data, <monospace>LRGD</monospace> provides significant gains over <monospace>GD</monospace> when the data has low-rank structure, and in the absence of such structure, <monospace>LRGD</monospace> does not degrade performance compared to <monospace>GD</monospace>. This suggests that <monospace>LRGD</monospace> could be used in practice in any setting in place of <monospace>GD</monospace>.https://ieeexplore.ieee.org/document/10250907/Active subspacefirst order optimizationlow-rank functionsoracle complexity
spellingShingle Romain Cosson
Ali Jadbabaie
Anuran Makur
Amirhossein Reisizadeh
Devavrat Shah
Low-Rank Gradient Descent
IEEE Open Journal of Control Systems
Active subspace
first order optimization
low-rank functions
oracle complexity
title Low-Rank Gradient Descent
title_full Low-Rank Gradient Descent
title_fullStr Low-Rank Gradient Descent
title_full_unstemmed Low-Rank Gradient Descent
title_short Low-Rank Gradient Descent
title_sort low rank gradient descent
topic Active subspace
first order optimization
low-rank functions
oracle complexity
url https://ieeexplore.ieee.org/document/10250907/
work_keys_str_mv AT romaincosson lowrankgradientdescent
AT alijadbabaie lowrankgradientdescent
AT anuranmakur lowrankgradientdescent
AT amirhosseinreisizadeh lowrankgradientdescent
AT devavratshah lowrankgradientdescent