Neural tangent kernel – Wikipedia

before-content-x4

Type of kernel induced by artificial neural networks

after-content-x4

In the study of artificial neural networks (ANNs), the neural tangent kernel (NTK) is a kernel that describes the evolution of deep artificial neural networks during their training by gradient descent. It allows ANNs to be studied using theoretical tools from kernel methods.

For most common neural network architectures, in the limit of large layer width the NTK becomes constant. This enables simple closed form statements to be made about neural network predictions, training dynamics, generalization, and loss surfaces. For example, it guarantees that wide enough ANNs converge to a global minimum when trained to minimize an empirical loss. The NTK of large width networks is also related to several other large width limits of neural networks.

The NTK was introduced in 2018 by Arthur Jacot, Franck Gabriel and Clément Hongler.[1] It was implicit in contemporaneous work on overparameterization.[2][3][4][5]

Definition[edit]

Scalar output case[edit]

An ANN with scalar output consists of a family of functions

f(,θ):RninR{displaystyle fleft(cdot ,theta right):mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} }

parametrized by a vector of parameters

θRP{displaystyle theta in mathbb {R} ^{P}}

.

after-content-x4

The NTK is a kernel

Θ:Rnin×RninR{displaystyle Theta :mathbb {R} ^{n_{mathrm {in} }}times mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} }

defined by

In the language of kernel methods, the NTK

Θ{displaystyle Theta }

is the kernel associated with the feature map

(xθpf(x;θ))p=1,,P{displaystyle left(xmapsto partial _{theta _{p}}fleft(x;theta right)right)_{p=1,ldots ,P}}

.

Vector output case[edit]

An ANN with vector output of size

nout{displaystyle n_{mathrm {out} }}

consists in a family of functions

f(;θ):RninRnout{displaystyle fleft(cdot ;theta right):mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} ^{n_{mathrm {out} }}}

parametrized by a vector of parameters

θRP{displaystyle theta in mathbb {R} ^{P}}

.

In this case, the NTK

Θ:Rnin×RninMnout(R){displaystyle Theta :mathbb {R} ^{n_{mathrm {in} }}times mathbb {R} ^{n_{mathrm {in} }}to {mathcal {M}}_{n_{mathrm {out} }}left(mathbb {R} right)}

is a matrix-valued kernel, with values in the space of

nout×nout{displaystyle n_{mathrm {out} }times n_{mathrm {out} }}

matrices, defined by

Derivation[edit]

When optimizing the parameters

θRP{displaystyle theta in mathbb {R} ^{P}}

of an ANN to minimize an empirical loss through gradient descent, the NTK governs the dynamics of the ANN output function

fθ{displaystyle f_{theta }}

throughout the training.

Scalar output case[edit]

For a dataset

(xi)i=1,,nRnin{displaystyle left(x_{i}right)_{i=1,ldots ,n}subset mathbb {R} ^{n_{mathrm {in} }}}

with scalar labels

(zi)i=1,,nR{displaystyle left(z_{i}right)_{i=1,ldots ,n}subset mathbb {R} }

and a loss function

c:R×RR{displaystyle c:mathbb {R} times mathbb {R} to mathbb {R} }

, the associated empirical loss, defined on functions

f:RninR{displaystyle f:mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} }

, is given by

When the ANN

f(;θ):RninR{displaystyle fleft(cdot ;theta right):mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} }

is trained to fit the dataset (i.e. minimize

C{displaystyle {mathcal {C}}}

) via continuous-time gradient descent, the parameters

(θ(t))t0{displaystyle left(theta left(tright)right)_{tgeq 0}}

evolve through the ordinary differential equation:

During training the ANN output function follows an evolution differential equation given in terms of the NTK:

This equation shows how the NTK drives the dynamics of

f(;θ(t)){displaystyle fleft(cdot ;theta left(tright)right)}

in the space of functions

RninR{displaystyle mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} }

during training.

Vector output case[edit]

For a dataset

(xi)i=1,,nRnin{displaystyle left(x_{i}right)_{i=1,ldots ,n}subset mathbb {R} ^{n_{mathrm {in} }}}

