硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

機器之心原創

作者:蔣思源

這是一篇神奇的論文,以前一層一層疊加的神經網路似乎突然變得連續了,反向傳播也似乎不再需要一點一點往前傳、一層一層更新引數了。

在最近結束的 NeruIPS 2018 中,來自多倫多大學的陳天琦等研究者成為最佳論文的獲得者。他們提出了一種名為神經常微分方程的模型,這是新一類的深度神經網路。神經常微分方程不拘於對已有架構的修修補補,它完全從另外一個角度考慮如何以連續的方式藉助神經網路對資料建模。在陳天琦的講解下,機器之心將向各位讀者介紹這一令人興奮的神經網路新家族。

在與機器之心的訪談中,陳天琦的導師 David Duvenaud 教授談起這位學生也是讚不絕口。Duvenaud 教授認為陳天琦不僅是位理解能力超強的學生,鑽研起問題來也相當認真透徹。他說:「天琦很喜歡提出新想法,他有時會在我提出建議一週後再反饋:『老師你之前建議的方法不太合理。但是我研究出另外一套合理的方法,結果我也做出來了。』」Ducenaud 教授評價道,現如今人工智慧熱度有增無減,教授能找到優秀博士生基本如同「雞生蛋還是蛋生雞」的問題,頂尖學校的教授通常能快速地招納到博士生,「我很幸運地能在事業起步階段就遇到陳天琦如此優秀的學生。」

本文主要介紹神經常微分方程背後的細想與直觀理解,很多延伸的概念並沒有詳細解釋,例如大大降低計算複雜度的連續型流模型和官方 PyTorch 程式碼實現等。這一篇文章重點對比了神經常微分方程(ODEnet)與殘差網路,我們不僅能透過這一部分了解如何從熟悉的 ResNet 演化到 ODEnet,同時還能還有新模型的前向傳播過程和特點。

其次文章比較關注 ODEnet 的反向傳播過程,即如何透過解常微分方程直接把梯度求出來。這一部分與傳統的反向傳播有很多不同,因此先理解反向傳播再看原始碼可能是更好的選擇。值得注意的是,ODEnet 的反傳只有常數級的記憶體佔用成本。

ODEnet 的 PyTorch 實現地址:https://github。com/rtqichen/torchdiffeq

ODEnet 論文地址:https://arxiv。org/abs/1806。07366

如下展示了文章的主要結構:

常微分方程

從殘差網路到微分方程

從微分方程到殘差網路

網路對比

神經常微分方程

反向傳播

反向傳播怎麼做

連續型的歸一化流

變數代換定理

作者:蔣思源

其實初讀這篇論文,還是有一些疑惑的,因為很多概念都不是我們所熟知的。因此如果想要了解這個模型,那麼同學們,你們首先需要回憶高數上的微分方程。有了這樣的概念後,我們就能愉快地連續化神經網路層級,並構建完整的神經常微分方程。

常微分方程即只包含單個自變數 x、未知函式 f(x) 和未知函式的導數 f‘(x) 的等式,所以說 f’(x) = 2x 也算一個常微分方程。但更常見的可以表示為 df(x)/dx = g(f(x), x),其中 g(f(x), x) 表示由 f(x) 和 x 組成的某個表示式,這個式子是擴充套件一般神經網路的關鍵,我們在後面會討論這個式子怎麼就連續化了神經網路層級。

一般對於常微分方程,我們希望解出未知的 f(x),例如 f‘(x) = 2x 的通解為 f(x)=x^2 +C,其中 C 表示任意常數。而在工程中更常用數值解,即給定一個初值 f(x_0),我們希望解出末值 f(x_1),這樣並不需要解出完整的 f(x),只需要一步步逼近它就行了。

現在回過頭來討論我們熟悉的神經網路,本質上不論是全連線、迴圈還是卷積網路,它們都類似於一個非常複雜的複合函式,複合的次數就等於層級的深度。例如兩層全連線網路可以表示為 Y=f(f(X, θ1), θ2),因此每一個神經網路層級都類似於萬能函式逼近器。

因為整體是複合函式,所以很容易接受複合函式的求導方法:鏈式法則,並將梯度從最外一層的函式一點點先向裡面層級的函式傳遞,並且每傳到一層函式,就可以更新該層的引數 θ。現在問題是,我們前向傳播過後需要保留所有層的啟用值,並在沿計算路徑反傳梯度時利用這些啟用值。這對記憶體的佔用非常大,因此也就限制了深度模型的訓練過程。

