• <xmp id="om0om">
  • <table id="om0om"><noscript id="om0om"></noscript></table>
  • 加速計算

    探索 FP8 訓練中 Debug 思路與技巧

    目前,市場上許多公司都積極開展基于 FP8 的大模型訓練,以提高計算效率和性能。 在此,我們整理并總結了客戶及 NVIDIA 技術團隊在 FP8 模型訓練過程中的 debug 思路和方法,供大家參考。

    在討論之前,建議大家使用我們推薦的 FP8 訓練的 Recipe,即使用 Delayed scaling,在History length為1024的窗口中選取最大的amax數值作為計算scaling factor的方法。當然,我們也在不斷優化這個 Recipe,未來隨著更多 FP8 的實踐案例,將繼續為大家總結和分享,期待共同探索和優化 debug 的思路和方案。

    在收集和整理了大量 FP8 訓練的案例后,我們發現,FP8 訓練中遇到的問題一般可以分成以下三類:

    第一類問題:Spike Issue

    Spike Issue 其實并不是 FP8 訓練所特有的,在 BF16 中也可能會遇到此類問題,并且實際上根據 NVIDIA 技術團隊內部訓練的一些曲線,可以看到 FP8 的 Spike Issue 要比 BF16 還要小一些。所以,如果遇到了Spike Issue,很多情況下可以暫時不用特別關注 FP8。另外,這里推薦兩篇關于 Spike 的研究,供大家參考。

    整體上,如果我們遇到的 Spike 和曾經在 BF16 上遇到的差不多,這種情況很可能不是 FP8 的問題。當然,也有例外的情況,比如我們遇到的 Spike 需要很多迭代步才能夠恢復正常,那這種情況下可以說明這個 loss 和 BF16 有本質上的差異, 可以考慮是第二類問題。

    第二類問題:FP8 Loss BF16不匹配或者發散

    在 Validation loss 曲線上,不論是預訓練還是 SFT,如果有 BF16 作為 Baseline,并且可以看到 FP8 和 BF16 有差距,這種情況下應該如何處理?

    一般這類問題可以分成兩種情況,包括:

    • 情況 1:在訓練的初始階段,不論是 Train from scratch 還是 Continue train,如果剛切換到 FP8 進行訓練,一開始就出現了 Loss 比較大或者直接跑飛,這種情況下大概率是軟件問題造成的,因此建議大家使用 NVIDIA 最新的 Transformer Engine 和 Megatron Core 的軟件棧,這樣很多軟件的問題可以及時被修復,從而讓大家少跑一些彎路。同時還有另外一種情況,在軟件不斷的更新過程中,為了性能的優化會增加很多新的特性。如果一些特性是剛剛加入的,可能在 FP8 上暫時還沒有遇到特殊情況,因此建議,大家如果使用了一些很新的特性,屆時可以先嘗試關閉掉這些新特性,檢查是否是由于這些新特性的實現不夠完善造成 Loss 的問題。
    • 情況 2:我們已經訓練了一段時間,比如已經訓練了幾百 Billion 的 Tokens,Loss 出現了差距,這種情況一般就不是軟件問題了。問題可能是給大家推薦的這個 Recipe 并不適用于某些數據集或某些模型結構。這種情況下,可以通過下面的案例去進行拆解。

    第三類問題:FP8 loss 非常吻合,但是 Downstream tasks 會有一些差異

    訓練中,我們的 Validation loss 曲線吻合的非常好,比如 Loss 差距的量級大概是在十的負三次方,但是在一些下游任務上打分可能會出現問題,那應該如何處理?這樣的問題一般分為兩種情況,包括:

    • 情況 1:進行下游任務打分的時候,會進行多任務打分。如果所有的任務和 BF16 baseline 對比,或者和當時上一代的模型對比,打分結果差異很大,這種情況大概率是評估過程中出現了問題。比如,Checkpoint 導出來的格式不對,或者 Scale 沒有取對,等評估流程的問題。因此我們還需要進行排除,確認是否是導出模型和評估流程出現了問題。
    • 情況 2:另一種情況,如前文提到的“在訓練了幾百 Billion 的 Token 之后,Loss 出現了差距”,和這種情況很相似,此時大部分任務都沒問題,只有個別的一兩個任務發現跟 BF16 的 Baseline 有明顯差距,如 3% 或者 5% 的掉點。這種情況下,建議改變 FP8 訓練的 Recipe,默認的 Recipe 是 Delayed scaling,即選用先前迭代步存下來的scale值,我們可以替換成 Current scaling,即選用當前迭代步的scale值,或者把部分的矩陣做一些回退到 BF16 的操作,具體方法下文會進行介紹。

    以下是一個案例,通過這個案例,可以初步了解哪些方法在現階段可以進行嘗試。

    這是一個類似于 Llama 2 的模型,雖然模型規模較小,但已經訓練了 1.1T 個 Tokens,使用了如下的推薦的配置,包括:

    • Pytorch 23.10 版本
    • TE Commit 為 d76118d
    • FP8 format:hybird
    • History Length:1024
    • Algo:Max
    • FP8 Wgrad Override:True

    我們發現,比較接近 Loss 末尾的時候,差異就會隨之出現,并且顯然已經不是十的負三次方的量級,這種情況下,可以考慮以下的步驟進行問題的排查。

    第一步:Sequence Parallel off

    在軟件前期的時候,首先盡可能嘗試關閉一些根據經驗判斷可能有問題的特性。比如在引入 FP8 初期,軟件上的 Sequence Paralleism(SP)經常會引起一些問題,因此可以先嘗試進行關閉,如果發現關閉后并沒有問題,可以初步判斷 Loss 不是由軟件引起的,從而大概率可以推斷是 Recipe 不夠完善造成的。

    第二步:我們可以做一個恢復性實驗

    嘗試看一下當前訓練出現問題的 FP8 的 Checkpoint,比如最后一個點,把這個 Checkpoint 切換到 BF16 訓練,查看是否可以恢復到 BF16 的 Baseline。我們目前遇到的的大多數情況都是可以恢復的。因此在這個基礎的情況下,可以繼續嘗試下一步 debug 的方法。

    • 第三步:三類矩陣的問題排查

    大多數情況下,整個模型跑在FP8上的并不多見。對于 Transformer layer 的每個 Gemm 來說,整個訓練過程中,有三類矩陣跑在 FP8 上,包括它的前向 Fprop,以及反向 Wgrad 和 Dgrad,因此現在需要判斷三類矩陣的哪個矩陣出了問題?當然,更細致一些應該判斷具體是哪一個 Transformer layer 的矩陣出了問題。不過,這個特性還在開發過程中,目前還是一個比較初步的判斷,需要檢查是前向的矩陣還是反向的兩個矩陣其中之一出現了差錯。因此這一步中,可以首先把這三類矩陣全部轉成 BF16 訓練。不過,我們做的是一個 Fake quantization,通俗的解釋就是使用 BF16 進行訓練,但是在做 BF16 計算之前,會先把它的輸入 Cast 成 FP8,然后再 Cast back 回到 BF16。這個時候,其實數據表示它已經是 FP8 表示范圍內的值了, 自然這個 scaling 使用的就是 Current scaling,或者說沒有 Scaling。這種情況下,會發現把三類矩陣全部都切回 Fake quantization 進行訓練的時候,此時的 Loss 曲線是可以貼近 BF16 Baseline 的。因此,下面需要一個矩陣一個矩陣的進行排除。

    三類矩陣包括前向的 Fprop,以及反向的 Wgrad 和 Dgrad。因此我們可以遵循一個相對簡單的思路 – 逐一嘗試,就是每次訓練把其中一個矩陣設置為 BF16 計算, 經我們嘗試后,可以看到

    • 在 Fprop 矩陣上面做 BF16 計算,會發現對 Loss 的影響并不是很大。
    • 在 Wgrad 矩陣上面做 BF16 計算,影響也非常小。
    • 在 Dgrad 矩陣上面做 BF16 計算,即只有 Dgrad 計算執行在 BF16,而 Fprop 和 Wgrad 全部執行在 FP8,此時會發現 Loss 會回到 BF16 的 Baseline。

    現在我們已經定位到了有問題的矩陣是 Dgrad,是否還有方法再做進一步的挽救從而避免性能損失太多?這種情況下,可以去進行以下嘗試。

    在 Transformer Engine(TE)的后續版本中,計劃支持用戶使用 Current scaling,即還是使用 FP8 去做 Gemm 的運算。但是我們不用前面給大家推薦這個 Delayed scaling recipe,而是使用當前輸入的 scale 值,雖然會損失一點性能,但是相比于把整個 Gemm 回退到 BF16 做計算,它的性能損失會小很多。

    當對 Dgrad 使用了 Current scaling 之后,會發現 Loss 曲線已經和 BF16 的 Baseline 吻合了。

    以上這就是一個相對完整的一個 debug 的思路,供大家參考和討論原始演講視頻,可以參考:NVIDIA 專家面對面技術沙龍|大模型訓練專場_嗶哩嗶哩_bilibili

    +4

    標簽

    人人超碰97caoporen国产