with vector labels

(zi)i=1,,nRnout{displaystyle left(z_{i}right)_{i=1,ldots ,n}subset mathbb {R} ^{n_{mathrm {out} }}}

and a loss function

c:Rnout×RnoutR{displaystyle c:mathbb {R} ^{n_{mathrm {out} }}times mathbb {R} ^{n_{mathrm {out} }}to mathbb {R} }

, the corresponding empirical loss on functions

f:RninRnout{displaystyle f:mathbb {R} ^{n_{mathrm {in} }}to mathbb {R} ^{n_{mathrm {out} }}}

is defined by

The training of

fθ(t){displaystyle f_{theta left(tright)}}

through continuous-time gradient descent yields the following evolution in function space driven by the NTK:

Interpretation[edit]

The NTK

Θ(x,xi;θ){displaystyle Theta left(x,x_{i};theta right)}

represents the influence of the loss gradient

wc(w,zi)|w=f(xi;θ){displaystyle partial _{w}cleft(w,z_{i}right){big |}_{w=fleft(x_{i};theta right)}}

with respect to example

i{displaystyle i}

on the evolution of ANN output

f(x;θ){displaystyle fleft(x;theta right)}

through a gradient descent step: in the scalar case, this reads

In particular, each data point

xi{displaystyle x_{i}}

influences the evolution of the output

f(x;θ){displaystyle fleft(x;theta right)}

for each

x{displaystyle x}

throughout the training, in a way that is captured by the NTK

Θ(x,xi;θ){displaystyle Theta left(x,x_{i};theta right)}

.

Large-width limit[edit]

Recent theoretical and empirical work in deep learning has shown the performance of ANNs to strictly improve as their layer widths grow larger.[6][7] For various ANN architectures, the NTK yields precise insight into the training in this large-width regime.[1][8][9][10][11][12]

Wide fully-connected ANNs have a deterministic NTK, which remains constant throughout training[edit]

Consider an ANN with fully-connected layers

=0,,L{displaystyle ell =0,ldots ,L}

of widths

n0=nin,n1,,nL=nout{displaystyle n_{0}=n_{mathrm {in} },n_{1},ldots ,n_{L}=n_{mathrm {out} }}

, so that

f(;θ)=RL1R0{displaystyle fleft(cdot ;theta right)=R_{L-1}circ cdots circ R_{0}}

, where

R=σA{displaystyle R_{ell }=sigma circ A_{ell }}

is the composition of an affine transformation

Ai{displaystyle A_{i}}

with the pointwise application of a nonlinearity

σ:RR{displaystyle sigma :mathbb {R} to mathbb {R} }

, where

θ{displaystyle theta }

parametrizes the maps

A0,,AL1{displaystyle A_{0},ldots ,A_{L-1}}

. The parameters

θRP{displaystyle theta in mathbb {R} ^{P}}

are initialized randomly, in an independent, identically distributed way.

As the widths grow, the NTK’s scale is affected by the exact parametrization of the

Ai{displaystyle A_{i}}

‘s and by the parameter initialization. This motivates the so-called NTK parametrization

A(x)=1nW()x+b(){displaystyle A_{ell }left(xright)={frac {1}{sqrt {n_{ell }}}}W^{left(ell right)}x+b^{left(ell right)}}

. This parametrization ensures that if the parameters

θRP{displaystyle theta in mathbb {R} ^{P}}

are initialized as standard normal variables, the NTK has a finite nontrivial limit. In the large-width limit, the NTK converges to a deterministic (non-random) limit

Θ{displaystyle Theta _{infty }}

, which stays constant in time.

The NTK

Θ{displaystyle Theta _{infty }}

is explicitly given by

Θ=Θ(L){displaystyle Theta _{infty }=Theta ^{left(Lright)}}

, where

Θ(L){displaystyle Theta ^{left(Lright)}}

is determined by the set of recursive equations:

where

LKf{displaystyle L_{K}^{f}}

denotes the kernel defined in terms of the Gaussian expectation:

In this formula the kernels

Σ(){displaystyle Sigma ^{left(ell right)}}

are the ANN’s so-called activation kernels.[13][14][15]

Wide fully connected networks are linear in their parameters throughout training[edit]

The NTK describes the evolution of neural networks under gradient descent in function space. Dual to this perspective is an understanding of how neural networks evolve in parameter space, since the NTK is defined in terms of the gradient of the ANN’s outputs with respect to its parameters. In the infinite width limit, the connection between these two perspectives becomes especially interesting. The NTK remaining constant throughout training at large widths co-occurs with the ANN being well described throughout training by its first order Taylor expansion around its parameters at initialization:[10]

Other architectures[edit]

The NTK can be studied for various ANN architectures,[11] in particular convolutional neural networks (CNNs),[16]recurrent neural networks (RNNs) and transformers.[17] In such settings, the large-width limit corresponds to letting the number of parameters grow, while keeping the number of layers fixed: for CNNs, this involves letting the number of channels grow.

Applications[edit]

Convergence to a global minimum[edit]

For a convex loss functional

C{displaystyle {mathcal {C}}}

with a global minimum, if the NTK remains positive-definite during training, the loss of the ANN

C(f(;θ(t))){displaystyle {mathcal {C}}left(fleft(cdot ;theta left(tright)right)right)}

converges to that minimum as

t{displaystyle tto infty }

. This positive-definiteness property has been shown in a number of cases, yielding the first proofs that large-width ANNs converge to global minima during training.[1][8][18]

Kernel methods[edit]

The NTK gives a rigorous connection between the inference performed by infinite-width ANNs and that performed by kernel methods: when the loss function is the least-squares loss, the inference performed by an ANN is in expectation equal to the kernel ridge regression (with zero ridge) with respect to the NTK

Θ{displaystyle Theta _{infty }}

. This suggests that the performance of large ANNs in the NTK parametrization can be replicated by kernel methods for suitably chosen kernels.[1][11]

Software libraries[edit]

Neural Tangents is a free and open-source Python library used for computing and doing inference with the infinite width NTK and neural network Gaussian process (NNGP) corresponding to various common ANN architectures.[19]