神經常微分方程走了另一條道路,它使用神經網路引數化隱藏狀態的導數,而不是如往常那樣直接引數化隱藏狀態。這裡引數化隱藏狀態的導數就類似構建了連續性的層級與引數,而不再是離散的層級。因此引數也是一個連續的空間,我們不需要再分層傳播梯度與更新引數。總而言之,神經微分方程在前向傳播過程中不儲存任何中間結果,因此它只需要近似常數級的記憶體成本。

常微分方程

殘差網路是一類特殊的卷積網路,它透過殘差連線而解決了梯度反傳問題,即當神經網路層級非常深時,梯度仍然能有效傳回輸入端。下圖為原論文中殘差模組的結構,殘差塊的輸出結合了輸入資訊與內部卷積運算的輸出資訊,這種殘差連線或恆等對映表示深層模型至少不能低於淺層網路的準確度。這樣的殘差模組堆疊幾十上百個就是非常深的殘差神經網路。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

如果我們將上面的殘差模組更加形式化地表示為以下方程:

其中 h_t 是第 t 層隱藏單元的輸出值,f 為透過θ_t 引數化的神經網路。該方程式表示上圖的整個殘差模組,如果我們其改寫為殘差的形式,即 h_t+1 - h_t = f(h_t, θ_t )。那麼我們可以看到神經網路 f 引數化的是隱藏層之間的殘差,f 同樣不是直接引數化隱藏層。

ResNet 假設層級的離散的,第 t 層到第 t+1 層之間是無定義的。那麼如果這中間是有定義的呢?殘差項 h_t0 - h_t1 是不是就應該非常小,以至於接近無窮小?這裡我們少考慮了分母,即殘差項應該表示為 (h_t+1 - h_t )/1,分母的 1 表示兩個離散的層級之間相差 1。所以再一次考慮層級間有定義,我們會發現殘差項最終會收斂到隱藏層對 t 的導數,而神經網路實際上引數化的就是這個導數。

所以若我們在層級間加入更多的層,且最終趨向於添加了無窮層時,神經網路就連續化了。可以說殘差網路其實就是連續變換的尤拉離散化,是一個特例,我們可以將這種連續變換形式化地表示為一個常微分方程:

如果從導數定義的角度來看,當 t 的變化趨向於無窮小時,隱藏狀態的變化 dh(t) 可以透過神經網路建模。當 t 從初始一點點變化到終止,那麼 h(t) 的改變最終就代表著前向傳播結果。這樣利用神經網路引數化隱藏層的導數,就確確實實連續化了神經網路層級。

現在若能得出該常微分方程的數值解,那麼就相當於完成了前向傳播。具體而言,若 h(0)=X 為輸入影象,那麼終止時刻的隱藏層輸出 h(T) 就為推斷結果。這是一個常微分方程的初值問題,可以直接透過黑箱的常微分方程求解器(ODE Solver)解出來。而這樣的求解器又能控制數值誤差,因此我們總能在計算力和模型準確度之間做權衡。

形式上來說,現在就需要變換方程 (2) 以求出數值解,即給定初始狀態 h(t_0) 和神經網路的情況下求出終止狀態 h(t_1):

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

如上所示,常微分方程的數值解 h(t_1) 需要求神經網路 f 從 t_0 到 t_1 的積分。我們完全可以利用 ODE solver 解出這個值,這在數學物理領域已經有非常成熟的解法,我們只需要將其當作一個黑盒工具使用就行了。

從殘差網路到微分方程

前面提到過殘差網路是神經常微分方程的特例,可以說殘差網路是尤拉方法的離散化。兩三百年前解常微分方程的尤拉法非常直觀,即 h(t +Δt) = h(t) + Δt×f(h(t), t)。每當隱藏層沿 t 走一小步Δt,新的隱藏層狀態 h(t +Δt) 就應該近似在已有的方向上邁一小步。如果這樣一小步一小步從 t_0 走到 t_1,那麼就求出了 ODE 的數值解。

如果我們令 Δt 每次都等於 1,那麼離散化的尤拉方法就等於殘差模組的表示式 h(t+1) = h(t) + f(h(t), t)。但是尤拉法只是解常微分方程最基礎的方法,它每走一步都會產生一點誤差,且誤差會累積起來。近百年來,數學家構建了很多現代 ODE 求解方法,它們不僅能保證收斂到真實解,同時還能控制誤差水平。

