如何直觀地理解條件隨機場,並透過PyTorch簡單地實現

條件隨機場是一種無向圖模型,且相對於深度網路有非常多的優勢,因此現在很多研究者結合條件隨機場(CRF)與深度網路獲得更魯棒和可解釋的模型。本文結合 PyTorch 從基本的機率定義到模型實現直觀地介紹了 CRF 的基本概念,有助於讀者進一步理解完整理論。

假設我們有兩個相同的骰子,但是其中的一個是公平的,每個點數出現的機率相同;另一個骰子則被做了手腳,數字 6 出現的機率為 80%,而數字 1-5 出現的機率都為 4%。如果我給你一個 15 次投擲骰子的序列,你能預測出我每次投擲用的是哪一枚骰子嗎?

如何直觀地理解條件隨機場,並透過PyTorch簡單地實現

為了得到較高的準確率,一個簡單的模型是,每當「6」出現的時候,我們那就預測使用了有偏的骰子,而出現其他數字時則預測使用了公平的骰子。實際上,如果我們在每次投擲時等可能地使用任意一個骰子,那麼這個簡單的規則就是你可以做到的最好預測。

但是,設想一種情況:如果在使用了公平的骰子後,我們下一次投擲時使用有偏的骰子的機率為 90%,結果會怎樣呢?如果下一次投擲出現了一個「3」,上述模型會預測我們使用了公平的骰子,但是實際上我們使用有偏的骰子是一個可能性更大的選項。我們可以透過貝葉斯定理來進行驗證這個說法:

如何直觀地理解條件隨機場,並透過PyTorch簡單地實現

其中隨機變數 y_i 是第 i 次投擲所用的骰子型別,x_i 是第 i 次投擲得到的點數。

我們的結論是,在每一步中作出可能性最大的選擇只是可行策略之一,因為我們同時可能選擇其它的骰子。更有可能的情況是,以前對骰子的選擇情況影響了我未來會做出怎樣的選擇。為了成功地進行預測,你將不得不考慮到每次投擲之間的相互依賴關係。

條件隨機場(CRF)是一個用於預測與輸入序列相對應標註序列的標準模型。目前有許多關於條件隨機場的教程,但是我所看到的教程都會陷入以下兩種情況其中之一:1)全都是理論,但沒有展示如何實現它們 2)為複雜的機器學習問題編寫的程式碼缺少解釋,不能令讀者對程式碼有直觀的理解。

之所以這些作者選擇寫出全是理論或者包含可讀性很差的程式碼教程,是因為條件隨機場從屬於一個更廣更深的課題「機率圖模型」。所以要想深入涵蓋其理論和實現可能需要寫一本書,而不是一篇博文,這種情況也使得學習條件隨機場的知識比它原本所需要的更困難。

本教程的目標是涵蓋恰到好處的理論知識,以便你能對 CRF 有一個基本的印象。此外我們還會透過一個簡單的問題向你展示如何實現條件隨機場,你可以在自己的膝上型電腦上覆現它。這很可能讓你具有將這個簡單的條件隨機場示例加以改造,用於更復雜問題所需要的直觀理解。

理論

我們對於理論的討論將分為三個部分:1)指定模型引數 2)如何估計這些引數 3)利用這些引數進行預測,這三大類適用於任何統計機器學習模型。因此從這個意義上說,條件隨機場並沒有什麼特別的,但這並不意味著條件隨機場就和 logistic 迴歸模型一樣簡單。我們會發現,一旦我們要面對一連串的預測而不是單一的預測,事情就會變得更加複雜。

指定模型引數

在這個簡單的問題中,我們需要擔心的唯一的引數就是與從一次投擲轉換到下一次投擲狀態的分佈。我們有六種狀態需要考慮,因此我們將它們儲存在一個 2*3 的「轉移矩陣」中。

如何直觀地理解條件隨機場,並透過PyTorch簡單地實現

第一列對應於「從前一次投擲使用公平骰子的狀態,轉換到當前使用公平骰子狀態的機率或成本(第一行的值),或轉換到有偏骰子狀態的機率(第二行的值)」。因此,第一列中的第一個元素編碼了在給定我本次投擲使用了公平骰子的前提下,預測下一次投擲使用公平骰子的機率。如果資料顯示,我不太可能在連續使用公平骰子,模型會學習到這個機率應該很低,反之亦然。同樣的邏輯也適用於第二列。

矩陣的第一和第二列假設我們知道在前一次投擲中使用了哪個骰子,因此我們必須將第一次投擲作為一個特例來對待。我們將把相應的機率儲存在第三列中。

引數估計

假設給定一個投擲的集合 X* *以及它們相應的骰子標籤 Y。我們將會找到使整個訓練資料的負對數似然最小的轉移矩陣 T。我將會向你展示單個骰子投擲序列的似然和負對數似然是什麼樣的。為了在整個資料集上得到它,你要對所有的序列取平均。

P(x_i | y_i) 是在給定當前的骰子標籤的前提條件下,觀測到一個給定骰子投擲點數的機率。舉例而言,如果 y_i 為公平骰子,則 P(x_i | y_i) = 1/6。另一項 T(y_i | y_{i-1}) 是從上一個骰子標籤轉換到當前投資標籤的機率,我們可以直接從轉移矩陣中讀取出這個機率。

請注意在分母中,我們是怎樣在所有可能標籤 y‘ 的序列上進行求和的。在傳統的二分類問題 logistic 迴歸中,我們在分母中會有兩個項。但是現在,我們要處理的是標註序列,並且對於一個長度為 15 的序列來說,一共有 2^15 種可能的標籤序列,所以分母項是十分巨大的。條件隨機場的「秘密武器」是,它假定當前的骰子標籤僅僅只取決於之前的骰子標籤,來高效地計算這個大規模求和。

這個秘密武器被稱為「前向-後向演算法」。對該演算法的深入討論超出了這篇博文的範圍,因此這裡不做詳細的解釋。

序列預測

一旦我們估計出了我們的轉移矩陣,我們可以使用它去找到在給定一個投擲序列的條件下,最有可能的骰子標註序列。要做到這一點,最簡單的方法就是計算出所有可能的序列的似然,但這即使對於中等長度的序列也是十分困難的。正如我們在引數估計中所做的那樣,我們將不得不用一種特殊的演算法高效地搜尋可能性最大的序列。這個演算法與「向前-向後演算法」很相近,它被稱為「維特比演算法」。

具體實現參見:https://mp。weixin。qq。com/s/1KAbFAWC3jgJTE-zp5Qu6g