• <xmp id="om0om">
  • <table id="om0om"><noscript id="om0om"></noscript></table>
  • 3 月 19 日下午 2 點,鎖定 NVIDIA AI 網絡中文專場。立即注冊觀看
    數據中心/云端

    如何在 NVIDIA Llama-3.1-Minitron 4B 模型上修剪和提煉 Llama-3.1 8B

    由于大型語言模型(LLM)的有效性和通用性,它們現在已經成為自然語言處理和理解領域的主導力量。LLM(例如 Llama 3.1 405BNVIDIA Nemotron-4 340B)在編碼、推理和數學等許多具有挑戰性的任務中表現出色。但是,它們的部署需要大量資源。因此,業內出現了另一種趨勢,即開發小型語言模型(SLM),這些模型在許多語言任務中足夠精通,但部署給大眾的成本要便宜得多。

    最近,NVIDIA 研究人員表明,結構化權重剪枝與知識提煉相結合,形成了一種有效且高效的策略,可以從初始較大的同級產品中逐步獲得較小的語言模型。NVIDIA Minitron 8B4B 是如此小的模型,通過在 NVIDIA Nemotron 系列中剪枝和提煉其較大的 15B 同級產品來獲得。

    剪枝和提煉可帶來以下優勢:

    • 與從頭開始訓練相比,MMLU 分數提高了 16%。
    • 每個新增模型所需的訓練令牌更少,約 ~100B 個令牌,減少高達 40 倍。
    • 與從頭開始訓練所有模型相比,訓練一系列模型可節省高達 1.8x 的計算成本。
    • 性能可與使用高達 15T 令牌訓練的 Mistral 7B、Gemma 7B 和 Llama-3 8B 相媲。

    本文還介紹了一套適用于 LLM 的實用且有效的結構化壓縮最佳實踐,這些實踐將深度、寬度、注意力和 MLP 剪枝與基于知識提煉的再訓練相結合。

    在本文中,我們首先討論這些最佳實踐,然后展示其在應用于 Llama 3.1 8B 模型以獲得 Llama-3.1-Minitron 4B 模型時的有效性。Llama-3.1-Minitron 4B 與類似大小的先進開源模型(包括 Minitron 4B、Phi-2 2.7B、Gemma2 2.6B 和 Qwen2-1.5B)相比性能較好。Llama-3.1-Minitron 4B 將很快發布到 NVIDIA Hugging Face 集合中,等待批準。

    修剪和提煉

    剪枝是通過丟棄圖層(深度剪枝)或丟棄神經元、注意力頭和嵌入通道(寬度剪枝)來縮小模型并使其更加精簡的過程。通常情況下,剪枝會伴隨一定數量的重新訓練,以恢復模型的準確性。

    模型提煉是一種技術,用于將知識從大型、復雜的模型(通常稱為教師模型)轉移到更小、更簡單的學生模型。其目標是創建更高效的模型,在更快、更低資源消耗的情況下,保留大型原始模型的大部分預測能力。

    經典知識提煉與 SDG 微調

    主要有兩種distillation:

    • SDG 微調:較大的教師模型生成的合成數據用于進一步微調較小的預訓練學生模型。在這里,學生僅模仿教師預測的最終令牌。例如 Azure AI Studio 中的 Llama 3.1 Azure Distillation 和 AWS 使用 Llama 3.1 405B 生成合成數據和進行提煉以微調較小的模型教程。
    • 經典知識提煉學生在訓練數據集上模仿教師的 logits 和其他中間狀態,而不僅僅是學習必須預測的令牌。這可以被視為提供更好的標簽(與一次性標簽相比的分布)。即使使用相同的數據,梯度也包含更豐富的反饋,從而提高訓練準確性和效率。但是,由于 logits 太大,無法存儲,因此必須為這種提煉方式提供訓練框架支持。

    這兩種提煉方式是相互補充的,而不是相互排斥的。本文主要介紹經典知識提煉方法。

    修剪和提煉程序

    我們建議將剪枝與經典知識提煉相結合作為一種資源節約型再訓練技術(圖 1)。

    1. 我們從一個 15B 的模型開始。我們估計每個組件(層、神經元、頭部和嵌入通道)的重要性,然后對模型進行排名并將其裁剪為目標大小:一個 8B 的模型。
    2. 我們使用模型提煉執行輕度再訓練程序,原始模型作為教師,經過剪枝的模型作為學生。
    3. 訓練完成后,小型模型 (8B) 將作為起點,用于裁剪和提取較小的 4B 模型。
    The diagram shows progressively pruning and distilling models of smaller sizes, from 15B to 8B and from 8B to 4B.
    圖 1.迭代模型剪枝和提煉程序

    圖 1 顯示了單個模型的剪枝和提煉過程(上)以及模型剪枝和提煉鏈(下)。在后者中,前一個階段的輸出模型將作為下一個階段的輸入模型。

    重要性分析

    要對模型進行剪枝,務必要了解模型的哪些部分很重要。我們建議使用純粹基于激活的重要性估計策略,該策略使用小型(1024 個樣本)的校正數據集,同時計算考慮的所有軸(深度、神經元、頭部和嵌入通道)的敏感度信息,并且僅使用前向傳播通道。與依賴梯度信息并需要反向傳播通道的策略相比,此策略的實施更直接、更經濟高效。

    在剪枝時,您可以在給定軸或軸組合的剪枝和重要性估計之間進行迭代交替。但是,我們的經驗工作表明,使用單次重要性估計已經足夠,而迭代估計并沒有任何好處。

    通過經典知識提煉進行再訓練

    圖 2 展示了學生模型 (剪枝模型) 的提煉過程,該模型從具有 M 層的教師模型 (原始未剪枝模型) 中提煉 N 層。學生通過最小化嵌入輸出損失、對數損失以及跨學生塊 S 和教師塊 T 映射的 Transformer 特定編碼器損失的組合來學習。

    The workflow diagram shows classical knowledge distillation from teacher to student, with loss function from several layers of the transformer architecture.
    圖 2.蒸訓練損失

    剪枝和提煉最佳實踐

    基于我們在緊湊語言模型中通過剪枝和知識提煉進行的大量消融研究,我們將所學到的經驗總結為幾種結構化壓縮最佳實踐:

    • Sizing:
      • 要訓練一系列LLMs,請先訓練最大的LLM,然后進行迭代剪枝和提煉,以獲得較小的LLMs。
      • 如果使用多階段訓練策略訓練最大的模型,則最好對從訓練的最后階段獲得的模型進行剪枝和重新訓練。
      • 剪枝最接近目標大小的可用源模型。
    • Pruning:
      • 首選寬度而非深度剪枝。這在所考慮的模型比例 (約 150B) 方面效果很好。
      • 使用單次重要性估計。迭代重要性估計沒有任何益處。
    • Retraining:
      • 僅使用蒸損失進行重新訓練,而不是使用傳統訓練。
      • 當深度顯著降低時,使用logit加中間狀態加嵌入提煉。
      • 當深度未顯著降低時,使用 Logit 純提煉。

    Llama-3.1-Minitron:將最佳實踐付諸實踐

    Meta 最近推出了功能強大的 Llama 3.1 模型系列,這是第一波開源模型,在許多基準測試中可與閉源模型相媲美。Llama 3.1 的范圍從龐大的 405B 模型到 70B 和 8B。

    憑借 Nemotron 提煉經驗,我們著手將 Llama 3.1 8B 模型提煉成更小、更高效的 4B 同級產品:

    • 教師微調
    • 僅深度剪枝
    • 僅寬度的剪枝
    • 準確性基準測試
    • 性能基準測試

    教師微調

    為了修正訓練模型所用的原始數據集上的分布偏移,我們首先在數據集上微調了未刪減的 8B 模型(94 億個令牌)。實驗表明,在不修正分布偏移的情況下,教師模型在提煉數據集時提供次優指導。

    僅深度剪枝

    為了從 8B 擴展到 4B,我們剪枝了 16 層(50%)。我們首先評估了每一層或連續子組的重要性,方法是將其從模型中刪除,并觀察下游任務的 LM 損失增加或準確性降低。

    圖 5 顯示了在驗證集上刪除 1、2、8 或 16 層后 Language Model 的損失值。例如,如果我們刪除前 16 層,則第 16 層的紅色圖表示 Language Model 的損失。如果我們留下第一層并刪除第 2 至 17 層,則第 17 層表示 Language Model 的損失。我們觀察到開始和結束的層最為重要。

    Line chart showing multiple sets of layer importance in depth-only pruning as measured by lm_loss. Layers at the beginning and the end are most important.
    圖 5.僅深度剪枝中的層重要性

    但是,我們發現此 LM 損失不一定與下游性能直接相關。

    圖 6 顯示了每個剪枝模型的 Winogrande 精度。這表明最好刪除第 16 層到第 31 層,其中 31 層是倒數第二層,其中剪枝模型的 5 次射擊精度明顯高于隨機(0.5)。我們根據這個見解刪除了第 16 層到第 31 層。

    Line chart shows the best accuracy on layer 32 out of layers 16-32.
    圖 6.移除 16 層時 Winogrande 任務的準確性

    僅寬度的剪枝

    我們沿著寬度軸剪枝了嵌入(hidden)和 MLP 中間維度,以壓縮 Llama 3.1 8B。具體來說,我們使用前面介紹的基于激活的策略計算了每個注意力頭、嵌入通道和 MLP 隱藏維度的重要性分數。根據重要性估計,我們:

    • 將 MLP 中間維度從 14336 剪枝 (修剪) 到 9216、
    • 將隱藏大小從 4096 縮減到 3072、
    • 重新訓練注意力頭數和層數。

    值得一提的是,在一次性剪枝后,寬度剪枝的 LM 損失要高于深度剪枝。但是,經過簡短的重新訓練后,趨勢就會逆轉。

    準確性基準測試

    我們使用以下參數對模型進行提煉:

    • 峰值學習率 = 1e-4
    • 最低學習率 = 1e-5
    • 40 個步驟的線性預熱
    • 余弦衰減計劃
    • 全局批量大小 = 1152

    表 1 顯示了 Llama-3.1-Minitron 4B 模型變體(寬度剪枝和深度剪枝)在跨多個領域的基準測試中的比較性能,與原始 Llama 3.1 8B 模型和其他類似大小的模型進行比較。

    總體而言,我們再次確認了寬度剪枝策略相比遵循最佳實踐的深度剪枝策略的有效性。

    基準測試 射擊次數 指標 Llama-3.1 8B Minitron 4B Llama-3.1-Minitron 4B Phi-2 27 億 Gemma2 2.6 B? Qwen2-1.5 B?
    寬度剪枝 深度剪枝 寬度剪枝
    Winogrande 5 acc 0.7272 0.7403 以上 0.7214 0.7348 0.7400++ 0.709 0.662
    arc_challenge 25 acc_norm 0.5794 0.5085 0.5256 0.5555% 0.6100% 0.554 0.439
    MMLU 5 acc 0.6528 0.5860%* 0.5871 0.6053% 0.5749 0.513 0.565
    希臘 10 acc_norm 0.8180 0.7496 0.7321 0.7606% 0.7524++ 0.73 0.666
    GSM8K 5 acc 0.4860 0.2411 0.1676 0.4124 0.5500%* 0.239 0.585%
    真實 0 mc2 0.4506 0.4288 0.3817 0.4289 0.4400++ 0.459%
    XLSum en (20%) 3 rougeL 0.3005 0.2954% 0.2722 0.2867%* 0.0100
    MBPP 0 pass = 1 0.4227 0.2817 0.3067 0.324 0.700 以上 0.29 0.374%*
    訓練令牌 15T 94 億 1.4 T 3T 7T
    表 1、與類似規模的基礎社區模型相比,Minitron 4B 基礎模型的準確性

    *最佳模型**次佳模型 – 不可用結果? 模型發行商在模型報告中報告的結果。

    為了驗證提煉的模型是否可以成為強指令模型,我們使用 NeMo-Aligner 微調了 Llama-3.1-Minitron 4B 模型。我們使用了 Nemotron-4 340B 的訓練數據,并在 IFEvalMT-BenchChatRAG-BenchBerkeley 函數調用排行榜(BFCL)上評估了模型,以測試指令遵循、角色扮演、RAG 和函數調用功能。我們確認,Llama-3.1-Minitron 4B 模型可以成為強指令模型,其性能優于其他基準 SLM(表 2)。

    ? Minitron 4B Llama-3.1-Minitron 4B Gemma 2B Phi-2 27 億 Gemma2 2.6 B Qwen2-1.5 B
    基準測試 寬度剪枝 深度剪枝 寬度剪枝
    IFEval 測試 0.4484 0.4257 0.5239++ 0.4050 0.4400 0.6451% 0.3981
    MT 工作臺 5.61 5.64 6.34%* 5.19 4.29 7.73% 5.22
    聊天 AG? 0.4111++ 0.4013 0.4399% 0.3331 0.3760 0.3745 0.2908
    BFCL 0.6423 0.6680% 0.6493++ 0.700 0.2305 0.3562 0.3275
    訓練令牌 94 億 3T 1.4 T 2T 7T
    表 2.Accuracy of aligned Minitron 4B base models compared to similarly sized aligned community models表 2.與類似大小的對齊社區模型相比,Minitron 4B 基礎模型的準確性

    *最佳模型**次優模型*基于 ChatRAG 的代表性子集,而非整個基準測試。

    性能基準測試

    我們使用 NVIDIA TensorRT-LLM(一個用于優化 LLM 推理的開源工具包)優化了 Llama 3.1 8B 和 Llama-3.1-Minitron 4B 模型。

    圖 7 和圖 8 顯示了不同用例下不同模型在 FP8 和 FP16 精度下的每秒吞吐量請求,這些請求以輸入序列長度/輸出序列長度 (ISL/OSL) 組合表示,在單個 NVIDIA H100 80GB GPU 上,以 8B 模型的批量大小為 32、4B 模型的批量大小為 64,這要歸功于單個 GPU 上較小的權重允許批量更大。

    Llama-3.1-Minitron-4B-Depth-Base 變體的運行速度最快,平均吞吐量約為 Llama 3.1 8B 的 2.7 倍,而 Llama-3.1-Minitron-4B-Width-Base 變體的平均吞吐量約為 Llama 3.1 8B 的 1.8 倍。在 FP8 中部署所有三種模型時,與 BF16 相比,性能也提升了約 1.3 倍。

    Bar chart shows the Llama-Minitron-3.1-4B-Depth-Base model being the fastest, followed by Llama-3.1-Minitron 4B-Width-Base and LLama 3.1 8B.
    圖 7. 不同輸入/輸出長度組合下請求 BF16 吞吐量的性能基準測試
    Bar chart shows the Llama-3.1-Minitron-4B-Depth-Base model being fastest, followed by Llama-3.1-Minitron-4B-Width-Base and LLama 3.1 8B.
    圖 8. 不同輸入/輸出長度組合下請求 FP8 吞吐量的性能基準測試

    組合:對于 Llama 3.1 8B,BS = 32;對于 Llama-3.1-Minitron 4B 型號,BS = 64.1 塊 H100 80GB GPU。

    結束語

    剪枝和傳統知識提煉是一種極具成本效益的方法,可以逐步獲得較小尺寸的LLM,與跨所有領域的從頭開始訓練相比,可以實現更高的準確性。相比于合成數據式微調或從頭開始預訓練,它是一種更有效、更高效的方法。

    Llama-3.1-Minitron 4B 是我們首次使用先進的開源 Llama 3.1 系列。要在 NVIDIA NeMo 中使用 Llama-3.1 的 SDG 微調,請參閱 GitHub 上的/sdg-law-title-generation notebook。

    有關更多信息,請參閱以下資源:

    致謝

    如果沒有 NVIDIA 的許多人的貢獻,這項工作是不可能完成的。例如,核心團隊Sharath Turuvekere Sreenivas、Saurav Muralidharan、Marcin Chochowski、Raviraj Joshi;顧問:顧問:顧問:Mostofa Patwary、Mohammad Shoeybi、Bryan Catanzaro、Jan Kautz、Pavlo Molchanov;教學調整Ameya Sunil Mahabaleshwarkar、Hayley Ross、Brandon Rowlett、Oluwatobi Olabiyi、Shizhe Diao、Yoshi Suhara;數據集Sanjeev Satheesh、Shengyang Sun、Jiaqi Zeng、Zhilin Wang、Yi Dong、Zihan Liu、Rajarshi Roy、Wei Ping、Makesh Narsimhan Sreedhar、Oleksii Kuchaiev;TRT-LLM:TRT-LLM:TRT-LLM:Bobby Chen、James Shen;HF 支持Ao Tang、Greg Heinrich;模型優化Chenhan Yu;討論和反饋Daniel Korzekwa;博客后期準備Vinh Nguyen、Sharath Turuvekere Sreenivas。

    ?

    +2

    標簽

    人人超碰97caoporen国产