陳天琦等研究者構建的 ODE 網路就使用了一種適應性的 ODE solver,它不像尤拉法移動固定的步長,相反它會根據給定的誤差容忍度選擇適當的步長逼近真實解。如下圖所示,左邊的殘差網路定義有限轉換的離散序列,它從 0 到 1 再到 5 是離散的層級數,且在每一層透過啟用函式做一次非線性轉換。此外,黑色的評估位置可以視為神經元,它會對輸入做一次轉換以修正傳遞的值。而右側的 ODE 網路定義了一個向量場,隱藏狀態會有一個連續的轉換,黑色的評估點也會根據誤差容忍度自動調整。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

從微分方程到殘差網路

在 David 的 Oral 演講中,他以兩段虛擬碼展示了 ResNet 與 ODEnet 之間的差別。如下展示了 ResNet 的主要過程,其中 f 可以視為卷積層,ResNet 為整個模型架構。在卷積層 f 中,h 為上一層輸出的特徵圖,t 確定目前是第幾個卷積層。ResNet 中的迴圈體為殘差連線,因此該網路一共 T 個殘差模組,且最終返回第 T 層的輸出值。

deff(h, t, θ):

return nnet(h, θ_t)

defresnet(h):

for t in [1:T]:

h = h + f(h, t, θ)

return h

相比常見的 ResNet,下面的虛擬碼就比較新奇了。首先 f 與前面一樣定義的是神經網路,不過現在它的引數θ是一個整體,同時 t 作為獨立引數也需要饋送到神經網路中,這表明層級之間也是有定義的,它是一種連續的網路。而整個 ODEnet 不需要透過迴圈搭建離散的層級,它只要透過 ODE solver 求出 t_1 時刻的 h 就行了。

deff(h, t, θ):

return nnet([h, t], θ)

defODEnet(h, θ):

return ODESolver(f, h, t_0, t_1, θ)

除了計算過程不一樣,陳天琦等研究者還在 MNSIT 測試了這兩種模型的效果。他們使用帶有 6 個殘差模組的 ResNet,以及使用一個 ODE Solver 代替這些殘差模組的 ODEnet。以下展示了不同網路在 MNSIT 上的效果、引數量、記憶體佔用量和計算複雜度。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

其中單個隱藏層的 MLP 引用自 LeCun 在 1998 年的研究,其隱藏層只有 300 個神經元,但是 ODEnet 在有相似引數量的情況下能獲得顯著更好的結果。上表中 L 表示神經網路的層級數,L tilde 表示 ODE Solver 中的評估次數,它可以近似代表 ODEnet 的「層級深度」。值得注意的是,ODEnet 只有常數級的記憶體佔用,這表示不論層級的深度如何增加,它的記憶體佔用基本不會有太大的變化。

網路對比

在與 ResNet 的類比中,我們基本上已經瞭解了 ODEnet 的前向傳播過程。首先輸入資料 Z(t_0),我們可以透過一個連續的轉換函式(神經網路)對輸入進行非線性變換,從而得到 f。隨後 ODESolver 對 f 進行積分,再加上初值就可以得到最後的推斷結果。如下所示,殘差網路只不過是用一個離散的殘差連線代替 ODE Solver。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

在前向傳播中,ODEnet 還有幾個非常重要的性質,即模型的層級數與模型的誤差控制。首先因為是連續模型,其並沒有明確的層級數,因此我們只能使用相似的度量確定模型的「深度」,作者在這篇論文中採用 ODE Solver 評估的次數作為深度。

其次,深度與誤差控制有著直接的聯絡,ODEnet 透過控制誤差容忍度能確定模型的深度。因為 ODE Solver 能確保在誤差容忍度之內逼近常微分方程的真實解,改變誤差容忍度就能改變神經網路的行為。一般而言,降低 ODE Solver 的誤差容忍度將增加函式的評估的次數,因此類似於增加了模型的「深度」。調整誤差容忍度能允許我們在準確度與計算成本之間做權衡,因此我們在訓練時可以採用高準確率而學習更好的神經網路,在推斷時可以根據實際計算環境調整為較低的準確度。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

如原論文的上圖所示,a 圖表示模型能保證在誤差範圍為內,且隨著誤差降低,前向傳播的函式評估數增加。b 圖展示了評估數與相對計算時間的關係。d 圖展示了函式評估數會隨著訓練的增加而自適應地增加,這表明隨著訓練的進行,模型的複雜度會增加。

c 圖比較有意思,它表示前向傳播的函式評估數大致是反向傳播評估數的一倍,這恰好表示反向傳播中的 adjoint sensitivity 方法不僅記憶體效率高,同時計算效率也比直接透過積分器的反向傳播高。這主要是因為 adjoint sensitivity 並不需要依次傳遞到前向傳播中的每一個函式評估,即梯度不透過模型的深度由後向前一層層傳。

神經常微分方程

