Cross Entropy:深度學習分類任務中不可或缺的損失函數
在深度學習分類任務中,Cross Entropy 損失函數用於衡量模型預測分布與真實標籤之間的差距,本文將從熵的本質、數學推導到應用與局限,解析其原理與最佳實踐。
引言
在深度學習領域中,Cross Entropy 是一種極為常見的損失函數 (Loss Function),尤其在處理分類 (Classification) 任務時,它扮演著至關重要的角色。今天,我們將深入探討這個重要的損失函數。
什麼是 Entropy?從熱力學出發
要理解 Cross Entropy,我們先從「Entropy (熵)」這個概念開始。在熱力學中,熵是用來衡量一個系統可能存在的微觀狀態數量,簡單來說,就是描述系統的可能性。系統的可能狀態越多,其熵就越高。
讓我們用一個例子來解釋這個系統中可能存在的狀態是什麼意思:
想像你打算「放生」一隻橡膠小黃鴨。
- 情境一:浴缸中的小黃鴨 如果你將小黃鴨放在家裡的浴缸中,並且不去干預它,那麼十年後,你幾乎可以肯定它仍然會在浴缸裡。這個系統的可能狀態非常有限,因此我們會說他的 Entropy (熵) 很低。
- 情境二:太平洋中央的小黃鴨 但如果你將這隻小黃鴨放到浩瀚的太平洋中央,十年後它又會在何處呢?它可能會隨著洋流漂向世界各地任何一個角落。這個系統的可能狀態變得極其龐大,因此它的熵很高。我們會說,相較於浴缸中的情況,太平洋中小黃鴨的 Entropy (熵) 顯著增加了。
系統可能存在的狀態數量,往往直接反映了其混亂程度或不確定性。試想一下:
一個習慣隨手亂丟衣服的人,比起總是整齊收納衣物的人,他的臭襪子可能出現的地點就多出許多。你的襪子可能在門口、床頭、沙發下,而不是只會出現在洗衣籃這個單一且確定的狀態。當襪子的可能位置分散在各處時,你的房間顯然就呈現出一種混亂且難以預測的狀態,這就是高熵的體現。相反地,總是收納整齊的房間,物品的可能狀態較少且固定,因此呈現低熵的狀態。
訊息熵 (Information Entropy)
資訊理論的奠基者 Claude Shannon 在 1948 年發表劃時代的論文《A Mathematical Theory of Communication》,首次將熵的概念引入資訊領域。他提出了一個核心思想:
一則訊息的不確定性或難以預測性越高,就需要越多的位元 (bit) 來進行編碼,這也意味著它所攜帶的資訊量越大。
Shannon 當時致力於解決無線電傳輸中的編碼效率問題。他發現,透過對高頻率出現的字母使用較短的編碼,而對低頻率字母使用較長的編碼,可以實現更有效率的資訊傳遞。Shannon 的這一洞見巧妙地將機率分布與編碼長度聯繫起來。
\[ I(x) = -\log_2(P(x)) \]
Cross Entropy 是什麼?
在分類任務中,評估模型預測的品質不能僅僅依賴其是否正確,模型對於預測結果的確定程度同樣重要。例如,一個貓狗分類器會透過 Softmax 函數將輸出轉換為總和為 1 的機率分布。然而,即使兩個模型都成功地將圖片分類為貓或狗,它們的輸出也可能呈現顯著差異。一個模型的輸出可能是 [0.9 (貓), 0.1 (狗)],而另一個可能是 [0.51 (貓), 0.49 (狗)]。雖然兩者都做出了正確的判斷,但顯然輸出為 [0.9, 0.1] 的模型對其預測具有更高的信心,這通常也是我們更期望獲得的結果。
基於 Shannon 的發現,我們得以將機率分布量化為數值。這為我們提供了一種量化模型預測與實際標籤資料之間差距的方法,而這個方法就是 Cross Entropy。
簡單來說,Cross Entropy 是一種用來衡量兩個機率分布之間差異的函數。在分類問題中,我們的模型會預測出一個機率分布 (Q),代表著它認為某個樣本屬於各個類別的可能性。而真實的標籤 (P) 則通常是一個「one-hot」分布,也就是只有正確類別的機率是 1,其他類別都是 0。
Cross Entropy 的數學公式是這樣的:
\[ H(P,Q)=−∑P(i)logQ(i) \]
但別被數學公式嚇到!在分類任務的特殊情況下,由於真實標籤 (P) 是 one-hot 編碼,這個公式會變得非常簡潔:
\[ Cross Entropy = -log(模型對正確類別的預測機率) \]
舉個例子,如果我們的真實標籤是第二類(比如「貓」),而模型預測各類別的機率是 [0.1 (狗), 0.7 (貓), 0.2 (鳥)],那麼這個預測的 Cross Entropy 就是:
\[ Cross Entropy = -log(0.7) ≈ 0.3567 \]
這個數值越小,代表模型的預測越接近真實情況。
相同的 Cross Entropy,不同的理解?
一個有趣的現象是,只要模型對正確類別的預測機率相同,即使它對其他類別的預測機率如何分配,Cross Entropy 的值也會一樣。例如:
- 真實標籤:[0, 1, 0]
- 模型 A 預測:[0.1, 0.6, 0.3]
- 模型 B 預測:[0.01, 0.6, 0.39]
在上述的例子中,這兩個模型的 Cross Entropy 都是 -log(0.6)。這是因為 Cross Entropy 主要關注模型對正確答案的信心程度,其他部分的機率分布並不重要。
Cross Entropy 的極端情況
Cross Entropy 對於預測錯誤的情況非常敏感
- 真實標籤:[0, 1, 0]
- 模型預測:[0.9, 0, 0.1]
預測正確類別的機率是 0: Loss 會趨向無限大(因為 log(0) 是無窮小)。這代表模型完全犯錯,損失會非常巨大,促使模型在訓練時極力避免這種情況。
- 真實標籤:[0, 1, 0]
- 模型預測:[0, 1, 0]
預測正確類別的機率是 1: Loss 會達到最小值 0(因為 log(1) = 0)。這代表模型完全正確且非常有信心。
這種對於預測錯誤的敏感性,會讓模型更好從錯誤的預測中學習。
Cross Entropy 的局限性
儘管 Cross Entropy 在分類任務中非常有效,但它也有其局限性。它主要關注模型對正確類別的預測,而忽略了模型對其他錯誤類別的預測分布情況。
這在某些需要考慮整體預測分布的任務中可能會成為問題,例如:
- 語意分類: 我們可能希望模型不僅預測出正確的意圖,還要對其他可能的意圖有合理的預測。
- 近似推論: 我們可能需要模型預測的分布盡可能接近真實的複雜分布,而不僅僅是正確的類別。
如何改進?
針對 Cross Entropy 的這些局限性,研究人員也提出了改進的方法:
- KL Divergence (KL 散度): 這個指標可以直接衡量兩個機率分布的整體差異,不僅考慮了模型對正確類別的預測。
- Label Smoothing (標籤平滑): 這個技巧並非直接修改 Loss 函數,而是將原本硬性的 0-1 真實標籤轉換為一個更連續的分布,例如 [0.1, 0.8, 0.1]。這樣可以避免模型過度自信地預測單一類別,鼓勵模型學習到更多類別之間的關係。
Cross Entropy vs MSE:分類問題選誰?
你可能會好奇,為什麼分類問題更常用 Cross Entropy 而不是均方誤差 (Mean Squared Error, MSE) 呢?讓我們來比較一下:
MSE 的缺點:
- 梯度消失: 對於分類問題,特別是在模型輸出的機率接近 0 或 1 時,MSE 的梯度可能會變得非常小,導致訓練緩慢甚至停滯(梯度消失)。
- 錯誤懲罰不合理: MSE 對於預測錯誤的懲罰是線性的,無法很好地區分「錯一點」和「錯很多」的情況。例如,將貓預測成狗和將貓預測成飛機在 MSE 看來可能差異不大,但實際上後者是更離譜的錯誤。
實例對比:
假設真實標籤是第二類,模型預測為 [0.1, 0.7, 0.2]。
- Cross Entropy ≈ 0.3567
- MSE ≈ 0.0867
雖然 MSE 的 Loss 看起來更小,但它並不能真正反映出模型在分類上的表現,特別是當預測非常離譜時,MSE 的 Loss 可能仍然很小。