References[edit]

  1. ^ a b c d Jacot, Arthur; Gabriel, Franck; Hongler, Clement (2018), Bengio, S.; Wallach, H.; Larochelle, H.; Grauman, K. (eds.), “Neural Tangent Kernel: Convergence and Generalization in Neural Networks” (PDF), Advances in Neural Information Processing Systems 31, Curran Associates, Inc., pp. 8571–8580, arXiv:1806.07572, Bibcode:2018arXiv180607572J, retrieved 2019-11-27
  2. ^ Li, Yuanzhi; Liang, Yingyu (2018). “Learning overparameterized neural networks via stochastic gradient descent on structured data”. Advances in Neural Information Processing Systems. arXiv:1808.01204.
  3. ^ Allen-Zhu, Zeyuan; Li, Yuanzhi; Song, Zhao (2018). “A convergence theory for deep learning via overparameterization”. International Conference on Machine Learning. arXiv:1811.03962.
  4. ^ Du, Simon S; Zhai, Xiyu; Poczos, Barnabas; Aarti, Singh (2019). “Gradient descent provably optimizes over-parameterized neural networks”. International Conference on Learning Representations. arXiv:1810.02054.
  5. ^ Zou, Difan; Cao, Yuan; Zhou, Dongruo; Gu, Quanquan (2020). “Gradient descent optimizes over-parameterized deep ReLU networks”. Machine learning. 109: 467–492.
  6. ^ Novak, Roman; Bahri, Yasaman; Abolafia, Daniel A.; Pennington, Jeffrey; Sohl-Dickstein, Jascha (2018-02-15). “Sensitivity and Generalization in Neural Networks: an Empirical Study”. arXiv:1802.08760. Bibcode:2018arXiv180208760N.
  7. ^ Canziani, Alfredo; Paszke, Adam; Culurciello, Eugenio (2016-11-04). “An Analysis of Deep Neural Network Models for Practical Applications”. arXiv:1605.07678. Bibcode:2016arXiv160507678C.
  8. ^ a b Allen-Zhu, Zeyuan; Li, Yuanzhi; Song, Zhao (2018-11-09). “A Convergence Theory for Deep Learning via Over-Parameterization”. International Conference on Machine Learning: 242–252. arXiv:1811.03962.
  9. ^ Du, Simon; Lee, Jason; Li, Haochuan; Wang, Liwei; Zhai, Xiyu (2019-05-24). “Gradient Descent Finds Global Minima of Deep Neural Networks”. International Conference on Machine Learning: 1675–1685. arXiv:1811.03804.
  10. ^ a b Lee, Jaehoon; Xiao, Lechao; Schoenholz, Samuel S.; Bahri, Yasaman; Novak, Roman; Sohl-Dickstein, Jascha; Pennington, Jeffrey (2020). “Wide neural networks of any depth evolve as linear models under gradient descent”. Journal of Statistical Mechanics: Theory and Experiment. 2020 (12): 124002. arXiv:1902.06720. Bibcode:2020JSMTE2020l4002L. doi:10.1088/1742-5468/abc62b. S2CID 62841516.
  11. ^ a b c Arora, Sanjeev; Du, Simon S; Hu, Wei; Li, Zhiyuan; Salakhutdinov, Russ R; Wang, Ruosong (2019), “On Exact Computation with an Infinitely Wide Neural Net”, NeurIPS: 8139–8148, arXiv:1904.11955
  12. ^ Huang, Jiaoyang; Yau, Horng-Tzer (2019-09-17). “Dynamics of Deep Neural Networks and Neural Tangent Hierarchy”. arXiv:1909.08156 [cs.LG].
  13. ^ Cho, Youngmin; Saul, Lawrence K. (2009), Bengio, Y.; Schuurmans, D.; Lafferty, J. D.; Williams, C. K. I. (eds.), “Kernel Methods for Deep Learning” (PDF), Advances in Neural Information Processing Systems 22, Curran Associates, Inc., pp. 342–350, retrieved 2019-11-27
  14. ^ Daniely, Amit; Frostig, Roy; Singer, Yoram (2016), Lee, D. D.; Sugiyama, M.; Luxburg, U. V.; Guyon, I. (eds.), “Toward Deeper Understanding of Neural Networks: The Power of Initialization and a Dual View on Expressivity” (PDF), Advances in Neural Information Processing Systems 29, Curran Associates, Inc., pp. 2253–2261, arXiv:1602.05897, Bibcode:2016arXiv160205897D, retrieved 2019-11-27
  15. ^ Lee, Jaehoon; Bahri, Yasaman; Novak, Roman; Schoenholz, Samuel S.; Pennington, Jeffrey; Sohl-Dickstein, Jascha (2018-02-15). “Deep Neural Networks as Gaussian Processes”.
  16. ^ Yang, Greg (2019-02-13). “Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation”. arXiv:1902.04760 [cs.NE].
  17. ^ Hron, Jiri; Bahri, Yasaman; Sohl-Dickstein, Jascha; Novak, Roman (2020-06-18). “Infinite attention: NNGP and NTK for deep attention networks”. International Conference on Machine Learning. 2020. arXiv:2006.10540. Bibcode:2020arXiv200610540H.
  18. ^ Allen-Zhu, Zeyuan; Li, Yuanzhi; Song, Zhao (2018-10-29). “On the convergence rate of training recurrent neural networks”. NeurIPS. arXiv:1810.12065.
  19. ^ Novak, Roman; Xiao, Lechao; Hron, Jiri; Lee, Jaehoon; Alemi, Alexander A.; Sohl-Dickstein, Jascha; Schoenholz, Samuel S. (2019-12-05), “Neural Tangents: Fast and Easy Infinite Neural Networks in Python”, International Conference on Learning Representations (ICLR), vol. 2020, arXiv:1912.02803, Bibcode:2019arXiv191202803N

External links[edit]

after-content-x4