師從同門的 Jesse Bettencourt 向機器之心介紹道,「天琦最擅長的就是耐心講解。」當他遇到任何無論是程式碼問題,理論問題還是數學問題,一旦是問了同桌的天琦,對方就一定會慢慢地花時間把問題講清楚、講透徹。而 ODEnet 的反向傳播,就是這樣一種需要耐心講解的問題。

ODEnet 的反向傳播與常見的反向傳播有一些不同,我們可能需要仔細查閱原論文與對應的附錄證明才能有較深的理解。此外,作者給出了 ODEnet 的 PyTorch 實現,我們也可以透過它瞭解實現細節。

正如作者而言,訓練一個連續層級網路的主要技術難點在於令梯度穿過 ODE Solver 的反向傳播。其實如果令梯度沿著前向傳播的計算路徑反傳回去是非常直觀的,但是記憶體佔用會比較大而且數值誤差也不能控制。作者的解決方案是將前向傳播的 ODE Solver 視為一個黑箱操作,梯度很難或根本不需要傳遞進去,只需要「繞過」就行了。

作者採用了一種名為 adjoint method 的梯度計算方法來「繞過」前向傳播中的 ODE Solver,即模型在反傳中透過第二個增廣 ODE Solver 算出梯度,其可以逼近按計算路徑從 ODE Solver 傳遞迴的梯度,因此可用於進一步的引數更新。這種方法如上圖 c 所示不僅在計算和記憶體非常有優勢,同時還能精確地控制數值誤差。

具體而言,若我們的損失函式為 L(),且它的輸入為 ODE Solver 的輸出:

我們第一步需要求 L 對 z(t) 的導數,或者說模型損失的變化如何取決於隱藏狀態 z(t) 的變化。其中損失函式 L 對 z(t_1) 的導數可以為整個模型的梯度計算提供入口。作者將這一個導數稱為 adjoint a(t) = -dL/z(t),它其實就相當於隱藏層的梯度。

在基於鏈式法則的傳統反向傳播中,我們需要從後一層對前一層求導以傳遞梯度。而在連續化的 ODEnet 中,我們需要將前面求出的 a(t) 對連續的 t 進行求導,由於 a(t) 是損失 L 對隱藏狀態 z(t) 的導數,這就和傳統鏈式法則中的傳播概念基本一致。下式展示了 a(t) 的導數,它能將梯度沿著連續的 t 向前傳,附錄 B。1 介紹了該式具體的推導過程。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

在獲取每一個隱藏狀態的梯度後,我們可以再求它們對引數的導數,並更新引數。同樣在 ODEnet 中,獲取隱藏狀態的梯度後,再對引數求導並積分後就能得到損失對引數的導數,這裡之所以需要求積分是因為「層級」t 是連續的。這一個方程式可以表示為:

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

綜上,我們對 ODEnet 的反傳過程主要可以直觀理解為三步驟,即首先求出梯度入口伴隨 a(t_1),再求 a(t) 的變化率 da(t)/dt,這樣就能求出不同時刻的 a(t)。最後藉助 a(t) 與 z(t),我們可以求出損失對引數的梯度,並更新引數。當然這裡只是簡要的直觀理解,更完整的反傳過程展示在原論文的演算法 1。

反向傳播

在演算法 1 中,陳天琦等研究者展示瞭如何藉助另一個 OED Solver 一次性求出反向傳播的各種梯度和更新量。要理解演算法 1,首先我們要熟悉 ODESolver 的表達方式。例如在 ODEnet 的前向傳播中,求解過程可以表示為 ODEsolver(z(t_0), f, t_0, t_1, θ),我們可以理解為從 t_0 時刻開始令 z(t_0) 以變化率 f 進行演化,這種演化即 f 在 t 上的積分,ODESolver 的目標是透過積分求得 z(t_1)。

同樣我們能以這種方式理解演算法 1,我們的目的是利用 ODESolver 從 z(t_1) 求出 z(t_0)、從 a(t_1) 按照方程 4 積出 a(t_0)、從 0 按照方程 5 積出 dL/dθ。最後我們只需要使用 dL/dθ 更新神經網路 f(z(t), t, θ) 就完成了整個反向傳播過程。

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

如上所示,若初始給定引數θ、前向初始時刻 t_0 和終止時刻 t_1、終止狀態 z(t_1) 和梯度入口 L/z(t_1)。接下來我們可以將三個積分都並在一起以一次性解出所有量,因此我們可以定義初始狀態 s_0,它們是解常微分方程的初值。

