Skip to main content
CenXiv.org
此网站处于试运行阶段,支持我们!
我们衷心感谢所有贡献者的支持。
贡献
赞助
cenxiv logo > stat > arXiv:2205.01445

帮助 | 高级搜索

统计学 > 机器学习

arXiv:2205.01445 (stat)
[提交于 2022年5月3日 ]

标题: 特征学习的高维渐近性:一步梯度如何改进表示

标题: High-dimensional Asymptotics of Feature Learning: How One Gradient Step Improves the Representation

Authors:Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, Greg Yang
摘要: 我们研究两层神经网络中第一层参数 $\boldsymbol{W}$ 的首次梯度下降步骤:$f(\boldsymbol{x}) = \frac{1}{\sqrt{N}}\boldsymbol{a}^\top\sigma(\boldsymbol{W}^\top\boldsymbol{x})$,其中 $\boldsymbol{W}\in\mathbb{R}^{d\times N}, \boldsymbol{a}\in\mathbb{R}^{N}$ 随机初始化,训练目标是经验均方误差(MSE)损失:$\frac{1}{n}\sum_{i=1}^n (f(\boldsymbol{x}_i)-y_i)^2$。在比例渐近极限下,当 $n,d,N\to\infty$ 以相同速率变化,并且处于理想的学生-教师设定中时,我们证明首次梯度更新包含一个秩-1“尖峰”,这导致第一层权重与教师模型 $f^*$ 的线性组件之间的对齐。 为了刻画这种对齐的影响,我们计算了在单指标模型 $f^*$ 下,学习率 $\eta$ 条件下,对 $\boldsymbol{W}$ 进行一次梯度下降步长后的共轭核上的岭回归预测风险。 我们考虑了初始学习率 $\eta$ 的两种缩放方式。 对于较小的 $\eta$,我们建立了训练特征映射的高斯等价性质,并证明学习到的核比初始随机特征模型有所改进,但无法击败最佳线性模型。 而对于足够大的 $\eta$,我们证明了对于某些 $f^*$,在训练特征上相同的岭估计量可以超越这个“线性区域”,并且优于广泛的随机特征和旋转不变核。 我们的结果表明,即使一次梯度步长也可以比随机特征带来显著优势,并突出了学习率缩放在训练初始阶段的作用。
摘要: We study the first gradient descent step on the first-layer parameters $\boldsymbol{W}$ in a two-layer neural network: $f(\boldsymbol{x}) = \frac{1}{\sqrt{N}}\boldsymbol{a}^\top\sigma(\boldsymbol{W}^\top\boldsymbol{x})$, where $\boldsymbol{W}\in\mathbb{R}^{d\times N}, \boldsymbol{a}\in\mathbb{R}^{N}$ are randomly initialized, and the training objective is the empirical MSE loss: $\frac{1}{n}\sum_{i=1}^n (f(\boldsymbol{x}_i)-y_i)^2$. In the proportional asymptotic limit where $n,d,N\to\infty$ at the same rate, and an idealized student-teacher setting, we show that the first gradient update contains a rank-1 "spike", which results in an alignment between the first-layer weights and the linear component of the teacher model $f^*$. To characterize the impact of this alignment, we compute the prediction risk of ridge regression on the conjugate kernel after one gradient step on $\boldsymbol{W}$ with learning rate $\eta$, when $f^*$ is a single-index model. We consider two scalings of the first step learning rate $\eta$. For small $\eta$, we establish a Gaussian equivalence property for the trained feature map, and prove that the learned kernel improves upon the initial random features model, but cannot defeat the best linear model on the input. Whereas for sufficiently large $\eta$, we prove that for certain $f^*$, the same ridge estimator on trained features can go beyond this "linear regime" and outperform a wide range of random features and rotationally invariant kernels. Our results demonstrate that even one gradient step can lead to a considerable advantage over random features, and highlight the role of learning rate scaling in the initial phase of training.
评论: 71页
主题: 机器学习 (stat.ML) ; 机器学习 (cs.LG); 统计理论 (math.ST)
引用方式: arXiv:2205.01445 [stat.ML]
  (或者 arXiv:2205.01445v1 [stat.ML] 对于此版本)
  https://doi.org/10.48550/arXiv.2205.01445
通过 DataCite 发表的 arXiv DOI

提交历史

来自: Denny Wu [查看电子邮件]
[v1] 星期二, 2022 年 5 月 3 日 12:09:59 UTC (1,403 KB)
全文链接:

获取论文:

    查看标题为《》的 PDF
  • 查看中文 PDF
  • 查看 PDF
  • TeX 源代码
  • 其他格式
查看许可
当前浏览上下文:
stat
< 上一篇   |   下一篇 >
新的 | 最近的 | 2022-05
切换浏览方式为:
cs
cs.LG
math
math.ST
stat.ML
stat.TH

参考文献与引用

  • NASA ADS
  • 谷歌学术搜索
  • 语义学者
a 导出 BibTeX 引用 加载中...

BibTeX 格式的引用

×
数据由提供:

收藏

BibSonomy logo Reddit logo

文献和引用工具

文献资源探索 (什么是资源探索?)
连接的论文 (什么是连接的论文?)
Litmaps (什么是 Litmaps?)
scite 智能引用 (什么是智能引用?)

与本文相关的代码,数据和媒体

alphaXiv (什么是 alphaXiv?)
CatalyzeX 代码查找器 (什么是 CatalyzeX?)
DagsHub (什么是 DagsHub?)
Gotit.pub (什么是 GotitPub?)
Hugging Face (什么是 Huggingface?)
带有代码的论文 (什么是带有代码的论文?)
ScienceCast (什么是 ScienceCast?)

演示

复制 (什么是复制?)
Hugging Face Spaces (什么是 Spaces?)
TXYZ.AI (什么是 TXYZ.AI?)

推荐器和搜索工具

影响之花 (什么是影响之花?)
核心推荐器 (什么是核心?)
IArxiv 推荐器 (什么是 IArxiv?)
  • 作者
  • 地点
  • 机构
  • 主题

arXivLabs:与社区合作伙伴的实验项目

arXivLabs 是一个框架,允许合作伙伴直接在我们的网站上开发和分享新的 arXiv 特性。

与 arXivLabs 合作的个人和组织都接受了我们的价值观,即开放、社区、卓越和用户数据隐私。arXiv 承诺这些价值观,并且只与遵守这些价值观的合作伙伴合作。

有一个为 arXiv 社区增加价值的项目想法吗? 了解更多关于 arXivLabs 的信息.

这篇论文的哪些作者是支持者? | 禁用 MathJax (什么是 MathJax?)
  • 关于
  • 帮助
  • contact arXivClick here to contact arXiv 联系
  • 订阅 arXiv 邮件列表点击这里订阅 订阅
  • 版权
  • 隐私政策
  • 网络无障碍帮助
  • arXiv 运营状态
    通过...获取状态通知 email 或者 slack

京ICP备2025123034号