注意第一個初值 z(t_1),其實在前向傳播中,從 z(t_0) 到 z(t_1) 都已經算過一遍了,但是模型並不會保留計算結果,因此也就只有常數級的記憶體成本。此外,在算 a(t) 時需要知道對應的 z(t),例如 L/z(t_0) 就要求知道 z(t_0) 的值。如果我們不能儲存中間狀態的話,那麼也可以從 z(t_1) 到 z(t_0) 反向再算一遍中間狀態。這個計算過程和前向過程基本一致,即從 z(t_1) 開始以變化率 f 進行演化而推出 z(t_0)。

定義 s_0 後,我們需要確定初始狀態都是怎樣「演化」到終止狀態的,定義這些演化的即前面方程 (3)、(4) 和 (5) 的被積函式,也就是演算法 1 中 aug_dynamics() 函式所定義的。

其中 f(z(t), t, θ) 從 t_1 到 t_0 積出來為 z(t_0),這第一個常微分方程是為了給第二個提供條件。而-a(t)*L/z(t) 從 t_1 到 t_0 積出來為 a(t_0),它類似於傳統神經網路中損失函式對第一個隱藏層的導數,整個 a(t) 就相當於隱藏層的梯度。只有獲取積分路徑中所有隱藏層的梯度,我們才有可能進一步解出損失函式對引數的梯度。

因此反向傳播中的第一個和第二個常微分方程 都是為第三個微分方程提供條件,即 a(t) 和 z(t)。最後,從 t_1 到 t_0 積分 -a(t)*f(z(t), t, θ)/θ 就能求出 dL/dθ。只需要一個積分,我們不再一層層傳遞梯度並更新該層特定的引數。

如下虛擬碼所示,完成反向傳播的步驟很簡單。先定義各變數演化的方法,再結合將其結合初始化狀態一同傳入 ODESolver 就行了。

deff_and_a([z, a], t):

return[f, -a*df/da, -a*df/dθ]

[z0, dL/dx, dL/dθ] =

ODESolver([z(t1), dL/dz(t), 0], f_and_a, t1, t0)

反向傳播怎麼做

這種連續型轉換有一個非常重要的屬性,即流模型中最基礎的變數代換定理可以便捷快速地計算得出。在論文的第四節中,作者根據這樣的推導結果構建了一個新型可逆密度模型,它能克服 Glow 等歸一化流模型的缺點,並直接透過最大似然估計訓練。

連續型的歸一化流

對於機率密度估計中的變數代換定理,我們可以從單變數的情況開始。若給定一個隨機變數 z 和它的機率密度函式 zπ(z),我們希望使用對映函式 x=f(z) 構建一個新的隨機變數。函式 f 是可逆的,即 z=g(x),其中 f 和 g 互為逆函式。現在問題是如何推斷新變數的未知機率密度函式 p(x)?

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

透過定義,積分項 ∫π(z)dz 表示無限個無窮小的矩形面積之和,其中積分元Δz 為積分小矩形的寬,小矩形在位置 z 的高為機率密度函式 π(z) 定義的值。若使用 f^1(x) 表示 f(x) 的逆函式,當我們替換變數的時候,z=f^1(x) 需要服從 Δz/Δx=(f^1(x))′。多變數的變數代換定理可以從單變數推廣而出,其中 det f/z 為函式 f 的雅可比行列式:

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

一般使用變數代換定理需要計算雅可比矩陣f/z 的行列式,這是主要的限制,最近的研究工作都在權衡歸一化流模型隱藏層的表達能力與計算成本。但是研究者發現,將離散的層級替換為連續的轉換,可以簡化計算,我們只需要算雅可比矩陣的跡就行了。核心的定理 1 如下所示:

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

在普通的變數代換定理中,分佈的變換函式 f(或神經網路)必須是可逆的,而且要製作可逆的神經網路也很複雜。在陳天琦等研究者定理裡,不論 f 是什麼樣的神經網路都沒問題,它天然可逆,所以這種連續化的模型對流模型的應用應該非常方便。

如下所示,隨機變數 z(t_0) 及其分佈可以透過一個連續的轉換演化到 z(t_1) 及其分佈:

硬核NeruIPS 2018最佳論文,一個神經了的常微分方程

此外,連續型流模型還有很多性質與優勢,但這裡並不展開。變數代換定理 1 在附錄 A 中有完整的證明,感興趣的讀者可查閱原論文了解細節。

最後,神經常微分方程是一種全新的框架,除了流模型外,很多方法在連續變換的改變下都有新屬性,這些屬性可能在離散啟用的情況下很難獲得。也許未來會有很多的研究關注這一新模型,連續化的神經網路也會變得多種多樣。

變數代換定理