※本記事は、Stanford大学のCME295コース「Transformers and Large Language Models」の2025年10月17日に実施されたLecture 4 - LLM Trainingの講義内容を基に作成されています。講義の詳細情報およびコースのシラバスは https://cme295.stanford.edu/syllabus/ でご覧いただけます。また、Stanfordの大学院プログラムについては https://online.stanford.edu/graduate からアクセス可能です。
本講義は、Afshine Amidi氏とShervine Amidi氏により行われました。両氏はStanford大学のAdjunct Lecturer(非常勤講師)として、本コースを担当されています。
本記事では、約1時間45分にわたる講義内容を要約・再構成しております。講義ではプレトレーニング、量子化、ハードウェア最適化、教師ありファインチューニング(SFT)、パラメータ効率的ファインチューニング(LoRA)などのトピックが包括的にカバーされています。なお、本記事の内容は原講義の見解を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や詳細な文脈については、オリジナルの講義動画をご覧いただくことをお勧めいたします。講義の完全版は、Stanford CME295のYouTubeプレイリストからアクセスできます。
1. Introduction
1.1 講義の概要と目的
Tatsunori Hashimoto: 皆さん、こんにちは。今日の講義では、大規模言語モデル(LLM)のトレーニングについて詳しく見ていきます。前回の講義では、Transformerアーキテクチャの詳細を学びました。レイヤーノーマライゼーション、アテンションメカニズム、フィードフォワードネットワークなど、モデルを構成する個々のコンポーネントについて深く掘り下げましたね。今日はそこから一歩進んで、これらのモデルを実際にどのようにトレーニングするのか、その実践的な側面に焦点を当てます。
Tatsunori Hashimoto: 今日の講義の主な目標は、LLMトレーニングの3つの重要な段階を理解することです。まず最初の段階は「プレトレーニング」です。これは、インターネット上の膨大なテキストデータを使って、モデルに言語の基本的な理解を教え込むプロセスです。次に「ファインチューニング」があります。これは、プレトレーニングされたモデルを特定のタスクやドメインに適応させる段階です。そして最後に「アライメント」です。これは、モデルの振る舞いを人間の価値観や期待に合わせて調整するプロセスで、モデルをより安全で有用なものにするために重要です。
Tatsunori Hashimoto: 今日の講義では、これら3つの段階それぞれについて、どのように機能するのか、どのような課題があるのか、そして最新の研究でどのような進展があるのかを見ていきます。特に、プレトレーニングの計算コストの問題、スケーリング法則とそれがモデル設計にどう影響するか、そしてトレーニングを効率化するための様々な最適化技術について詳しく説明します。Data ParallelismやZeRO、Model Parallelism、Flash Attentionなど、大規模モデルを実際にトレーニング可能にする技術を取り上げます。
Tatsunori Hashimoto: さらに、量子化や混合精度トレーニングといった、メモリと計算効率を向上させる手法についても議論します。後半では、教師ありファインチューニング、特にInstruction Tuningに焦点を当て、モデルが自然言語の指示に従えるようにする方法を学びます。そして最後に、LoRAやQLoRAといったパラメータ効率的なファインチューニング手法を紹介します。これらは、限られた計算リソースでも大規模モデルを効果的にファインチューニングできる革新的なアプローチです。
Tatsunori Hashimoto: 本日の講義を通じて、皆さんには単にこれらの技術が存在することを知るだけでなく、なぜそれらが必要なのか、どのように機能するのか、そして実際のプロジェクトでどのように適用すればよいのかを深く理解していただきたいと思います。LLMのトレーニングは非常にリソース集約的なプロセスですが、適切な技術と戦略を用いることで、より効率的かつ効果的にモデルを構築することができます。それでは、プレトレーニングから始めましょう。
1.2 LLMトレーニングの全体像
Tatsunori Hashimoto: LLMのトレーニングプロセス全体を理解するために、まずその大きな流れを俯瞰してみましょう。現代の大規模言語モデルは、単一のトレーニングステップで完成するわけではありません。むしろ、それぞれが異なる目的を持つ複数の段階を経て、段階的に能力を獲得していくのです。
Tatsunori Hashimoto: トレーニングパイプラインの最初の段階はプレトレーニングです。これは最も計算集約的で時間がかかるフェーズです。ここでは、インターネット上のウェブページ、書籍、学術論文、コード、その他あらゆる種類のテキストデータから構成される巨大なコーパスを使用します。モデルは次のトークンを予測するという、一見シンプルに見えるタスクを通じて学習します。つまり、与えられたテキストの文脈から、次に来るべき単語や文字を予測するのです。この過程で、モデルは文法、事実知識、推論能力、さらには世界についての一般的な理解を獲得していきます。
Tatsunori Hashimoto: プレトレーニングが完了すると、私たちは汎用的な言語理解能力を持つベースモデルを手に入れます。しかし、このモデルはまだ特定のタスクに最適化されているわけではありません。そこで次の段階、ファインチューニングに移ります。ファインチューニングでは、プレトレーニングされたモデルを特定の用途に適応させます。たとえば、医療文書の分析、法律文書の要約、コード生成、あるいは特定のドメインの質問応答など、様々なタスクに特化させることができます。
Tatsunori Hashimoto: ファインチューニングの中でも特に重要なのがInstruction Tuningです。これは、モデルが自然言語の指示に従って応答できるようにするプロセスです。「この文章を要約してください」「このコードのバグを修正してください」「この質問に答えてください」といった指示をモデルに与え、適切に応答できるようにトレーニングします。これにより、モデルは単にテキストを生成するだけでなく、ユーザーの意図を理解し、それに沿った有用な出力を生成できるようになります。
Tatsunori Hashimoto: トレーニングパイプラインの最終段階はアライメントです。ここでは、モデルの振る舞いを人間の価値観や期待に合わせて調整します。これには、有害なコンテンツの生成を避ける、より正確で誠実な応答を生成する、ユーザーの真の意図を理解するといった要素が含まれます。アライメントには、人間からのフィードバックを用いた強化学習(RLHF: Reinforcement Learning from Human Feedback)などの技術が使われます。これについては、次回以降の講義で詳しく扱います。
Tatsunori Hashimoto: これら3つの段階は、それぞれ異なる特性を持っています。プレトレーニングは最も大規模で、何兆ものトークンを使い、数千から数万のGPU時間を必要とします。一方、ファインチューニングは比較的小規模なデータセット、多くの場合数千から数万の例を使い、数時間から数日で完了します。アライメントは更に少ないデータで行われることが多く、人間の評価者からの高品質なフィードバックに依存します。
Tatsunori Hashimoto: 産業界では、多くの企業がこのパイプライン全体を実装しています。OpenAIのGPTシリーズ、AnthropicのClaude、GoogleのGemini、MetaのLlamaなど、これらすべてのモデルは基本的にこの3段階のプロセスを経ています。ただし、各企業は独自の工夫や最適化を加えており、使用するデータの種類、トレーニングの期間、アライメントの方法などに違いがあります。
Tatsunori Hashimoto: 重要なのは、これらの段階が独立しているわけではなく、相互に関連しているということです。プレトレーニングの質がファインチューニングの効果に影響し、ファインチューニングの方法がアライメントの容易さに影響します。また、各段階で使用する技術的な最適化手法、たとえばデータ並列化やモデル並列化、混合精度トレーニングなどは、すべての段階で共通して適用できるものです。今日の講義では、これらの技術的な詳細に深く踏み込んでいきます。
2. Pretraining(事前学習)
2.1 事前学習の定義と目的
Tatsunori Hashimoto: それでは、LLMトレーニングの最初の段階であるプレトレーニングについて詳しく見ていきましょう。プレトレーニングとは何か、なぜこれほど重要なのかを理解することが、LLM全体を理解する上での基礎となります。
Tatsunori Hashimoto: プレトレーニングの核心は、次トークン予測というタスクにあります。これは非常にシンプルな概念です。モデルに一連のトークン、つまり単語や文字の列を与えると、モデルは次に来るべきトークンを予測しようとします。たとえば、「The cat sat on the」という文章があれば、モデルは次に「mat」が来る確率が高いと予測するでしょう。この予測タスクを何兆回も繰り返すことで、モデルは言語の構造、文法、意味、さらには世界についての知識を学習していくのです。
Tatsunori Hashimoto: なぜこのような単純なタスクが強力なのでしょうか。それは、次トークン予測が実は非常に難しいタスクだからです。適切に次のトークンを予測するためには、モデルは文脈を理解し、文法規則を把握し、事実関係を知っていて、さらには文章の意図や目的まで理解する必要があります。たとえば、「The capital of France is」という文があれば、次に「Paris」が来ることを予測するためには、フランスの首都がパリであるという事実知識が必要です。あるいは、プログラミングコードの文脈で「for i in range(」とあれば、次に数字が来る可能性が高いと予測するためには、プログラミング言語の構文を理解している必要があります。
Tatsunori Hashimoto: プレトレーニングの最大の目的は、汎用的な言語表現を学習することです。特定のタスクに特化するのではなく、言語全般についての深い理解を獲得することが目標です。この汎用的な理解は、後のファインチューニング段階で様々な下流タスクに適応する際の強力な基盤となります。プレトレーニングされたモデルは、文章分類、質問応答、要約、翻訳、コード生成など、多様なタスクに対して少量の追加トレーニングで適応できるようになります。
Tatsunori Hashimoto: プレトレーニングにおいて重要なのは、自己教師あり学習(self-supervised learning)という性質です。次トークン予測では、ラベル付けされたデータを必要としません。テキストデータそのものが、入力と正解ラベルの両方を提供してくれます。文章の最初の部分が入力となり、次のトークンが正解ラベルとなるからです。これにより、インターネット上に存在する膨大な量のテキストデータをすべて活用できます。人間が手動でラベル付けする必要がないため、数兆トークンという規模のデータセットでトレーニングすることが可能になるのです。
Tatsunori Hashimoto: プレトレーニングを通じて、モデルは複数のレベルで知識を獲得します。最も基本的なレベルでは、統語論的な知識、つまり文法や言語の構造を学びます。主語と動詞の一致、語順、句構造などです。次に、意味論的な知識、つまり単語やフレーズの意味、それらの関係性を学びます。さらに高次のレベルでは、事実知識、つまり世界についての具体的な情報を獲得します。歴史的な出来事、科学的な事実、地理的な情報などです。そして最も抽象的なレベルでは、推論能力や常識的な理解も獲得します。
Tatsunori Hashimoto: 興味深いのは、これらすべての能力が明示的に教えられるわけではないということです。モデルは単に次のトークンを予測するという目的関数を最適化しているだけなのに、その過程でこれらの複雑な能力が創発的に現れてくるのです。これは、大規模言語モデルの最も驚くべき特性の一つであり、スケーリング法則の研究が示すように、モデルのサイズとデータ量を増やすにつれて、これらの能力はさらに向上していきます。
Tatsunori Hashimoto: プレトレーニングのもう一つの重要な側面は、転移学習の基盤を提供することです。一度プレトレーニングされたモデルは、様々な下流タスクに適応できる汎用的な特徴表現を持っています。これは、コンピュータビジョンにおいてImageNetで事前学習されたモデルが様々な画像認識タスクに転移できるのと同じ原理です。ただし、言語の場合、その汎用性はさらに広範囲に及びます。なぜなら、言語は本質的に多目的なコミュニケーション手段だからです。
Tatsunori Hashimoto: プレトレーニングは、LLM開発における最もコストのかかる段階でもあります。計算リソース、時間、そして電力の面で膨大な投資が必要です。しかし、一度プレトレーニングが完了すれば、そのモデルは多くの異なる用途に再利用できます。これが、多くの研究機関や企業がプレトレーニングされたモデルを公開している理由です。LlamaやMistralといったオープンソースモデルは、コミュニティ全体がこの高価なプレトレーニング段階の成果を共有できるようにしています。
2.2 事前学習のデータセットと手法
Tatsunori Hashimoto: プレトレーニングの成功は、使用するデータセットの質と量に大きく依存します。現代のLLMは、インターネット上から収集された膨大な量のテキストデータでトレーニングされていますが、このデータの選択と処理は単純な作業ではありません。どのようなデータソースが使われ、どのように処理されるのかを見ていきましょう。
Tatsunori Hashimoto: 典型的なプレトレーニングデータセットは、複数の異なるソースから構成されています。最も大きな部分を占めるのは、ウェブテキストです。Common Crawlというプロジェクトがインターネット全体をクロールして収集したデータがよく使われます。これには、ニュース記事、ブログ投稿、フォーラムの議論、製品レビュー、ウィキペディアの記事など、あらゆる種類のウェブコンテンツが含まれます。ただし、Common Crawlのデータは非常に雑多で、低品質なコンテンツも多く含まれているため、そのまま使うわけにはいきません。
Tatsunori Hashimoto: 書籍もプレトレーニングデータの重要なソースです。Books3やProject Gutenbergなどのコーパスには、小説、ノンフィクション、学術書など、長文の構造化されたテキストが含まれています。書籍は一般的にウェブテキストよりも質が高く、より一貫性のある長い文脈を提供します。これは、モデルが長距離の依存関係や複雑な論理構造を学習するのに役立ちます。
Tatsunori Hashimoto: 学術論文も重要なデータソースです。arXivやPubMedなどから収集された論文は、専門的な知識や科学的な推論パターンを学習するのに有用です。特に、数学や科学の問題を解く能力を向上させるためには、学術コンテンツが不可欠です。
Tatsunori Hashimoto: そして、近年特に重要性が増しているのがコードです。GitHubなどのリポジトリから収集されたプログラミングコードは、多くの最新モデルのトレーニングに含まれています。コードを学習することで、モデルは論理的思考や構造化された問題解決の能力を獲得します。興味深いことに、コードでトレーニングされたモデルは、コード生成タスクだけでなく、一般的な推論タスクでも性能が向上することが観察されています。
Tatsunori Hashimoto: データの収集が完了したら、次は前処理とフィルタリングの段階です。これは極めて重要なステップで、最終的なモデルの性能に大きな影響を与えます。まず、重複除去を行います。インターネット上には同じコンテンツが何度も繰り返し現れることがあり、これらの重複はモデルの学習を歪める可能性があります。多くの場合、文書レベルまたは段落レベルでの重複除去が行われます。
Tatsunori Hashimoto: 品質フィルタリングも重要です。自動生成されたスパム、意味不明なテキスト、過度に短いまたは長い文書、フォーマットが壊れたコンテンツなどを除去します。これには、様々なヒューリスティックが使われます。たとえば、単語の平均長、句読点の比率、ストップワードの頻度、言語モデルのパープレキシティスコアなどを基準にフィルタリングします。一部のプロジェクトでは、高品質なコンテンツで訓練された分類器を使って、品質の低いデータを識別し除去します。
Tatsunori Hashimoto: 有害コンテンツのフィルタリングも行われます。ヘイトスピーチ、露骨な性的コンテンツ、暴力的なコンテンツなどを検出し除去するために、キーワードベースのフィルタリングや、専用の分類器が使用されます。ただし、これは微妙なバランスを要する作業です。過度にフィルタリングすると、モデルの能力が制限される可能性がありますし、不十分だと有害な出力を生成するリスクが高まります。
Tatsunori Hashimoto: 個人識別情報(PII)の除去も重要な前処理ステップです。メールアドレス、電話番号、住所、社会保障番号などの個人情報を検出し、除去または匿名化します。これはプライバシー保護の観点から必須の処理です。
Tatsunori Hashimoto: データの構成比率も慎重に設計されます。すべてのデータソースを均等に使うわけではなく、ウェブテキスト、書籍、論文、コードなど、それぞれの比率を調整します。たとえば、GPT-3では、Common Crawlのデータは60%程度、書籍が16%、ウィキペディアが3%といった具合に配分されています。この比率は、モデルの最終的な性能特性に影響を与えます。
Tatsunori Hashimoto: トークン化もデータ処理の重要な側面です。テキストをモデルが処理できる形式に変換する必要があります。多くのモデルでは、Byte Pair Encoding(BPE)やSentencePieceといったサブワードトークン化手法が使われます。これらの手法は、頻出する単語やサブワードの組み合わせを効率的にエンコードしながら、未知の単語にも対応できる柔軟性を提供します。
Tatsunori Hashimoto: バッチ処理の方法も工夫されています。トレーニング時には、異なる長さの文書を効率的にバッチ化する必要があります。多くの実装では、コンテキストウィンドウのサイズ、たとえば2048トークンや4096トークンといった固定長にテキストを分割し、これを基本単位としてバッチを構成します。文書の境界をまたぐ場合でも、そのまま連結して処理することが一般的です。
Tatsunori Hashimoto: データの多様性も考慮されます。特定のトピックやスタイルに偏りすぎないよう、様々なジャンル、言語レジスター、難易度のテキストをバランスよく含めることが重要です。これにより、モデルは幅広い状況で適切に機能できるようになります。
Tatsunori Hashimoto: 最近の研究では、データの質が量よりも重要である場合があることが示されています。単純にデータ量を増やすだけでなく、高品質で多様なデータを選択的に使用することで、より少ないトークン数でも優れた性能を達成できることがわかってきました。これは、Chinchilla法則とも関連する重要な知見で、後ほど詳しく議論します。
2.3 事前学習の課題とコスト
Tatsunori Hashimoto: プレトレーニングは、LLM開発の中で最もコストがかかり、最も困難な段階です。その規模と複雑さは、多くの研究機関や企業にとって大きな障壁となっています。具体的にどれほどのリソースが必要なのか、そしてどのような課題があるのかを見ていきましょう。
Tatsunori Hashimoto: まず、計算リソースの膨大さについて話しましょう。現代の大規模言語モデルのトレーニングには、数千から数万のGPU時間が必要です。たとえば、GPT-3のトレーニングには、約3.14×10の23乗FLOPs、つまり314ゼタFLOPsの計算が必要だったと推定されています。これをNVIDIA V100 GPUで実行すると仮定すると、数千個のGPUを数週間から数ヶ月稼働させ続ける必要があります。より大規模なモデル、たとえば5000億パラメータや1兆パラメータのモデルになると、必要な計算量はさらに桁違いに増加します。
Tatsunori Hashimoto: 金銭的なコストも驚異的です。GPT-3のトレーニングコストは、約460万ドルから1200万ドルと推定されています。これは使用するハードウェアや期間によって変動しますが、いずれにしても莫大な投資です。さらに大規模なモデルでは、トレーニングコストが数千万ドルに達することもあります。これには、GPUのレンタルや購入費用だけでなく、電力コスト、冷却システム、データセンターのインフラストラクチャ、そして専門的なエンジニアリングチームの人件費も含まれます。
Tatsunori Hashimoto: 時間的なコストも無視できません。大規模モデルのプレトレーニングには、数週間から数ヶ月かかります。GPT-3は約34日間のトレーニングを要したと報告されています。これは、数千のGPUを並列に使用した場合の時間です。単一のGPUで同じことをしようとすれば、何年もかかるでしょう。この長い時間は、研究開発のサイクルを遅くし、迅速な実験やイテレーションを困難にします。
Tatsunori Hashimoto: メモリの制約も大きな課題です。大規模なTransformerモデルは、パラメータだけで膨大なメモリを消費します。175億パラメータのモデルを32ビット浮動小数点数で保存すると、700GB以上のメモリが必要です。しかし、トレーニング時には、パラメータだけでなく、勾配、オプティマイザの状態、アクティベーション(中間計算結果)なども保存する必要があります。これらを合計すると、単一のGPUでは到底収まりません。NVIDIA A100の最大メモリ容量は80GBですから、大規模モデルを単一デバイスでトレーニングすることは物理的に不可能なのです。
Tatsunori Hashimoto: この問題を解決するために、モデルを複数のGPUに分散させる必要があります。しかし、これは新たな課題を生み出します。GPUメモリ間の通信オーバーヘッドです。データを複数のデバイス間で転送する際、通信帯域幅がボトルネックになります。特に、勾配の同期や重みの更新を行う際には、大量のデータを転送する必要があり、これが全体のトレーニング速度を低下させる要因となります。効率的な並列化戦略を設計することが、大規模トレーニングの成功の鍵となります。
Tatsunori Hashimoto: 数値的安定性の問題もあります。何兆もの演算を行う中で、数値誤差が蓄積し、勾配消失や勾配爆発といった問題が発生する可能性があります。特に、混合精度トレーニングを使用する場合、16ビット浮動小数点数の限られた数値範囲が問題になることがあります。適切な初期化、学習率のスケジューリング、グラディエントクリッピングなど、様々なテクニックを駆使して、安定したトレーニングを維持する必要があります。
Tatsunori Hashimoto: チェックポイントとリカバリーの課題もあります。数週間にわたるトレーニングの途中で、ハードウェア障害やソフトウェアのバグが発生する可能性は高いです。定期的にモデルの状態を保存し、問題が発生した場合に迅速に回復できる仕組みが不可欠です。しかし、数百GBから数TBに及ぶモデルの状態を頻繁に保存することは、それ自体が大きなオーバーヘッドとなります。効率的なチェックポイント戦略を設計する必要があります。
Tatsunori Hashimoto: ハイパーパラメータの調整も困難です。学習率、バッチサイズ、ウォームアップステップ、重み減衰など、多くのハイパーパラメータがモデルの最終性能に影響します。しかし、大規模モデルでこれらを調整するには、各実験に数週間から数ヶ月かかるため、徹底的なハイパーパラメータ探索は現実的ではありません。多くの場合、小規模なモデルでの実験結果や、以前のトレーニング実行からの知見を基に、ハイパーパラメータを慎重に選択する必要があります。
Tatsunori Hashimoto: 環境への影響も無視できない課題です。大規模モデルのトレーニングには、膨大な電力が消費されます。GPT-3のトレーニングは、約1,287MWhの電力を消費したと推定されており、これは約552トンのCO2排出に相当します。AI研究コミュニティでは、この環境コストを削減し、より持続可能なトレーニング手法を開発することが重要な課題として認識されています。
Tatsunori Hashimoto: データの質と量のバランスも課題です。より多くのデータでトレーニングすれば必ず性能が向上するわけではありません。低品質なデータは、モデルの性能を損なう可能性があります。一方で、高品質なデータだけを選択すると、データ量が不足する可能性があります。適切なバランスを見つけることが重要ですが、これは経験とドメイン知識を要する難しい問題です。
Tatsunori Hashimoto: これらすべての課題があるにもかかわらず、プレトレーニングは依然としてLLM開発の中核です。なぜなら、一度プレトレーニングが完了すれば、そのモデルは多様なタスクに適応でき、多くの異なる用途に再利用できるからです。このため、大手テクノロジー企業や研究機関は、莫大なリソースを投じてプレトレーニングを行い、その成果をコミュニティと共有しています。今日の講義の後半で説明する様々な最適化技術は、これらの課題を軽減し、より効率的なプレトレーニングを可能にするために開発されてきました。
3. FLOPs, FLOPS
3.1 FLOPsとFLOPSの定義と違い
Tatsunori Hashimoto: プレトレーニングのコストについて議論する前に、計算量を測定するための基本的な用語を明確にしておく必要があります。FLOPsとFLOPSという2つの用語がありますが、これらは似ているように見えて全く異なる概念です。この違いを理解することは、LLMのトレーニングコストを正確に見積もり、議論する上で極めて重要です。
Tatsunori Hashimoto: まず、FLOPsについて説明しましょう。FLOPsは「Floating Point Operations」の複数形で、浮動小数点演算の総数を表します。これは、ある計算タスクを完了するために必要な演算の絶対的な量です。たとえば、2つの行列を掛け算する場合、必要な乗算と加算の回数がFLOPsで表されます。FLOPsは作業の総量を示す指標であり、時間の概念は含まれていません。1000 FLOPsの計算タスクは、それを実行するのに1秒かかろうが1時間かかろうが、常に1000 FLOPsなのです。
Tatsunori Hashimoto: 一方、FLOPSは「Floating Point Operations Per Second」の略で、1秒あたりの浮動小数点演算数を表します。これは計算速度、つまりハードウェアの性能を示す指標です。たとえば、あるGPUが100 TFLOPSの性能を持つと言う場合、そのGPUは1秒間に100兆回の浮動小数点演算を実行できることを意味します。FLOPSは、ハードウェアがどれだけ速く計算できるかを示しています。
Tatsunori Hashimoto: この2つの関係を理解するために、シンプルな例を考えてみましょう。あるタスクが1ペタFLOPs、つまり10の15乗FLOPsを必要とするとします。これをピーク性能が100 TFLOPSのGPUで実行する場合、理想的には10,000秒、約2.8時間で完了することになります。計算式は簡単です。総FLOPs数をFLOPSで割ればよいのです。1×10の15乗FLOPs ÷ 100×10の12乗FLOPS = 10,000秒となります。
Tatsunori Hashimoto: ただし、実際にはこれほど単純ではありません。GPUのピーク性能は理論値であり、実際のアプリケーションではメモリ帯域幅、通信オーバーヘッド、カーネルの効率などの要因により、ピーク性能の一部しか達成できないことがほとんどです。実際の利用率は、タスクによって20%から80%程度の範囲で変動します。これをMFU(Model FLOPs Utilization)と呼び、後ほど詳しく説明します。
Tatsunori Hashimoto: 表記の混乱を避けることも重要です。FLOPsは複数形なので小文字のsで終わりますが、FLOPSは頭字語なのですべて大文字です。しかし、文献や論文では必ずしも統一されていないため、文脈から判断する必要があることがよくあります。一般的なルールとして、非常に大きな数字(テラ、ペタ、エクサなど)が単位なしで示されている場合はFLOPs(総演算数)を指し、ハードウェアの仕様を議論している場合はFLOPS(演算速度)を指していると考えられます。
Tatsunori Hashimoto: LLMのトレーニングコンテキストでは、両方の概念が頻繁に使われます。モデルのトレーニングに必要な総計算量を議論する際にはFLOPsを使い、使用するハードウェアの能力を議論する際にはFLOPSを使います。たとえば、「このモデルは3×10の23乗FLOPsでトレーニングされた」と言う場合、これは総作業量です。そして、「このクラスタは500 PFLOPSの計算能力を持つ」と言う場合、これはそのクラスタの1秒あたりの計算速度を指しています。
Tatsunori Hashimoto: スケールの感覚を掴むために、いくつかの具体例を見てみましょう。NVIDIA A100 GPUは、FP32精度で約19.5 TFLOPS、TF32精度で約156 TFLOPS、FP16精度で約312 TFLOPSの性能を持っています。より新しいH100では、FP16で約2000 TFLOPS、つまり2 PFLOPSに達します。一方、GPT-3のトレーニングには約314ゼタFLOPs、つまり3.14×10の23乗FLOPsが必要だったと推定されています。これらの数字を比較することで、大規模モデルのトレーニングがいかに計算集約的であるかがわかります。
Tatsunori Hashimoto: この区別を理解することは、トレーニング時間やコストを見積もる際に不可欠です。総FLOPs数がわかれば、特定のハードウェア構成でどれくらいの時間がかかるか、あるいは特定の期間内にトレーニングを完了させるにはどれだけのハードウェアが必要かを計算できます。これは、プロジェクトの計画や予算策定において極めて重要な情報となります。
3.2 計算量の測定方法
Tatsunori Hashimoto: LLMのトレーニングに必要な計算量を正確に見積もることは、プロジェクトの計画や予算策定において非常に重要です。Transformerモデルの計算量は、モデルのアーキテクチャパラメータから数学的に導出できます。具体的にどのように計算するのか見ていきましょう。
Tatsunori Hashimoto: Transformerモデルのトレーニングに必要なFLOPsは、基本的に次の公式で近似できます。トレーニング1ステップあたりのFLOPsは、およそ6×N×Dです。ここで、Nはモデルのパラメータ数、Dはトレーニングに使用するトークン数です。この係数6という数字は、forward passとbackward passの両方を考慮したものです。forward passには約2×N×D FLOPsが必要で、backward passにはその2倍、つまり4×N×D FLOPsが必要だからです。
Tatsunori Hashimoto: なぜforward passが2×N×Dなのか、もう少し詳しく見てみましょう。Transformerの各レイヤーでは、主に行列乗算が行われます。アテンション機構では、Query、Key、Valueの計算、アテンションスコアの計算、そして出力の計算があります。フィードフォワードネットワークでは、2つの大きな行列乗算が行われます。これらすべてを合計すると、各トークンの処理に約2N回の演算が必要になります。したがって、Dトークンを処理する場合、2×N×D FLOPsとなります。
Tatsunori Hashimoto: backward passが2倍の計算量を必要とする理由は、勾配計算の性質にあります。forward passで行ったすべての計算について、backward passではその勾配を計算する必要があります。さらに、重みの勾配とアクティベーションの勾配の両方を計算する必要があるため、計算量はforward passの約2倍になります。これを合計すると、2 + 4 = 6となり、6×N×Dという公式が得られます。
Tatsunori Hashimoto: 具体例で確認してみましょう。GPT-3は1750億パラメータ、つまりN = 1.75×10の11乗です。トレーニングには約3000億トークン、D = 3×10の11乗が使われたと報告されています。この公式を適用すると、6 × 1.75×10の11乗 × 3×10の11乗 = 3.15×10の23乗FLOPsとなります。これは約315ゼタFLOPsで、実際に報告されている値とほぼ一致します。
Tatsunori Hashimoto: この公式は近似であり、いくつかの簡略化を含んでいることに注意が必要です。実際には、レイヤーノーマライゼーション、ソフトマックス計算、その他の小さな演算もありますが、これらは全体の計算量に比べれば無視できる程度です。大規模モデルでは、計算量の大部分は行列乗算によるもので、それが6×N×Dという公式でよく捉えられています。
Tatsunori Hashimoto: より詳細な計算が必要な場合は、各コンポーネントを個別に考慮できます。たとえば、セルフアテンション層の計算量は、シーケンス長をLとすると、O(L²×d)です。ここでdは隠れ層の次元数です。長いシーケンスを扱う場合、このL²の項が支配的になることがあります。しかし、典型的なトレーニング設定では、シーケンス長は2048や4096といった固定値に制限されており、パラメータ数Nやトークン数Dに比べれば相対的に小さいため、全体の計算量への影響は限定的です。
Tatsunori Hashimoto: 推論時の計算量は、トレーニング時とは異なります。推論では、forward passのみを実行すればよいため、1トークン生成あたり約2×N FLOPsが必要です。ただし、自己回帰的な生成の場合、各トークンを生成するたびにKVキャッシュを更新しながら処理を進めるため、シーケンス長に応じた追加の計算が発生します。長いテキストを生成する場合、この累積的な計算量は無視できなくなります。
Tatsunori Hashimoto: バッチサイズも考慮する必要があります。バッチサイズBでトレーニングする場合、各ステップの計算量は6×N×B×Lとなります。ここでLは各サンプルのシーケンス長です。総トレーニングステップ数がSであれば、総FLOPsは6×N×B×L×Sです。これは6×N×Dと同等ですが、Dが総トークン数、つまりB×L×Sに等しいためです。
Tatsunori Hashimoto: MFU(Model FLOPs Utilization)という概念も重要です。これは、実際に達成された計算速度をハードウェアのピーク性能で割ったものです。たとえば、A100 GPUのピーク性能が312 TFLOPSで、実際のトレーニングで100 TFLOPSを達成している場合、MFUは約32%です。高効率なトレーニング実装では、MFUが50%から60%に達することもありますが、これは様々な最適化技術を駆使した結果です。
Tatsunori Hashimoto: 計算量を見積もる際には、メモリアクセスのコストも考慮する必要があります。現代のGPUでは、計算速度よりもメモリ帯域幅がボトルネックになることが多いです。特に、小さなバッチサイズや短いシーケンスでトレーニングする場合、メモリアクセスの待ち時間が支配的になり、理論的な計算速度を達成できないことがあります。これが、Flash Attentionのようなメモリ効率的なアルゴリズムが重要である理由です。
Tatsunori Hashimoto: 計算量の見積もりは、異なるモデルアーキテクチャを比較する際にも有用です。同じ計算予算で、より浅く広いモデルと、より深く狭いモデルのどちらが良いかを判断する際、FLOPs計算は客観的な比較基準を提供します。また、Mixture of Expertsのような条件付き計算を使用するアーキテクチャでは、実際に活性化するパラメータ数がサンプルによって異なるため、平均的なFLOPsを計算する必要があります。
Tatsunori Hashimoto: これらの計算方法を理解することで、新しいモデルをトレーニングする前に、必要な時間とリソースを現実的に見積もることができます。これは、研究プロジェクトの実行可能性を評価し、計算予算を効果的に配分する上で不可欠なスキルです。
3.3 トレーニングコストの見積もり
Tatsunori Hashimoto: FLOPsの計算方法を理解したところで、実際のトレーニングコストを見積もる実践的なアプローチを見ていきましょう。総FLOPs数がわかれば、特定のハードウェア構成でどれくらいの時間がかかり、どれだけの費用が必要かを計算できます。
Tatsunori Hashimoto: 基本的な計算式は非常にシンプルです。トレーニング時間は、総FLOPs数を実効FLOPS(実際に達成される1秒あたりの演算速度)で割ったものです。たとえば、3×10の23乗FLOPsのトレーニングタスクがあり、実効性能が100 TFLOPSのGPUを使用する場合、必要な時間は3×10の23乗 ÷ 100×10の12乗 = 3×10の9乗秒、つまり約34,722日、または約95年になります。これは明らかに単一GPUでは非現実的ですね。
Tatsunori Hashimoto: そこで並列化が必要になります。1000個のGPUを並列に使用すれば、理想的には時間を1000分の1に短縮できます。先ほどの例では、約34.7日になります。これは実際にGPT-3のトレーニングにかかった時間とほぼ一致します。ただし、これは理想的な並列化効率を仮定しています。実際には、通信オーバーヘッドやその他の非効率性により、完全な線形スケーリングは達成できません。
Tatsunori Hashimoto: より現実的な見積もりのために、MFU(Model FLOPs Utilization)を考慮する必要があります。MFUは通常40%から60%の範囲です。最適化されていない実装では20%程度まで下がることもあります。仮にMFUが50%だとすると、ハードウェアのピーク性能の半分しか実際には利用できていないことになります。A100 GPUのFP16ピーク性能は312 TFLOPSですが、実効性能は約156 TFLOPSになります。この実効値を使って時間を計算する必要があります。
Tatsunori Hashimoto: コストの見積もりも重要です。クラウドプロバイダーでのGPUレンタル料金は、使用するGPUの種類と地域によって異なります。たとえば、NVIDIA A100 80GB GPUは、Amazon Web Servicesでは1時間あたり約4ドルから5ドルです。Google CloudやMicrosoft Azureでも同様の価格帯です。もし1000個のA100を34日間(816時間)使用するとすれば、コストは1000 × 816 × 4.5 = 約367万ドルになります。これに、ストレージコスト、ネットワーク転送費用、その他の間接費用を加えると、総コストは400万ドルから500万ドルに達します。
Tatsunori Hashimoto: 電力コストも無視できません。A100 GPUは約400ワットの電力を消費します。1000個のGPUを34日間稼働させると、約326,400 kWh(キロワット時)の電力が必要です。商業用電力料金が1 kWhあたり0.10ドルだとすると、電力コストだけで約32,640ドルになります。データセンター全体の冷却や他のインフラを含めると、この数字はさらに増加します。実際には、電力コストは総運用コストの中でかなりの割合を占めます。
Tatsunori Hashimoto: より小規模なモデルの例も見てみましょう。70億パラメータのモデル、たとえばLlama-2 7Bクラスのモデルを考えます。Chinchilla法則に従って、パラメータ数の20倍のトークン、つまり1400億トークンでトレーニングするとします。必要なFLOPsは、6 × 7×10の9乗 × 1.4×10の11乗 = 5.88×10の21乗FLOPsです。これをA100 8台で実行する場合、実効性能を1台あたり150 TFLOPSとすると、総実効性能は1.2 PFLOPSです。トレーニング時間は5.88×10の21乗 ÷ 1.2×10の15乗 = 約4.9×10の6乗秒、つまり約56.7日になります。
Tatsunori Hashimoto: コスト面では、A100 8台を57日間(1368時間)使用すると、8 × 1368 × 4.5 = 約49,248ドルになります。これは、個人の研究者や小規模な研究グループにとっても、慎重に検討すべき投資額です。ただし、GPT-3のような超大規模モデルと比較すれば、はるかに手の届きやすい範囲です。
Tatsunori Hashimoto: チェックポイント保存のストレージコストも考慮する必要があります。70億パラメータのモデルをBF16形式で保存すると、約14GBのストレージが必要です。しかし、トレーニング中はオプティマイザの状態や勾配も保存するため、実際にはその数倍のストレージが必要になります。定期的にチェックポイントを保存し、複数のバージョンを保持する場合、数百GBから数TBのストレージが必要になることがあります。
Tatsunori Hashimoto: 実際のプロジェクトでは、予期しない問題や実験の失敗も考慮する必要があります。ハイパーパラメータが適切でなかったり、データの問題が発見されたりして、トレーニングをやり直す必要が生じることがあります。したがって、初期見積もりの1.5倍から2倍のバッファーを持つことが賢明です。もし1回のトレーニング実行に50,000ドルかかると見積もった場合、総予算としては75,000ドルから100,000ドルを確保しておくべきです。
Tatsunori Hashimoto: 計算予算を最適に配分することも重要です。限られた予算内で最良の結果を得るためには、モデルサイズとトレーニングトークン数のバランスを慎重に選ぶ必要があります。これがまさに次のセクションで議論するスケーリング法則とChinchilla法則の核心です。同じ計算予算で、より大きなモデルを少ないデータでトレーニングするか、より小さなモデルをより多くのデータでトレーニングするか、どちらが良いのでしょうか。
Tatsunori Hashimoto: 人件費も忘れてはいけません。大規模トレーニングプロジェクトには、機械学習エンジニア、データエンジニア、インフラエンジニアなど、専門的なチームが必要です。これらの人件費は、ハードウェアコストと同等かそれ以上になることがあります。包括的なコスト見積もりには、これらすべての要素を含める必要があります。
Tatsunori Hashimoto: これらの計算方法を理解することで、プロジェクトの実現可能性を現実的に評価できます。どれだけの予算があり、どのくらいの時間枠で、どの程度のモデルサイズが達成可能かを明確に把握できるのです。次のセクションでは、この計算予算をどのように最適に配分するかという、スケーリング法則について詳しく見ていきます。
4. Scaling laws, Chinchilla law(スケーリング法則)
4.1 スケーリング法則の基本概念
Tatsunori Hashimoto: スケーリング法則は、大規模言語モデルの研究において最も重要な発見の一つです。これは、モデルの性能がモデルサイズ、データ量、そして計算量とどのように関係するかを定量的に予測できることを示しています。この法則を理解することで、限られた計算予算を最適に配分し、最良の結果を得る方法を科学的に決定できます。
Tatsunori Hashimoto: スケーリング法則の基本的な洞察は、モデルの損失関数、つまり性能の指標が、いくつかの主要な変数のべき乗則(power law)に従うということです。具体的には、モデルのパラメータ数N、トレーニングデータのトークン数D、そして計算量Cが増加すると、検証損失(validation loss)が予測可能な形で減少します。この関係は、非常に広い範囲のスケールにわたって驚くほど滑らかで一貫しています。
Tatsunori Hashimoto: 初期の重要な研究は、2020年にOpenAIのチームによって発表されました。彼らは、トランスフォーマーモデルを様々なサイズとデータセットでトレーニングし、その性能を体系的に測定しました。その結果、モデルのパラメータ数を10倍にすると、損失がほぼ一定の割合で減少することを発見しました。同様に、トレーニングデータ量を増やしても、予測可能な改善が見られました。
Tatsunori Hashimoto: より具体的には、損失Lは次のような形で表現できます。L(N) = aN^(-α) + L∞です。ここで、Nはパラメータ数、aとαは実験的に決定される定数、L∞は理想的な無限大のモデルでの到達可能な最小損失です。同様に、データ量についても、L(D) = bD^(-β) + L_∞という関係があります。これらのべき指数αとβは、通常0.05から0.1程度の値を取ります。
Tatsunori Hashimoto: この法則が示す重要な点は、スケールの増加による改善は対数的ではなく、べき乗則に従うということです。つまり、一定の改善を得るためには、リソースを指数的に増やす必要があります。たとえば、損失を半分にするためには、モデルサイズを10倍から100倍にする必要があるかもしれません。これは、より良いモデルを作るには莫大なリソースが必要になることを意味しますが、同時に、その改善が予測可能であることも示しています。
Tatsunori Hashimoto: OpenAIの研究では、もう一つ重要な発見がありました。それは、モデルサイズ、データ量、計算量の3つの要素の中で、モデルサイズが最も重要であるように見えたことです。彼らの結論は、限られた計算予算がある場合、できるだけ大きなモデルをトレーニングし、データ量は相対的に少なくても構わないというものでした。この考え方は、GPT-3の設計に影響を与えました。GPT-3は1750億パラメータという非常に大きなモデルですが、トレーニングトークン数は約3000億と、パラメータ数の2倍未満でした。
Tatsunori Hashimoto: しかし、この初期の理解には重要な見落としがありました。2022年にDeepMindから発表されたChinchillaの研究は、スケーリング法則に対する理解を大きく変えました。彼らは、モデルサイズとデータ量の最適なバランスについて、より注意深い分析を行いました。その結果、従来考えられていたよりも、データ量がはるかに重要であることが明らかになったのです。
Tatsunori Hashimoto: スケーリング法則を理解する上でもう一つ重要な概念は、計算最適性(compute-optimal)です。与えられた計算予算Cに対して、最良の性能を達成するためのモデルサイズNとデータ量Dの組み合わせは何か、という問いです。計算予算Cは、C ≈ 6NDという関係で、モデルサイズとデータ量に関連しています。同じ計算予算でも、大きなモデルを短くトレーニングするか、小さなモデルを長くトレーニングするか、様々な選択肢があります。
Tatsunori Hashimoto: スケーリング法則の美しさは、その予測可能性にあります。小規模な実験から得られたデータを使って、はるかに大規模なモデルの性能を予測できるのです。これにより、実際に何百万ドルもかけて大規模モデルをトレーニングする前に、その期待される性能を見積もることができます。これは、研究開発の効率を大幅に向上させる強力なツールです。
Tatsunori Hashimoto: ただし、スケーリング法則にも限界があることを認識しておく必要があります。これらの法則は主に検証損失、つまりパープレキシティのような基本的な指標に関するものです。しかし、実際のタスク性能、特に推論能力や複雑な問題解決能力は、必ずしも損失と直線的な関係にあるわけではありません。大規模モデルでは、創発的能力(emergent abilities)と呼ばれる、小規模モデルでは見られなかった能力が突然現れることがあります。これらは、スケーリング法則だけでは完全には予測できません。
Tatsunori Hashimoto: また、スケーリング法則は主にプレトレーニング段階に適用されます。ファインチューニングやアライメント段階では、異なるダイナミクスが働く可能性があります。さらに、異なるアーキテクチャや異なるデータ分布では、スケーリングの挙動が変わる可能性もあります。それでも、スケーリング法則はLLM研究の基礎的な原理として、モデル設計と計算予算の配分に関する意思決定を導いています。
4.2 Chinchilla法則の発見と意義
Tatsunori Hashimoto: 2022年、DeepMindのチームが発表したChinchillaの論文は、スケーリング法則に対する理解を根本的に変えました。この研究は、従来の常識を覆し、大規模言語モデルの設計戦略に大きな影響を与えることになります。Chinchilla法則とは何か、そしてなぜそれがこれほど重要なのかを詳しく見ていきましょう。
Tatsunori Hashimoto: DeepMindの研究者たちは、OpenAIの初期のスケーリング法則研究には重要な見落としがあることに気づきました。それは、モデルサイズとデータ量の最適なバランスについての分析が不十分だったということです。OpenAIの研究では、主にモデルサイズを変えながら性能を測定していましたが、各モデルサイズに対して本当に最適なデータ量でトレーニングしていたわけではありませんでした。
Tatsunori Hashimoto: DeepMindのアプローチは、より体系的でした。彼らは、様々なモデルサイズと様々なトレーニングトークン数の組み合わせで、400以上の異なるモデルをトレーニングしました。モデルサイズは7000万パラメータから160億パラメータまで、トレーニングトークン数は50億から5000億まで、幅広い範囲をカバーしました。そして、同じ計算予算の下で、どの組み合わせが最良の性能を達成するかを分析したのです。
Tatsunori Hashimoto: その結果は衝撃的でした。従来、Gopherという700億パラメータのモデルを3000億トークンでトレーニングしていましたが、同じ計算予算を使って、Chinchillaという700億パラメータのモデルを1.4兆トークンでトレーニングすると、性能が大幅に向上したのです。つまり、同じパラメータ数でも、トレーニングデータを約4.7倍に増やすことで、より良いモデルが得られることがわかったのです。
Tatsunori Hashimoto: さらに興味深いのは、Chinchilla(700億パラメータ、1.4兆トークン)が、はるかに大きなGopher(2800億パラメータ、3000億トークン)よりも優れた性能を示したことです。これは、単にモデルを大きくするだけでは最適ではなく、モデルサイズとデータ量のバランスが極めて重要であることを明確に示しました。実際、Chinchillaは様々な下流タスクでGopherを上回り、計算効率の面でも優れていました。
Tatsunori Hashimoto: DeepMindの分析から導かれた重要な結論は、compute-optimal(計算最適)なトレーニングでは、モデルのパラメータ数とトレーニングトークン数をほぼ等しく増やすべきだということです。より正確には、計算予算を10倍にする場合、モデルサイズを約3.2倍に、トレーニングトークン数を約3.2倍にするのが最適だとされました。これは、両者をバランスよくスケールさせるというアプローチです。
Tatsunori Hashimoto: この発見の重要性は、既存の多くの大規模モデルが実は「アンダートレーニング」、つまりトレーニング不足の状態にあることを示唆している点です。GPT-3は1750億パラメータですが、トレーニングトークン数は約3000億、つまりパラメータ数の約1.7倍でした。Chinchilla法則によれば、同じ計算予算でもっと多くのトークンでトレーニングすべきだったことになります。これは、多くの組織が莫大なリソースを投じてトレーニングしたモデルが、実は最適な性能を発揮していなかった可能性があることを意味します。
Tatsunori Hashimoto: Chinchilla法則は、具体的には次のように表現できます。与えられた計算予算Cに対して、最適なモデルサイズNoptとデータ量Doptは、Nopt ∝ C^a、Dopt ∝ C^bという関係にあります。ここで、aとbはともに約0.5です。つまり、計算予算を増やすとき、その平方根に比例してモデルサイズとデータ量の両方を増やすべきだということです。これにより、C ≈ 6NDの関係が保たれます。
Tatsunori Hashimoto: より実用的な形で表現すると、Chinchilla法則は「パラメータ1つあたり約20トークンでトレーニングすべき」という指針を提供します。これは初期の推奨値で、後の研究では、この比率はタスクやデータの質によって多少変動することが示されていますが、重要な出発点となります。たとえば、100億パラメータのモデルなら、約2000億トークンでトレーニングすべきだということになります。
Tatsunori Hashimoto: この発見は、AI業界全体に波及効果をもたらしました。Chinchilla論文の発表後、多くの組織がモデル設計戦略を見直しました。MetaのLlamaシリーズは、この法則を意識して設計されています。Llama 2の70億パラメータモデルは2兆トークンでトレーニングされており、これはパラメータ数の約286倍で、Chinchilla法則よりもさらに多くのデータを使用しています。同様に、MistralやGemmaといった他のモデルも、より多くのトークンでトレーニングする傾向にあります。
Tatsunori Hashimoto: Chinchilla法則の意義は、単に技術的な最適化にとどまりません。それは、リソース配分の戦略的な意思決定に影響を与えます。もし計算予算が限られている場合、非常に大きなモデルを少ないデータでトレーニングするよりも、中規模のモデルを十分なデータでトレーニングする方が良いということです。これは特に、リソースが限られた研究グループや企業にとって重要な洞察です。
Tatsunori Hashimoto: また、Chinchilla法則は推論コストの観点からも重要です。同じ性能を達成できるなら、より小さなモデルの方が推論時の計算コストが低くなります。Chinchillaが示したのは、適切にトレーニングされた中規模のモデルが、不十分にトレーニングされた大規模モデルに匹敵する、あるいはそれを上回る性能を発揮できるということです。これは、デプロイメントのコストと効率性の面で大きな利点をもたらします。
Tatsunori Hashimoto: ただし、Chinchilla法則にも限界があることを認識しておく必要があります。この法則は主に、プレトレーニングの検証損失を最小化することに焦点を当てています。しかし、実際の使用シナリオでは、他の要因も考慮する必要があります。たとえば、より大きなモデルは、ファインチューニング後により良い性能を示す可能性があります。また、in-context learning(文脈内学習)の能力は、モデルサイズが大きいほど向上する傾向があります。
Tatsunori Hashimoto: さらに、Chinchilla法則は、一度きりのトレーニングを想定しています。しかし、実際には、モデルは継続的にファインチューニングされたり、複数の異なるタスクに適応されたりします。そのような場合、より大きなベースモデルを持つことには、追加の価値があるかもしれません。このため、一部の組織は、Chinchilla法則が示唆するよりも大きなモデルをトレーニングすることを選択しています。
Tatsunori Hashimoto: それでも、Chinchilla法則は、LLMトレーニングにおける最も影響力のある発見の一つです。それは、科学的な根拠に基づいてリソースを配分し、計算予算を最大限に活用する方法を示してくれます。今日、多くの最先端モデルの設計は、この法則の影響を明確に反映しています。
4.3 モデルサイズとデータ量の最適なバランス
Tatsunori Hashimoto: Chinchilla法則が示した最も重要な洞察は、モデルサイズとトレーニングデータ量を等しくスケールさせるべきだということです。しかし、この「等しく」とは具体的にどういう意味なのか、そして実際のモデル設計にどう応用すべきかを詳しく見ていきましょう。
Tatsunori Hashimoto: Chinchilla論文での分析によると、計算予算Cを2倍にする場合、モデルのパラメータ数Nとトレーニングトークン数Dの両方を約1.4倍(√2倍)にするのが最適です。これにより、C = 6NDの関係が保たれます。より一般的には、計算予算をk倍にする場合、NとDの両方を√k倍にすべきだということです。計算予算を10倍にするなら、NとDをそれぞれ約3.16倍にします。100倍なら、それぞれ10倍にします。
Tatsunori Hashimoto: この等量スケーリングの原則は、パラメータ1つあたりに必要なトークン数という形でも表現できます。Chinchilla論文の初期の推奨値は、パラメータ1つあたり約20トークンでした。つまり、100億パラメータのモデルなら2000億トークン、700億パラメータのモデルなら1.4兆トークンでトレーニングすべきだということです。この比率は、計算最適性を達成するための簡単な経験則として広く使われるようになりました。
Tatsunori Hashimoto: しかし、最近の研究では、この比率はもう少し幅があることが示されています。一部の研究では、パラメータ1つあたり50トークンから100トークン、あるいはそれ以上でトレーニングすることで、さらなる性能向上が得られることが報告されています。これは、「過剰トレーニング」と呼ばれることもありますが、特定のユースケースでは正当化される選択です。
Tatsunori Hashimoto: この比率の選択は、トレードオフを伴います。より多くのトークンでトレーニングすると、プレトレーニングのコストは増加しますが、より小さなモデルで同等の性能を達成できる可能性があります。小さなモデルは、推論時の計算コストが低く、デプロイが容易で、レイテンシも短くなります。したがって、プレトレーニングに多くのリソースを投じて、デプロイメントでのコスト削減を図るという戦略は、多くの実用シナリオで合理的です。
Tatsunori Hashimoto: 具体例で見てみましょう。Llama 2シリーズは、この原則の実践例です。Llama 2の7Bモデルは2兆トークンでトレーニングされており、これはパラメータ数の約286倍です。13Bモデルと70Bモデルも同様に2兆トークンでトレーニングされています。これは、Chinchilla法則の初期推奨値よりもはるかに多くのデータを使用しており、特に小さなモデルでは「過剰トレーニング」に相当します。しかし、これにより、7Bモデルは非常に高い性能を達成し、多くのタスクでより大きなモデルに匹敵する結果を出しています。
Tatsunori Hashimoto: Mistralも興味深い例です。Mistral 7Bは、公式には詳細が明らかにされていませんが、推定では数兆トークンでトレーニングされています。その結果、7Bという比較的小さなサイズにもかかわらず、多くのベンチマークで13Bや30Bクラスのモデルを上回る性能を示しています。これは、十分なデータでトレーニングすれば、小さなモデルでも驚くべき能力を獲得できることを示しています。
Tatsunori Hashimoto: 最適なバランスを決定する際には、いくつかの要因を考慮する必要があります。まず、プレトレーニング予算です。限られた計算予算がある場合、Chinchilla法則に従って、モデルサイズとデータ量をバランスさせることが最も効率的です。しかし、プレトレーニングに十分なリソースがあり、その後の推論コストを最小化したい場合は、より小さなモデルをより多くのデータでトレーニングする戦略が有効です。
Tatsunori Hashimoto: 次に、ターゲットとするユースケースです。in-context learningの能力が重要な場合、より大きなモデルが有利かもしれません。大きなモデルは、より多くの例を文脈内で処理し、より複雑なパターンを認識する能力が高い傾向があります。一方、特定のタスクにファインチューニングすることが前提なら、適切にプレトレーニングされた中規模モデルで十分な場合が多いです。
Tatsunori Hashimoto: データの可用性も重要な制約です。高品質なトレーニングデータには限りがあります。インターネット上の利用可能なテキストデータは膨大ですが、無限ではありません。一部の推定では、高品質な英語テキストデータは約10兆から15兆トークン程度とされています。これを超えてトレーニングを続ける場合、データの重複使用や、低品質データの混入、多言語データの追加などが必要になります。
Tatsunori Hashimoto: 実際、最新の超大規模モデルの中には、15兆トークン以上でトレーニングされているものもあります。これは、データを複数回見ることになります。興味深いことに、適切に管理すれば、データの複数エポックでのトレーニングも効果的であることが示されています。ただし、過度の繰り返しは過学習のリスクを伴います。
Tatsunori Hashimoto: もう一つの考慮事項は、トレーニング時間です。より多くのトークンでトレーニングすることは、より長い時間がかかることを意味します。時間が制約となる場合、たとえば市場投入までの時間が重要な場合、より大きなモデルを短期間でトレーニングする方が実用的かもしれません。これは、計算最適ではありませんが、ビジネス上の制約を考慮した合理的な選択です。
Tatsunori Hashimoto: 継続的な学習やモデルの更新も考慮すべき点です。もし、モデルを定期的に新しいデータで更新する予定があるなら、初期のプレトレーニングでやや控えめなデータ量でスタートし、後で追加トレーニングを行うという戦略も有効です。これは、特に急速に変化する分野や、継続的なデータ収集が可能な環境で有用です。
Tatsunori Hashimoto: 研究コミュニティでは、最適な比率についての議論が続いています。一部の研究者は、現在の主流モデルはまだ十分にトレーニングされていないと主張しています。彼らは、さらに多くのトークンでトレーニングすることで、より小さく、より効率的なモデルを作れると考えています。一方で、他の研究者は、ある点を超えると収穫逓減の法則が働き、追加のトレーニングの利益が減少すると指摘しています。
Tatsunori Hashimoto: 実践的な推奨としては、プレトレーニング予算が限られている場合は、Chinchilla法則の20トークン/パラメータという比率を出発点として使うのが良いでしょう。もし推論効率が重要で、プレトレーニングにより多くのリソースを投じられるなら、50から100トークン/パラメータ、あるいはそれ以上の比率を検討する価値があります。いずれにせよ、小規模な実験で様々な比率をテストし、自分のユースケースに最適なバランスを見つけることが重要です。
Tatsunori Hashimoto: 最後に、この最適化は静的なものではないことを理解しておくことが重要です。ハードウェアの進化、新しいアーキテクチャ、改善されたトレーニング技術などにより、最適なバランスは時間とともに変化する可能性があります。継続的に最新の研究をフォローし、自分のプロジェクトに適用することが、効率的なLLM開発の鍵となります。
4.4 実験結果と実用的な示唆
Tatsunori Hashimoto: Chinchilla法則の理論的な理解だけでなく、実際の実験結果がどのようにこの法則を裏付け、そして実用的な意思決定にどう影響するかを見ていくことが重要です。DeepMindの元の研究と、その後の業界での適用例から、多くの重要な教訓を学ぶことができます。
Tatsunori Hashimoto: Chinchilla論文での最も印象的な結果の一つは、直接的な性能比較です。Gopherという2800億パラメータのモデルは、3000億トークンでトレーニングされました。一方、Chinchillaは700億パラメータ、つまりGopherの4分の1のサイズですが、1.4兆トークン、つまりGopherの約4.7倍のデータでトレーニングされました。両者は同じ計算予算、約5.76×10の23乗FLOPsを使用しています。
Tatsunori Hashimoto: 結果は明確でした。Chinchillaは、ほぼすべての評価タスクでGopherを上回りました。MMLU(Massive Multitask Language Understanding)というベンチマークでは、Gopherが60%の精度だったのに対し、Chinchillaは67%を達成しました。BIG-bench benchmarkでも同様の傾向が見られ、Chinchillaは一貫して優れた性能を示しました。これは、同じ計算予算でも、リソースの配分方法によって大きな性能差が生まれることを明確に示しています。
Tatsunori Hashimoto: さらに興味深いのは、Chinchillaの推論効率です。700億パラメータのモデルは、2800億パラメータのモデルよりもはるかに小さいため、推論時の計算コストは約4分の1です。メモリ使用量も大幅に少なく、より少ないGPUでデプロイできます。つまり、Chinchillaはトレーニングコストは同じでありながら、より良い性能とより低い推論コストを同時に達成したのです。これは、ビジネスの観点から見ても非常に魅力的な結果です。
Tatsunori Hashimoto: この発見は、業界に波及しました。MetaのLlamaシリーズは、Chinchilla法則の影響を明確に受けています。Llama 1は、7B、13B、33B、65Bという複数のサイズで提供されましたが、すべて1兆から1.4兆トークンでトレーニングされています。特に7Bモデルは、パラメータ数の約140倍以上のトークンでトレーニングされており、Chinchilla法則の推奨を大幅に超えています。
Tatsunori Hashimoto: Llama 2では、この傾向がさらに進みました。7B、13B、70Bのすべてのモデルが2兆トークンでトレーニングされています。7Bモデルでは、これはパラメータ数の約286倍に相当します。この「過剰トレーニング」戦略の結果、Llama 2 7Bは驚くべき性能を示し、多くのベンチマークで以前の13Bや30Bクラスのモデルに匹敵する、あるいはそれを上回る結果を出しました。これは、適切にトレーニングされた小規模モデルの潜在能力を示す素晴らしい例です。
Tatsunori Hashimoto: Mistral 7Bも同様の戦略を採用しているようです。正確なトレーニング詳細は公開されていませんが、その性能から判断すると、非常に多くのトークンでトレーニングされていることは明らかです。Mistral 7Bは、多くのベンチマークでLlama 2 13Bを上回り、一部のタスクでは34Bクラスのモデルに匹敵する性能を示しています。7Bというサイズで、これほどの能力を実現できることは、Chinchilla法則の実用的な価値を強く裏付けています。
Tatsunori Hashimoto: ただし、すべての組織がこの戦略を採用しているわけではありません。一部の超大規模モデル、たとえば5000億パラメータを超えるようなモデルでは、Chinchilla法則が推奨するほど多くのトークンでトレーニングすることは現実的でない場合があります。それには、データの可用性、トレーニング時間、そしてコストの制約があります。このような場合、「アンダートレーニング」を意図的に受け入れることになります。
Tatsunori Hashimoto: 実験から得られた重要な洞察の一つは、トレーニングトークン数を増やすことの収穫逓減です。最初の数千億トークンは性能向上に大きく貢献しますが、それを超えると、追加のトークンによる改善は徐々に小さくなります。たとえば、2兆トークンから3兆トークンに増やした場合の改善は、1兆トークンから2兆トークンに増やした場合よりも小さいでしょう。この非線形性を理解することは、実際的な決定を下す上で重要です。
Tatsunori Hashimoto: データの質も重要な要因であることが明らかになっています。単純にトークン数を増やすだけでなく、高品質で多様なデータを使用することが重要です。一部の研究では、慎重にキュレーションされた1兆トークンのデータセットが、フィルタリングされていない2兆トークンのデータセットよりも良い結果をもたらすことが示されています。したがって、Chinchilla法則を適用する際には、データの量だけでなく質にも注意を払うべきです。
Tatsunori Hashimoto: 実用的な示唆として、新しいプロジェクトを開始する際の推奨アプローチを考えてみましょう。まず、利用可能な計算予算を明確にします。次に、Chinchilla法則を使って、その予算で達成可能なモデルサイズとトレーニングトークン数の組み合わせを計算します。もし推論効率が重要なら、やや小さめのモデルでより多くのトークンでトレーニングする方向に寄せます。逆に、in-context learningや少数例学習が重要なら、やや大きめのモデルを選択するかもしれません。
Tatsunori Hashimoto: 小規模な実験を行うことも推奨されます。本格的なトレーニングを始める前に、小規模なモデル、たとえば1億パラメータや10億パラメータのモデルで、異なるデータ量での性能をテストします。これにより、自分のデータセットとタスクに対して、どの程度のトレーニングが必要かの感覚を掴めます。小規模モデルでの実験結果をスケーリング法則で外挿することで、大規模モデルの性能を予測できます。
Tatsunori Hashimoto: タイムラインの制約も考慮する必要があります。もし6ヶ月以内にモデルを完成させる必要があるなら、利用可能な計算リソースから、その期間内にトレーニングできるモデルサイズとデータ量を逆算します。これにより、理論的な最適性と実際的な制約のバランスを取ることができます。
Tatsunori Hashimoto: 継続的な改善の観点も重要です。最初のモデルをChinchilla法則に従ってトレーニングした後、追加のデータが利用可能になったり、計算リソースが増えたりした場合、継続的なトレーニングを行うことができます。この場合、学習率を慎重に調整し、新しいデータと古いデータのバランスを取る必要がありますが、初期モデルをさらに改善する効果的な方法です。
Tatsunori Hashimoto: 最後に、Chinchilla法則は出発点であり、絶対的なルールではないことを理解しておくことが重要です。異なるアーキテクチャ、異なるデータ分布、異なる目的関数では、最適なバランスが変わる可能性があります。自分のユースケースに合わせて実験し、データを収集し、継続的に最適化することが、成功する LLMプロジェクトの鍵です。Chinchilla法則は、その実験を導く強力な指針を提供してくれます。
4.4 実験結果と実用的な示唆
Tatsunori Hashimoto: Chinchilla法則の理論的な理解だけでなく、実際の実験結果がどのようにこの法則を裏付け、そして実用的な意思決定にどう影響するかを見ていくことが重要です。DeepMindの元の研究と、その後の業界での適用例から、多くの重要な教訓を学ぶことができます。
Chinchilla論文での最も印象的な結果の一つは、直接的な性能比較です。Gopherという2800億パラメータのモデルは、3000億トークンでトレーニングされました。一方、Chinchillaは700億パラメータ、つまりGopherの4分の1のサイズですが、1.4兆トークン、つまりGopherの約4.7倍のデータでトレーニングされました。両者は同じ計算予算、約5.76×10の23乗FLOPsを使用しています。
結果は明確でした。Chinchillaは、ほぼすべての評価タスクでGopherを上回りました。MMLU(Massive Multitask Language Understanding)というベンチマークでは、Gopherが60%の精度だったのに対し、Chinchillaは67%を達成しました。BIG-bench benchmarkでも同様の傾向が見られ、Chinchillaは一貫して優れた性能を示しました。これは、同じ計算予算でも、リソースの配分方法によって大きな性能差が生まれることを明確に示しています。
さらに興味深いのは、Chinchillaの推論効率です。700億パラメータのモデルは、2800億パラメータのモデルよりもはるかに小さいため、推論時の計算コストは約4分の1です。メモリ使用量も大幅に少なく、より少ないGPUでデプロイできます。つまり、Chinchillaはトレーニングコストは同じでありながら、より良い性能とより低い推論コストを同時に達成したのです。これは、ビジネスの観点から見ても非常に魅力的な結果です。
この発見は、業界に波及しました。MetaのLlamaシリーズは、Chinchilla法則の影響を明確に受けています。Llama 1は、7B、13B、33B、65Bという複数のサイズで提供されましたが、すべて1兆から1.4兆トークンでトレーニングされています。特に7Bモデルは、パラメータ数の約140倍以上のトークンでトレーニングされており、Chinchilla法則の推奨を大幅に超えています。
Llama 2では、この傾向がさらに進みました。7B、13B、70Bのすべてのモデルが2兆トークンでトレーニングされています。7Bモデルでは、これはパラメータ数の約286倍に相当します。この「過剰トレーニング」戦略の結果、Llama 2 7Bは驚くべき性能を示し、多くのベンチマークで以前の13Bや30Bクラスのモデルに匹敵する、あるいはそれを上回る結果を出しました。これは、適切にトレーニングされた小規模モデルの潜在能力を示す素晴らしい例です。
Mistral 7Bも同様の戦略を採用しているようです。正確なトレーニング詳細は公開されていませんが、その性能から判断すると、非常に多くのトークンでトレーニングされていることは明らかです。Mistral 7Bは、多くのベンチマークでLlama 2 13Bを上回り、一部のタスクでは34Bクラスのモデルに匹敵する性能を示しています。7Bというサイズで、これほどの能力を実現できることは、Chinchilla法則の実用的な価値を強く裏付けています。
ただし、すべての組織がこの戦略を採用しているわけではありません。一部の超大規模モデル、たとえば5000億パラメータを超えるようなモデルでは、Chinchilla法則が推奨するほど多くのトークンでトレーニングすることは現実的でない場合があります。それには、データの可用性、トレーニング時間、そしてコストの制約があります。このような場合、「アンダートレーニング」を意図的に受け入れることになります。
実験から得られた重要な洞察の一つは、トレーニングトークン数を増やすことの収穫逓減です。最初の数千億トークンは性能向上に大きく貢献しますが、それを超えると、追加のトークンによる改善は徐々に小さくなります。たとえば、2兆トークンから3兆トークンに増やした場合の改善は、1兆トークンから2兆トークンに増やした場合よりも小さいでしょう。この非線形性を理解することは、実際的な決定を下す上で重要です。
データの質も重要な要因であることが明らかになっています。単純にトークン数を増やすだけでなく、高品質で多様なデータを使用することが重要です。一部の研究では、慎重にキュレーションされた1兆トークンのデータセットが、フィルタリングされていない2兆トークンのデータセットよりも良い結果をもたらすことが示されています。したがって、Chinchilla法則を適用する際には、データの量だけでなく質にも注意を払うべきです。
実用的な示唆として、新しいプロジェクトを開始する際の推奨アプローチを考えてみましょう。まず、利用可能な計算予算を明確にします。次に、Chinchilla法則を使って、その予算で達成可能なモデルサイズとトレーニングトークン数の組み合わせを計算します。もし推論効率が重要なら、やや小さめのモデルでより多くのトークンでトレーニングする方向に寄せます。逆に、in-context learningや少数例学習が重要なら、やや大きめのモデルを選択するかもしれません。
小規模な実験を行うことも推奨されます。本格的なトレーニングを始める前に、小規模なモデル、たとえば1億パラメータや10億パラメータのモデルで、異なるデータ量での性能をテストします。これにより、自分のデータセットとタスクに対して、どの程度のトレーニングが必要かの感覚を掴めます。小規模モデルでの実験結果をスケーリング法則で外挿することで、大規模モデルの性能を予測できます。
タイムラインの制約も考慮する必要があります。もし6ヶ月以内にモデルを完成させる必要があるなら、利用可能な計算リソースから、その期間内にトレーニングできるモデルサイズとデータ量を逆算します。これにより、理論的な最適性と実際的な制約のバランスを取ることができます。
継続的な改善の観点も重要です。最初のモデルをChinchilla法則に従ってトレーニングした後、追加のデータが利用可能になったり、計算リソースが増えたりした場合、継続的なトレーニングを行うことができます。この場合、学習率を慎重に調整し、新しいデータと古いデータのバランスを取る必要がありますが、初期モデルをさらに改善する効果的な方法です。
最後に、Chinchilla法則は出発点であり、絶対的なルールではないことを理解しておくことが重要です。異なるアーキテクチャ、異なるデータ分布、異なる目的関数では、最適なバランスが変わる可能性があります。自分のユースケースに合わせて実験し、データを収集し、継続的に最適化することが、成功する LLMプロジェクトの鍵です。Chinchilla法則は、その実験を導く強力な指針を提供してくれます。
5. Training optimizations overview(トレーニング最適化の概要)
5.1 最適化が必要な理由
Tatsunori Hashimoto: ここまで、LLMのトレーニングには膨大な計算リソースが必要であることを見てきました。しかし、単にリソースを投入するだけでは不十分です。大規模モデルを効率的にトレーニングするためには、様々な最適化技術が不可欠です。なぜこれらの最適化が必要なのか、具体的な問題点を見ていきましょう。
最も基本的な問題は、単一のGPUではほとんどの大規模モデルをトレーニングできないということです。現代の最も強力なGPUの一つであるNVIDIA A100は、80GBのメモリを搭載しています。これは一見十分に思えるかもしれませんが、実際には全く足りません。70億パラメータのモデルでさえ、単一のA100でトレーニングすることは困難です。
なぜでしょうか。まず、モデルのパラメータそのものがメモリを消費します。70億パラメータをFP32(32ビット浮動小数点数)で保存すると、約28GBのメモリが必要です。しかし、トレーニング時にはパラメータだけを保存すればよいわけではありません。勾配も保存する必要があります。これはパラメータと同じサイズなので、さらに28GB必要です。さらに、AdamやAdamWのような一般的なオプティマイザは、各パラメータに対してモーメンタムと二次モーメントの状態を保持します。これはパラメータサイズの2倍、つまり56GBです。ここまでで、28 + 28 + 56 = 112GBとなり、すでにA100の80GBを超えてしまいます。
そして、これだけではありません。トレーニング中には、アクティベーション、つまり各レイヤーの中間計算結果も保存する必要があります。これらは backward pass で勾配を計算するために必要です。バッチサイズやシーケンス長によっては、アクティベーションだけで数十GBのメモリを消費することがあります。さらに、一時的なバッファやカーネルの作業スペースなども必要です。
具体的な数字で見てみましょう。175億パラメータのモデル、つまりGPT-3クラスのモデルを考えます。FP32でトレーニングする場合、パラメータだけで700GB、勾配で700GB、オプティマイザの状態で1400GB、合計で約2800GBのメモリが必要です。これは、A100の80GBメモリの35倍に相当します。明らかに、単一デバイスでは到底扱えません。
混合精度トレーニングを使ってFP16に削減しても、メモリ使用量は半分程度にしかならず、依然として単一GPUには収まりません。70億パラメータのモデルでさえ、FP16でも約60GBから80GB以上のメモリが必要になることが多く、バッチサイズやシーケンス長によっては単一A100でも厳しい状況です。
メモリの問題だけではありません。トレーニング速度も重要な課題です。単一GPUでトレーニングできたとしても、完了までに何ヶ月も、場合によっては何年もかかるようでは実用的ではありません。たとえば、70億パラメータのモデルを2兆トークンでトレーニングする場合、必要な計算量は約8.4×10の22乗FLOPsです。A100のFP16ピーク性能が312 TFLOPSで、実効性能がその50%、つまり156 TFLOPSだとすると、必要な時間は約5.38×10の8乗秒、つまり約6227日、17年以上かかる計算になります。これは明らかに許容できません。
したがって、並列化が必須です。複数のGPUを使って計算を分散させ、トレーニング時間を短縮する必要があります。しかし、単純に複数のGPUを使えばよいというわけではありません。効率的な並列化戦略がなければ、通信オーバーヘッドやその他の非効率性により、スケーリング効率が著しく低下します。理想的には、GPUの数を2倍にすればトレーニング時間が半分になるべきですが、実際には様々な要因でこの線形スケーリングは達成困難です。
通信がボトルネックになることも大きな問題です。複数のGPU間でデータを転送する際、ネットワーク帯域幅が制約となります。特に、勾配の同期やモデルの重みの共有を行う際には、大量のデータを転送する必要があります。GPUのメモリ帯域幅は非常に高速ですが、GPU間の通信、特に異なるノード間の通信は、それよりもはるかに遅くなります。効率的な並列化戦略では、この通信コストを最小化する必要があります。
計算効率の問題もあります。GPU は理論的には非常に高いピーク性能を持っていますが、実際のワークロードでそのピーク性能を達成することは困難です。メモリアクセスパターンが非効率だったり、小さな行列演算が多かったり、計算とメモリアクセスのバランスが悪かったりすると、GPUの計算ユニットが十分に活用されず、実効性能が大幅に低下します。Model FLOPs Utilization(MFU)を向上させ、ハードウェアの潜在能力を最大限引き出すことが重要です。
コストの観点からも最適化は重要です。GPUレンタルは非常に高価で、A100 1台で1時間あたり4ドルから5ドルかかります。数百台、数千台のGPUを数週間から数ヶ月稼働させると、コストは数百万ドルに達します。トレーニング時間を半分に短縮できれば、コストも半分になります。あるいは、同じ予算でより大きなモデルや、より多くのデータでトレーニングできるようになります。効率的なトレーニングは、経済的にも極めて重要なのです。
環境への影響も無視できません。大規模モデルのトレーニングには膨大な電力が消費され、それに伴うCO2排出も問題となっています。トレーニング効率を向上させることは、環境負荷を削減することにも直結します。より少ないエネルギーで同じ結果を達成できれば、それは持続可能なAI研究に貢献します。
実験のイテレーション速度も重要な要因です。研究開発では、様々なハイパーパラメータやアーキテクチャを試す必要があります。各実験に数週間かかるようでは、十分な探索ができません。最適化技術によりトレーニング時間を短縮できれば、より多くの実験を行い、より良いモデルを発見できる可能性が高まります。
これらすべての課題に対処するために、様々な最適化技術が開発されてきました。データ並列化、モデル並列化、パイプライン並列化といった並列化戦略。ZeROのようなメモリ効率化手法。Flash Attentionのような計算効率化アルゴリズム。混合精度トレーニングや量子化のような数値精度の最適化。これらの技術を組み合わせることで、大規模モデルのトレーニングが実現可能になり、より効率的になるのです。これから、これらの技術を一つずつ詳しく見ていきます。
5.2 主要な最適化手法の分類
Tatsunori Hashimoto: LLMトレーニングの最適化技術は多岐にわたりますが、それらを体系的に理解するために、主要なカテゴリーに分類して見ていきましょう。これらの技術は、それぞれ異なる問題に対処し、多くの場合は組み合わせて使用されます。
まず、最も基本的で重要なカテゴリーは並列化手法です。これは、複数のGPUやTPUを使って計算を分散させる技術で、大きく3つのタイプに分けられます。1つ目はデータ並列化(Data Parallelism)です。これは最もシンプルな並列化戦略で、各GPUがモデルの完全なコピーを保持し、異なるデータバッチを処理します。各GPUは独立してforward passとbackward passを実行し、最後に勾配を集約してパラメータを更新します。実装が比較的簡単で、多くのフレームワークでデフォルトでサポートされています。
2つ目はモデル並列化(Model Parallelism)です。これは、モデル自体を複数のGPUに分割する手法です。単一GPUにモデル全体が収まらない場合に必須となります。モデル並列化にはさらに2つのサブタイプがあります。テンソル並列化(Tensor Parallelism)は、個々のレイヤー内の行列演算を複数のGPUに分割します。たとえば、大きな行列乗算を複数の小さな行列乗算に分解し、それぞれを異なるGPUで実行します。パイプライン並列化(Pipeline Parallelism)は、モデルのレイヤーを複数のステージに分割し、各ステージを異なるGPUに配置します。これにより、データが連続的にパイプラインを流れていくような処理が可能になります。
3つ目の並列化手法として、最近ではシーケンス並列化(Sequence Parallelism)も注目されています。これは、特に長いシーケンスを処理する際に、シーケンス次元を複数のGPUに分割する手法です。ただし、これは比較的新しい技術で、すべての状況で適用できるわけではありません。
次に重要なカテゴリーは、メモリ効率化手法です。これらは、限られたGPUメモリをより効果的に使用するための技術です。最も重要なものの一つがZeRO(Zero Redundancy Optimizer)です。ZeROは、従来のデータ並列化で発生する冗長なメモリ使用を削減します。通常のデータ並列化では、各GPUがモデルのパラメータ、勾配、オプティマイザの状態すべてのコピーを保持しますが、ZeROはこれらを複数のGPUに分割して保存することで、メモリ使用量を大幅に削減します。
Gradient checkpointing(勾配チェックポイント)も重要なメモリ効率化手法です。通常、backward passで勾配を計算するために、forward pass中のすべてのアクティベーションを保存しておく必要があります。しかし、gradient checkpointingでは、一部のアクティベーションのみを保存し、必要に応じて再計算します。これにより、メモリ使用量を大幅に削減できますが、計算時間は若干増加します。メモリと計算時間のトレードオフです。
Activation recomputation(アクティベーション再計算)もこのカテゴリーに含まれます。これはgradient checkpointingと似た概念ですが、より細かい粒度で適用されることがあります。特定のレイヤーやモジュールのアクティベーションを保存せず、必要時に再計算することで、メモリを節約します。
3つ目の主要カテゴリーは、計算効率化手法です。これらは、同じ計算をより高速に実行するための技術です。Flash Attentionはその代表例で、アテンション機構の計算をメモリ階層を考慮して最適化することで、大幅な高速化を実現します。標準的なアテンション実装では、中間行列を高帯域メモリ(HBM)に書き込みますが、Flash Attentionはこれをオンチップのより高速なSRAMで処理し、メモリアクセスを削減します。
カーネル融合(Kernel Fusion)も重要な計算効率化技術です。複数の小さな演算を一つの大きなカーネルにまとめることで、メモリアクセスのオーバーヘッドを削減します。たとえば、行列乗算の後にバイアス加算とアクティベーション関数を適用する場合、これらを別々のカーネルで実行するのではなく、一つの融合されたカーネルで実行することで効率が向上します。
4つ目のカテゴリーは、数値精度の最適化です。混合精度トレーニング(Mixed Precision Training)は、計算の一部を低精度(FP16やBF16)で実行することで、メモリ使用量を削減し、計算速度を向上させます。ただし、数値的安定性を保つために、マスターウェイトはFP32で保持し、損失スケーリングなどの技術を使用します。
量子化(Quantization)は、さらに低いビット精度(8ビット、4ビット、場合によっては2ビット)を使用する技術です。これは主に推論時に使用されますが、最近ではQLoRAのように、トレーニング時にも適用される手法が登場しています。量子化により、メモリ使用量と計算コストを大幅に削減できますが、精度の低下とのトレードオフがあります。
5つ目のカテゴリーは、通信最適化です。並列トレーニングでは、GPU間の通信が大きなボトルネックになることがあります。勾配圧縮(Gradient Compression)は、転送するデータ量を削減するために、勾配をより少ないビットで表現したり、スパース化したりする技術です。これにより、通信時間を短縮できますが、収束性能に影響を与える可能性があります。
通信とコンピュートのオーバーラップ(Overlap of Communication and Computation)も重要な最適化です。backward passで勾配を計算しながら、すでに計算された勾配を並行して通信することで、通信の待ち時間を隠蔽します。これにより、全体のトレーニング時間を短縮できます。
6つ目のカテゴリーは、ハイパーパラメータとスケジューリングの最適化です。学習率スケジューリング、バッチサイズの動的調整、ウォームアップ戦略など、トレーニングプロセス自体を最適化する技術です。これらは直接的なハードウェア効率化ではありませんが、収束速度と最終的なモデル性能に大きく影響します。
最後に、システムレベルの最適化があります。これには、効率的なデータローディング、プリフェッチング、キャッシング戦略などが含まれます。GPUが計算に集中できるよう、データの準備や転送がボトルネックにならないようにする工夫です。また、チェックポイントの効率的な保存と読み込み、障害からの自動回復機構なども、大規模で長時間のトレーニングジョブでは重要です。
これらの最適化手法は、相互に補完的です。実際の大規模トレーニングでは、複数の技術を組み合わせて使用します。たとえば、データ並列化とモデル並列化を組み合わせたハイブリッド並列化、ZeROと混合精度トレーニングの組み合わせ、Flash Attentionとgradient checkpointingの併用などが一般的です。これから、これらの主要な技術について、それぞれ詳しく見ていきます。
5.3 メモリとコンピュート効率のトレードオフ
Tatsunori Hashimoto: 最適化技術を理解する上で極めて重要なのは、メモリ効率と計算効率の間には本質的なトレードオフが存在するということです。多くの場合、一方を改善しようとすると、もう一方が犠牲になります。このトレードオフを理解し、自分のユースケースに応じて適切なバランスを選択することが、効率的なトレーニングの鍵となります。
最も典型的な例は、gradient checkpointingです。この技術は、メモリ使用量を大幅に削減できます。通常、forward pass中に生成されるすべてのアクティベーションを保存すると、大量のメモリが必要になります。特に大きなバッチサイズや長いシーケンスを扱う場合、アクティベーションだけで数十GBのメモリを消費することがあります。Gradient checkpointingでは、一部のアクティベーションのみを保存し、backward pass中に必要に応じて再計算します。
たとえば、24層のTransformerモデルで、4層ごとにチェックポイントを設定すると、保存するアクティベーションは約4分の1になります。これにより、メモリ使用量を大幅に削減できます。実際、gradient checkpointingを使用することで、同じGPUでバッチサイズを2倍から3倍に増やせることがよくあります。しかし、代償として、backward pass中にforward passの一部を再実行する必要があるため、計算時間が約20%から33%増加します。
この計算時間の増加は、実は思ったほど悪くありません。なぜなら、バッチサイズを増やせることで、GPU利用率が向上し、全体のスループットが改善する場合が多いからです。小さなバッチサイズでは、GPUの計算ユニットが十分に活用されず、効率が低下します。Gradient checkpointingによりバッチサイズを2倍にできれば、計算時間が33%増加しても、単位時間あたりに処理できるトークン数は増加する可能性があります。したがって、これは多くの場合、受け入れられるトレードオフです。
Flash Attentionは、逆の例を提供します。これは計算効率を向上させる技術ですが、実装の複雑さが増します。標準的なアテンション実装は、中間行列(アテンションスコアなど)を高帯域メモリ(HBM)に書き込みます。これは実装がシンプルで理解しやすいですが、メモリアクセスのオーバーヘッドが大きくなります。Flash Attentionは、タイリングとカーネル融合を使用して、これらの中間行列をオンチップのSRAMに保持し、HBMへのアクセスを最小化します。
この最適化により、アテンション計算は2倍から4倍高速になり、メモリ使用量も削減されます。しかし、実装は大幅に複雑になります。低レベルのCUDAカーネルを書く必要があり、異なるハードウェアやシーケンス長に対して調整が必要です。また、すべての操作が融合されているため、デバッグが困難になります。それでも、大規模トレーニングでは、この性能向上は実装の複雑さを正当化するに十分な価値があります。
モデル並列化とデータ並列化の選択も、トレードオフを伴います。データ並列化は実装がシンプルで、通信パターンも比較的単純です。各GPUは独立して計算を行い、backward passの最後に勾配を集約するだけです。しかし、モデル全体を各GPUに保持する必要があるため、メモリ使用量が多くなります。非常に大きなモデルでは、単一GPUに収まらないため、データ並列化だけでは対応できません。
モデル並列化、特にテンソル並列化は、モデルを複数のGPUに分割できるため、より大きなモデルをトレーニングできます。しかし、各演算ごとにGPU間で通信が必要になるため、通信オーバーヘッドが大きくなります。特に、GPUが異なるノードに配置されている場合、ネットワークレイテンシが性能のボトルネックになります。したがって、テンソル並列化は通常、同一ノード内の高速インターコネクト(NVLinkなど)で接続されたGPU間でのみ使用され、ノード間ではデータ並列化やパイプライン並列化を使用します。
パイプライン並列化は、テンソル並列化よりも通信頻度が少ないため、ノード間でも使用できます。各ステージは連続するレイヤーのグループを処理し、前のステージからアクティベーションを受け取り、次のステージに渡します。しかし、パイプラインバブル、つまりGPUがアイドル状態になる時間が発生します。パイプラインの最初と最後のステージでは、他のステージが処理を完了するのを待つ必要があるからです。このバブルを最小化するために、マイクロバッチングなどの技術が使われますが、完全には解消できません。
ZeROは、メモリ効率とシンプルさのバランスを取る優れた例です。ZeRO Stage 1では、オプティマイザの状態のみを分割します。これにより、メモリ使用量が削減されますが、通信オーバーヘッドは最小限です。ZeRO Stage 2では、勾配も分割され、さらなるメモリ削減が得られますが、通信量は増加します。ZeRO Stage 3では、パラメータ自体も分割され、最大のメモリ削減が得られますが、通信オーバーヘッドも最大になります。
実際には、ZeRO Stage 2が最もバランスが取れていることが多く、広く使用されています。Stage 3は、非常に大きなモデルで絶対的にメモリが不足する場合にのみ使用されます。なぜなら、forward pass中にもパラメータを通信する必要があるため、オーバーヘッドが大きいからです。このように、ZeROのステージ選択自体が、メモリと通信のトレードオフを調整する手段となっています。
混合精度トレーニングも興味深いトレードオフを提供します。FP16やBF16を使用することで、メモリ使用量が半分になり、多くのハードウェアでは計算も高速化されます。NVIDIA のTensor Coreは、低精度演算で大幅に高いスループットを提供します。しかし、数値的安定性の懸念があります。FP16の表現可能な数値範囲は限られており、非常に小さな勾配がアンダーフローする可能性があります。
これに対処するため、損失スケーリングやマスターウェイトの保持など、追加の技術が必要になります。これらは実装の複雑さを増し、わずかな計算オーバーヘッドを追加します。BF16(Brain Float 16)は、FP16よりも広い指数範囲を持つため、損失スケーリングなしで使用できることが多く、実装がシンプルになります。しかし、仮数部の精度はFP16より低いため、一部のタスクでは精度に影響が出る可能性があります。
バッチサイズの選択も、重要なトレードオフです。大きなバッチサイズは、GPU利用率を向上させ、並列化効率を高めます。また、勾配の推定がより安定し、収束が安定することがあります。しかし、大きなバッチサイズはメモリを多く消費し、汎化性能に悪影響を与える可能性があります。さらに、非常に大きなバッチサイズでは、学習率などのハイパーパラメータを調整する必要があります。
実践的には、利用可能なメモリを最大限活用できるバッチサイズを選び、必要に応じてgradient accumulationを使用してeffective batch sizeを増やすアプローチが一般的です。Gradient accumulationでは、複数のミニバッチで勾配を累積してから重みを更新します。これにより、メモリを増やすことなく、大きなeffective batch sizeを実現できますが、更新の頻度が減るため、トレーニング時間は長くなります。
通信圧縮も明確なトレードオフがあります。勾配を圧縮して転送することで、通信時間を短縮できます。しかし、圧縮と解凍には計算コストがかかり、さらに圧縮による情報損失が収束速度や最終性能に影響する可能性があります。高い圧縮率を使用すると、通信時間は大幅に削減されますが、収束に必要なイテレーション数が増加する可能性があります。
これらすべてのトレードオフを考慮して、最適な設定を見つけることが、効率的なトレーニングの鍵です。一般的には、まずメモリ制約を満たすための最適化を適用し、その上で計算効率を最大化する戦略が有効です。たとえば、まずZeROやgradient checkpointingでメモリ使用量を管理し、その後Flash Attentionや混合精度トレーニングで速度を向上させます。小規模な実験で様々な組み合わせをテストし、自分のハードウェア構成とモデルサイズに最適なバランスを見つけることが推奨されます。
6. Data parallelism with ZeRO
6.1 データ並列化の基本原理
Tatsunori Hashimoto: データ並列化は、大規模モデルのトレーニングで最も基本的で広く使われている並列化戦略です。その概念はシンプルですが、効果的です。まずは、標準的なデータ並列化がどのように機能するかを理解しましょう。
データ並列化の基本的なアイデアは、複数のGPUが同じモデルのコピーを持ち、異なるデータバッチを処理するというものです。たとえば、4つのGPUでトレーニングする場合、各GPUはモデル全体の完全なコピーを保持します。トレーニングの各ステップで、ミニバッチ全体を4つのサブバッチに分割し、各GPUが1つずつ処理します。
具体的なプロセスを見てみましょう。まず、各GPUは自分に割り当てられたデータバッチでforward passを実行します。このとき、各GPUは独立して計算を行い、他のGPUとの通信は必要ありません。たとえば、GPU 0はバッチ0-31のサンプルを処理し、GPU 1はバッチ32-63を処理し、という具合です。各GPUは自分のデータに対して損失を計算します。
次に、backward passが実行されます。これも各GPUで独立して行われます。各GPUは、自分が処理したデータに基づいて勾配を計算します。ここまでは、GPUは完全に独立して動作しており、通信は発生していません。これが、データ並列化が実装しやすく、スケーラビリティが高い理由です。
しかし、backward passが完了した後、重要なステップが必要になります。それが勾配の集約です。各GPUは異なるデータバッチで計算した勾配を持っていますが、モデルを一貫して更新するためには、これらの勾配を平均する必要があります。ここでAll-reduceという集団通信操作が使われます。
All-reduceは、すべてのGPUが持つ勾配を合計し、その結果をすべてのGPUに配布する操作です。たとえば、4つのGPUがそれぞれ勾配g1、g2、g3、g4を計算した場合、All-reduceの後、すべてのGPUは(g1 + g2 + g3 + g4) / 4という平均勾配を持つことになります。この平均勾配を使って、各GPUは自分のモデルのコピーを更新します。
All-reduceの効率は、ネットワークトポロジーとアルゴリズムに依存します。最も単純な実装では、1つのGPUがすべての勾配を収集し、平均を計算して、結果を他のGPUに配布します。しかし、これは非効率です。より効率的なアルゴリズムとして、Ring All-reduceがあります。これは、GPUをリング状に配置し、各GPUが隣接するGPUとのみ通信する方法です。N個のGPUがある場合、Ring All-reduceは2(N-1)のステップで完了し、各ステップで転送されるデータ量は全体の1/Nです。
もう一つの効率的なアルゴリズムは、Tree All-reduceです。これは、GPUを木構造で配置し、階層的に勾配を集約します。ネットワークトポロジーによっては、Ring All-reduceよりも高速な場合があります。NCCLという NVIDIA の通信ライブラリは、ハードウェアに応じて最適なAll-reduceアルゴリズムを自動的に選択します。
データ並列化の大きな利点は、実装のシンプルさです。PyTorchでは、DistributedDataParallelやDataParallelというモジュールを使うだけで、自動的にデータ並列化が実現されます。モデルコード自体をほとんど変更する必要がありません。また、スケーリング効率も比較的高く、通信はbackward pass後の一度だけなので、計算と通信の比率が良好です。
もう一つの利点は、デバッグのしやすさです。各GPUは同じコードを実行し、同じモデルを持っているため、動作を理解しやすく、問題を特定しやすいです。また、単一GPUでのトレーニングと挙動が似ているため、小規模で検証してから大規模にスケールすることが容易です。
データ並列化は、特に同一ノード内の複数GPUでは非常に効率的です。現代のサーバーは、通常8個のGPUを搭載しており、これらはNVLinkという高速インターコネクトで接続されています。NVLinkの帯域幅は非常に高く、最大600 GB/s程度に達します。このような環境では、All-reduceのオーバーヘッドは最小限で、ほぼ線形のスケーリングが可能です。
ノード間でのデータ並列化も可能ですが、効率は低下します。ノード間の通信はInfiniBandやEthernetなどのネットワークを経由するため、帯域幅が低く、レイテンシも高くなります。それでも、勾配通信は一度だけなので、計算が十分に重い場合は、許容できるオーバーヘッドです。実際、数百から数千のGPUでのデータ並列化も珍しくありません。
しかし、標準的なデータ並列化には大きな欠点があります。それは、メモリの冗長性です。各GPUがモデルの完全なコピー、勾配の完全なコピー、そしてオプティマイザの状態の完全なコピーを保持する必要があります。たとえば、8つのGPUでトレーニングする場合、モデルのパラメータ、勾配、オプティマイザの状態が8回重複して保存されます。これは非常に無駄です。
この冗長性の問題を解決するために開発されたのがZeROです。ZeROは、データ並列化の基本的な構造を保ちながら、メモリの冗長性を削減する革新的なアプローチです。標準的なデータ並列化では、各GPUが同じ情報を保持していますが、ZeROはこの情報を複数のGPUに分散させ、必要に応じて通信することで、メモリ使用量を大幅に削減します。これについて、次のセクションで詳しく見ていきましょう。
6.2 ZeRO(Zero Redundancy Optimizer)の仕組み
Tatsunori Hashimoto: ZeRO、正式にはZero Redundancy Optimizerは、Microsoftの研究者によって開発された画期的な最適化技術です。ZeROの核心的なアイデアは、標準的なデータ並列化における冗長性を排除することで、メモリ効率を劇的に向上させることです。その仕組みを詳しく見ていきましょう。
まず、トレーニング中のメモリ使用量を理解する必要があります。大規模モデルをトレーニングする際、メモリは主に3つのカテゴリーに分類されます。1つ目はモデルの状態(Model States)です。これには、パラメータ自体、勾配、そしてオプティマイザの状態が含まれます。2つ目はアクティベーション(Activations)で、forward pass中の各レイヤーの出力を保存したものです。3つ目は一時的なバッファやワークスペースです。
モデルの状態の内訳をより詳しく見てみましょう。まずパラメータがあります。70億パラメータのモデルをFP16で保存すると、約14GBです。次に勾配があります。これもパラメータと同じサイズなので、14GBです。そして、オプティマイザの状態があります。AdamやAdamWのような一般的なオプティマイザは、各パラメータに対してモーメンタムと二次モーメントを保持します。これらはFP32で保存されることが多く、パラメータサイズの2倍、つまり28GBになります。合計すると、14 + 14 + 28 = 56GBです。
標準的なデータ並列化では、この56GBが各GPUで完全に重複しています。8つのGPUでトレーニングする場合、実質的に56 × 8 = 448GBのメモリが使用されますが、その大部分は冗長です。ZeROは、この冗長性に着目しました。
ZeROの基本的な洞察は次の通りです。データ並列化では、各GPUは異なるデータバッチを処理していますが、最終的にはすべてのGPUが同じ勾配(の平均)と同じパラメータを持ちます。それならば、なぜすべてのGPUがこれらすべてを保存する必要があるのでしょうか。代わりに、これらを複数のGPUに分割して保存し、必要なときだけ通信すれば、メモリを大幅に節約できるはずです。
ZeROは、この分割を段階的に適用する3つのステージを提供します。各ステージは、異なるレベルのメモリ削減と通信オーバーヘッドのトレードオフを提供します。
まず、ZeRO Stage 1について説明しましょう。これは最も保守的なアプローチで、オプティマイザの状態のみを分割します。各GPUは、全パラメータの一部に対するオプティマイザの状態のみを保持します。たとえば、8つのGPUがある場合、各GPUはオプティマイザの状態の8分の1を保持します。パラメータと勾配は、依然として各GPUで完全に保持されます。
Stage 1の動作を具体的に見てみましょう。Backward passの後、All-reduceで勾配を集約します。この時点で、すべてのGPUは平均勾配を持っています。次に、パラメータの更新段階で、各GPUは自分が担当するパラメータのサブセットのみを更新します。たとえば、GPU 0はパラメータ0から12.5%を更新し、GPU 1は12.5%から25%を更新し、といった具合です。更新後、各GPUは自分が更新したパラメータを他のGPUにブロードキャストし、すべてのGPUが最新のパラメータを持つようにします。
Stage 1のメモリ削減効果を計算してみましょう。オプティマイザの状態は通常、パラメータサイズの2倍です。70億パラメータのモデルでは28GBです。8つのGPUで分割すると、各GPUは28 / 8 = 3.5GBのオプティマイザの状態のみを保持します。パラメータ(14GB)と勾配(14GB)は依然として各GPUで保持されるため、各GPUのモデル状態は14 + 14 + 3.5 = 31.5GBとなり、元の56GBから約44%削減されます。
追加の通信コストは比較的小さいです。パラメータ更新後のAll-gather操作(更新されたパラメータを収集して配布する操作)が必要ですが、これは backward pass後のAll-reduceと同じ通信量です。したがって、通信コストは約2倍になりますが、これは多くの場合、許容可能です。
次に、ZeRO Stage 2を見てみましょう。Stage 2では、オプティマイザの状態に加えて、勾配も分割します。各GPUは、全パラメータの一部に対する勾配のみを保持します。これにより、さらなるメモリ削減が可能になります。
Stage 2の動作は少し複雑です。Backward passでは、各GPUは最初に完全な勾配を計算します。しかし、All-reduceの代わりに、Reduce-scatterという操作を使用します。Reduce-scatterは、すべての勾配を合計しますが、結果を全GPUに配布するのではなく、各GPUに異なる部分を配布します。たとえば、GPU 0は勾配の最初の8分の1を受け取り、GPU 1は次の8分の1を受け取り、といった具合です。
各GPUは、自分が受け取った勾配を使って、対応するパラメータのサブセットを更新します。更新後、各GPUは自分が更新したパラメータを他のGPUにブロードキャストします。これにより、すべてのGPUが最新の完全なパラメータセットを持つようになります。
Stage 2のメモリ削減は更に大きくなります。勾配も8分の1になるため、各GPUのモデル状態は14(パラメータ)+ 1.75(勾配の8分の1)+ 3.5(オプティマイザの状態の8分の1)= 19.25GBとなり、元の56GBから約66%削減されます。
通信パターンはStage 1よりもわずかに複雑ですが、総通信量は同じです。Reduce-scatterとAll-gatherの組み合わせは、All-reduceと同等の通信量です。したがって、Stage 2は、Stage 1と同じ通信コストで、より大きなメモリ削減を実現します。これが、Stage 2が実践で最も広く使用される理由です。
最後に、ZeRO Stage 3を見てみましょう。これは最も積極的なアプローチで、パラメータ自体も分割します。各GPUは、全パラメータの一部のみを保持します。これにより、最大のメモリ削減が可能になりますが、通信オーバーヘッドも最大になります。
Stage 3では、forward pass中にもパラメータの通信が必要になります。各レイヤーを計算する前に、そのレイヤーのパラメータをすべてのGPUから収集する必要があります。計算後、収集したパラメータを破棄します。Backward passでも同様で、各レイヤーの勾配を計算する前に、パラメータを収集し、計算後に破棄します。
Stage 3のメモリ削減は最大です。各GPUのモデル状態は、1.75(パラメータの8分の1)+ 1.75(勾配の8分の1)+ 3.5(オプティマイザの状態の8分の1)= 7GBとなり、元の56GBから約87.5%削減されます。これは驚異的な削減率です。
しかし、通信コストは大幅に増加します。Forward passとbackward passの両方で、各レイヤーごとにパラメータのAll-gather操作が必要になります。レイヤー数が多い場合、これは大きなオーバーヘッドになります。特に、ノード間の通信が遅い環境では、Stage 3のオーバーヘッドは無視できません。
実践では、Stage 3は、モデルが非常に大きくて、Stage 2でも単一GPUに収まらない場合にのみ使用されます。たとえば、1750億パラメータのGPT-3クラスのモデルや、それ以上の規模のモデルでは、Stage 3が必要になることがあります。中規模のモデル、たとえば70億から130億パラメータ程度では、Stage 2で十分なことが多いです。
ZeROの美しさは、その柔軟性にあります。ユースケースに応じて、適切なステージを選択できます。メモリに余裕があり、通信を最小化したい場合はStage 1を使用し、バランスを取りたい場合はStage 2を、極限のメモリ効率が必要な場合はStage 3を使用します。さらに、ZeROは他の最適化技術、たとえば混合精度トレーニングやgradient checkpointingと組み合わせることができ、相乗効果を生み出します。
6.3 ZeROのステージ(Stage 1, 2, 3)
Tatsunori Hashimoto: 前のセクションでZeROの基本的な仕組みを説明しましたが、ここでは3つのステージをより詳細に比較し、それぞれの適用場面と実装上の考慮事項を見ていきましょう。各ステージは、メモリ削減、通信オーバーヘッド、実装の複雑さの間で異なるトレードオフを提供します。
ZeRO Stage 1は、最も保守的で、導入しやすいオプションです。オプティマイザの状態のみを分割するため、標準的なデータ並列化からの移行が比較的スムーズです。具体的な数字で見てみましょう。N個のGPUでトレーニングする場合、各GPUのメモリ使用量は次のようになります。パラメータはΨバイト、勾配もΨバイト、オプティマイザの状態は通常2Ψバイトです(AdamやAdamWの場合)。Stage 1では、オプティマイザの状態が2Ψ/Nに削減されるため、総メモリは Ψ + Ψ + 2Ψ/N です。
8つのGPUを使用する場合、Stage 1のメモリ削減率を計算できます。元のメモリ使用量は4Ψ(パラメータ、勾配、オプティマイザの状態の合計)です。Stage 1では、Ψ + Ψ + 2Ψ/8 = 2.25Ψとなり、約44%の削減です。16個のGPUでは、2Ψ/16 = 0.125Ψとなり、2.125Ψで約47%の削減です。GPU数が増えるほど、削減率は向上しますが、オプティマイザの状態のみを分割しているため、改善には限界があります。
Stage 1の通信パターンは比較的シンプルです。Backward pass後、標準的なAll-reduceで勾配を集約します。次に、各GPUが自分の担当するパラメータを更新します。更新後、All-gather操作で更新されたパラメータを全GPUに配布します。この追加のAll-gatherは、All-reduceと同じ通信量です。したがって、通信コストは標準的なデータ並列化の約2倍ですが、これは多くの場合、計算時間に比べて小さなオーバーヘッドです。
Stage 1は、メモリが比較的余裕がある場合や、ノード間の通信が遅い環境で特に有用です。たとえば、70億パラメータのモデルを8つのA100(各80GB)でトレーニングする場合、Stage 1だけで十分にメモリに収まることがあります。通信オーバーヘッドが最小なため、トレーニング速度への影響も小さく抑えられます。
ZeRO Stage 2は、実践で最も広く使用されているステージです。オプティマイザの状態に加えて勾配も分割することで、大幅なメモリ削減を実現しながら、通信オーバーヘッドを許容範囲に抑えます。Stage 2での各GPUのメモリ使用量は、Ψ + Ψ/N + 2Ψ/N = Ψ + 3Ψ/N です。
8つのGPUでは、Ψ + 3Ψ/8 = 1.375Ψとなり、元の4Ψから約66%の削減です。16個のGPUでは、Ψ + 3Ψ/16 = 1.1875Ψで、約70%の削減です。Stage 2では、GPU数が増えるほど、削減率が大きく向上します。これは、勾配とオプティマイザの状態の両方がGPU数に比例して削減されるためです。
Stage 2の通信パターンは、Stage 1よりもわずかに複雑ですが、効率的です。Backward pass後、All-reduceの代わりにReduce-scatter操作を使用します。Reduce-scatterは、すべてのGPUの勾配を合計し、各GPUに異なる部分を配布します。これにより、各GPUは自分が担当するパラメータの勾配のみを受け取ります。次に、各GPUはこれらのパラメータを更新し、All-gather操作で更新されたパラメータを全GPUに配布します。
重要なのは、Reduce-scatterとAll-gatherの組み合わせは、All-reduceと同じ通信量だということです。数学的に、Reduce-scatter + All-gatherは、All-reduceと等価です。したがって、Stage 2は、Stage 1と同じ総通信量で、より大きなメモリ削減を実現します。これが、Stage 2が「スイートスポット」と見なされる理由です。
Stage 2は、中規模から大規模のモデル、たとえば70億から700億パラメータの範囲で特に効果的です。この範囲のモデルでは、Stage 2により、より少ないGPUでトレーニングできるようになり、コストを大幅に削減できます。また、同じGPU数でも、より大きなバッチサイズやより長いシーケンス長を使用できるようになります。
ZeRO Stage 3は、最も積極的なメモリ最適化を提供します。パラメータ、勾配、オプティマイザの状態のすべてを分割します。各GPUのメモリ使用量は、Ψ/N + Ψ/N + 2Ψ/N = 4Ψ/N です。これは、元のメモリ使用量4ΨをN個のGPUで均等に分割したことになります。
8つのGPUでは、4Ψ/8 = 0.5Ψとなり、元の4Ψから87.5%の削減です。16個のGPUでは、4Ψ/16 = 0.25Ψで、93.75%の削減です。32個のGPUでは、4Ψ/32 = 0.125Ψで、96.875%の削減です。Stage 3のメモリ削減は、GPU数に完全に線形にスケールします。これにより、理論的には、十分な数のGPUがあれば、どれほど大きなモデルでもトレーニングできます。
しかし、Stage 3には大きなトレードオフがあります。通信オーバーヘッドが大幅に増加するのです。Forward pass中、各レイヤーを計算する前に、そのレイヤーのパラメータをAll-gather操作で収集する必要があります。計算が完了したら、メモリを解放するために、収集したパラメータを破棄します。次のレイヤーでも同じプロセスを繰り返します。
Backward passでも同様のことが起こります。各レイヤーの勾配を計算する前に、そのレイヤーのパラメータをAll-gatherで収集します。勾配計算後、パラメータを破棄し、計算された勾配の自分の担当部分のみを保持します。モデルにL個のレイヤーがある場合、forward passとbackward passの両方でL回のAll-gather操作が必要になります。合計で2L回です。
この頻繁な通信は、特にノード間の通信が遅い環境では、大きなボトルネックになります。同一ノード内のNVLink接続されたGPUでは、Stage 3でも許容可能な性能を達成できることがありますが、InfiniBandなどのノード間ネットワークでは、オーバーヘッドが無視できなくなります。実際のベンチマークでは、Stage 3は、Stage 2に比べて10%から30%程度遅くなることがあります。
Stage 3が真に必要なのは、モデルが非常に大きくて、Stage 2でも単一GPUに収まらない場合です。たとえば、1750億パラメータのモデルをFP16でトレーニングする場合、Stage 2でも各GPUに約200GB以上のメモリが必要になることがあります。これは、A100の80GBを大幅に超えています。このような場合、Stage 3を使用することで、より多くのGPUにモデルを分散させ、各GPUのメモリ要件を削減できます。
実装上の考慮事項もあります。Stage 1とStage 2は、DeepSpeedやFairScaleなどのライブラリで比較的簡単に使用できます。コードの変更は最小限で、通常は数行の設定変更で済みます。Stage 3は、やや複雑になることがあります。特に、カスタムモデルアーキテクチャや特殊な演算を使用している場合、パラメータの分割とAll-gatherのタイミングを慎重に管理する必要があります。
ステージの選択は、具体的な状況に依存します。経験則として、メモリに余裕がある場合や通信を最小化したい場合はStage 1、バランスの取れた選択肢としてはStage 2、極限のメモリ効率が必要で通信オーバーヘッドを許容できる場合はStage 3を使用します。多くのプロジェクトでは、まずStage 2から始め、メモリが不足する場合にのみStage 3に移行するアプローチが推奨されます。
また、ZeROのステージは、gradient checkpointingや混合精度トレーニングなどの他の最適化技術と組み合わせることができます。たとえば、Stage 2とgradient checkpointingを組み合わせることで、モデル状態とアクティベーションの両方のメモリを削減し、より大きなモデルやバッチサイズを実現できます。このような組み合わせにより、限られたハードウェアリソースでも、驚くほど大規模なモデルをトレーニングできるようになります。
6.4 メモリ削減効果と実装上の工夫
Tatsunori Hashimoto: ZeROの理論的なメモリ削減効果を理解したところで、実際のシナリオでの具体的な効果と、実装する際の重要な工夫について見ていきましょう。理論と実践の間には、しばしばギャップがあり、そのギャップを埋めるための技術が重要です。
まず、具体的な数値例でメモリ削減効果を確認しましょう。70億パラメータのモデルを8つのA100 GPUでトレーニングするケースを考えます。混合精度トレーニング(FP16計算、FP32マスターウェイト)を使用するとします。標準的なデータ並列化では、各GPUは次のメモリを必要とします。パラメータのFP16コピーが14GB、勾配が14GB、FP32マスターウェイトが28GB、オプティマイザのモーメンタムと分散が28GB、合計で84GBです。これは、A100の80GBを超えています。
ZeRO Stage 2を適用すると、劇的に変わります。パラメータのFP16コピーは依然として14GBですが、勾配は14GB / 8 = 1.75GBに、FP32マスターウェイトは28GB / 8 = 3.5GBに、オプティマイザの状態も28GB / 8 = 3.5GBに削減されます。合計は14 + 1.75 + 3.5 + 3.5 = 22.75GBとなり、約73%の削減です。これにより、バッチサイズを大幅に増やしたり、より長いシーケンスを処理したりできるようになります。
さらに大規模なモデルでの効果も見てみましょう。175億パラメータのモデルを64個のA100でトレーニングする場合を考えます。標準的なデータ並列化では、各GPUに約210GBのメモリが必要で、これは到底不可能です。ZeRO Stage 2では、約52GBに削減され、A100の80GBに収まるようになります。ZeRO Stage 3を使用すれば、さらに約13GBまで削減でき、より大きなバッチサイズや、gradient checkpointingなしでのトレーニングが可能になります。
実装上の重要な工夫の一つは、通信とコンピュートのオーバーラップです。ZeRO Stage 2では、backward pass中に勾配を計算しながら、すでに計算された勾配をReduce-scatter操作で通信できます。DeepSpeedの実装では、この自動オーバーラップが組み込まれています。各レイヤーの勾配が計算されると、すぐにバックグラウンドで通信が開始され、次のレイヤーの勾配計算と並行して実行されます。
この技術は、backward passがレイヤーを逆順に処理するという性質を利用しています。最後のレイヤーの勾配が最初に計算されるため、その通信を開始しながら、前のレイヤーの勾配計算を続けられます。理想的には、通信時間が計算時間内に完全に隠蔽され、実質的なオーバーヘッドがゼロになります。実際には、完全なオーバーラップは難しいですが、かなりの部分を隠蔽できます。
もう一つの重要な工夫は、パーティショニング戦略です。モデルのパラメータを複数のGPUに分割する際、単純に連続するパラメータを均等に分割するだけでなく、より洗練された戦略を使用できます。たとえば、各レイヤーのパラメータを複数のGPUに分散させることで、負荷のバランスを改善できます。一部のレイヤーは他のレイヤーよりも大きいため、単純な連続分割では、一部のGPUに負荷が偏る可能性があります。
メモリの断片化も実践的な問題です。ZeRO Stage 3では、forward pass中にパラメータを動的にAll-gatherし、計算後に解放します。この頻繁な割り当てと解放は、メモリの断片化を引き起こす可能性があります。DeepSpeedは、この問題に対処するために、事前に割り当てられたバッファプールを使用します。パラメータのAll-gather用のバッファを事前に確保し、再利用することで、断片化を最小化し、割り当てのオーバーヘッドも削減します。
バケッティング(bucketing)という技術も重要です。多数の小さなパラメータテンソルを個別に通信するのは非効率です。代わりに、複数の小さなテンソルを一つのバケットにまとめて、一度の通信操作で転送します。これにより、通信の起動オーバーヘッドを削減し、ネットワーク帯域幅をより効率的に利用できます。DeepSpeedでは、デフォルトで500MBのバケットサイズが使用されますが、これは設定で調整できます。
ZeRO-Offloadという拡張機能も注目に値します。これは、オプティマイザの状態をGPUメモリではなくCPUメモリに保存する技術です。更新ステップ中にのみ、必要なデータをGPUに転送します。これにより、GPUメモリをさらに節約できますが、CPU-GPU間の転送オーバーヘッドが発生します。ZeRO-Offloadは、メモリが極端に不足している場合や、CPUメモリが豊富にある環境で有用です。
ZeRO-Infinityは、さらに進んだ概念で、NVMe SSDなどの二次記憶装置にもデータをオフロードします。これにより、単一のマシンのメモリ容量をはるかに超える巨大なモデルをトレーニングできます。ただし、ディスクI/Oは非常に遅いため、慎重なデータ管理とプリフェッチ戦略が必要です。ZeRO-Infinityは、主に研究目的や、極限的に大きなモデルの実験に使用されます。
実装の観点から、DeepSpeedは最も成熟したZeRO実装を提供しています。PyTorchとの統合が良好で、比較的簡単に既存のコードに組み込めます。基本的な使用方法は、モデルとオプティマイザをDeepSpeedEngineでラップし、設定ファイルでZeROのステージを指定するだけです。たとえば、設定ファイルで"zero_optimization": {"stage": 2}と指定すれば、ZeRO Stage 2が有効になります。
FairScaleもZeROの実装を提供しており、特にFully Sharded Data Parallel(FSDP)という名前で知られています。FSDPは、ZeRO Stage 3に相当する機能を提供し、PyTorchコミュニティで広く使用されています。PyTorch 1.11以降では、FSDPが公式にPyTorchに統合されており、torch.distributed.fsdpモジュールからアクセスできます。
実際のプロジェクトでZeROを使用する際の推奨アプローチは、段階的に導入することです。まず、小規模なモデルやデータセットで、ZeRO Stage 1または2をテストし、正しく動作することを確認します。次に、メモリ使用量とトレーニング速度を測定し、標準的なデータ並列化と比較します。メモリ削減効果が期待通りであることを確認したら、本格的なトレーニングに移行します。
デバッグも重要な考慮事項です。ZeROを使用すると、パラメータや勾配が複数のGPUに分散されるため、デバッグが複雑になることがあります。DeepSpeedは、ZeROを無効にして標準的なデータ並列化に戻すオプションを提供しており、問題の切り分けに役立ちます。また、小規模な設定(少数のGPU、小さなモデル)でまず検証し、段階的にスケールアップすることが推奨されます。
チェックポイントの保存と読み込みも、ZeROでは特別な処理が必要です。Stage 2や3では、各GPUがパラメータの一部のみを保持しているため、チェックポイントを保存する際には、すべてのGPUからパラメータを収集する必要があります。DeepSpeedは、この処理を自動化する機能を提供しており、通常のモデルと同じようにチェックポイントを扱えます。ただし、大規模モデルでは、チェックポイントの保存に時間がかかる場合があるため、定期的な保存の頻度を調整する必要があります。
最後に、ZeROはツールの一つであり、万能ではないことを理解しておくことが重要です。中規模のモデルで十分なGPUメモリがある場合、標準的なデータ並列化の方がシンプルで、わずかに高速かもしれません。しかし、メモリが制約となる場合、ZeROは非常に強力なソリューションです。適切なステージを選択し、他の最適化技術と組み合わせることで、限られたリソースで驚くほど大規模なモデルをトレーニングできるようになります。
7. Model parallelism(モデル並列化)
7.1 モデル並列化の必要性
Tatsunori Hashimoto: データ並列化とZeROは非常に強力ですが、それでも解決できない問題があります。それは、モデルが単一のGPUに物理的に収まらない場合です。ZeRO Stage 3でさえ、各GPUはforward pass中に少なくとも1つのレイヤーの完全なパラメータをメモリに保持する必要があります。モデルが極端に大きくなると、これが問題になります。ここでモデル並列化が必要になります。
モデル並列化の基本的な動機を理解するために、具体的なシナリオを考えてみましょう。5000億パラメータのモデルをトレーニングしたいとします。このモデルをFP16で保存するだけで1TB(テラバイト)のメモリが必要です。さらに、アクティベーション、勾配、オプティマイザの状態を含めると、数TBに達します。NVIDIA A100の最大メモリは80GBですから、どれほどZeROを活用しても、このモデルを単一のGPUでは扱えません。
実際には、個々のレイヤーやコンポーネントが単一GPUのメモリを超えることもあります。たとえば、非常に大きな語彙サイズを持つ埋め込み層や、非常に大きな隠れ層次元を持つフィードフォワードネットワークなどです。GPT-3の最終レイヤーには、12,288次元から50,257語彙への射影があり、これだけでFP32で約2.5GBのメモリを消費します。より大きなモデルでは、単一のレイヤーが数十GBになることもあります。
モデル並列化は、モデル自体を複数のデバイスに分割することで、この問題を解決します。重要なのは、これはデータ並列化とは異なるアプローチだということです。データ並列化では、各デバイスがモデル全体のコピーを持ち、異なるデータを処理します。モデル並列化では、各デバイスがモデルの一部のみを保持し、同じデータを協調して処理します。
モデル並列化には、主に2つのタイプがあります。テンソル並列化(Tensor Parallelism)とパイプライン並列化(Pipeline Parallelism)です。これらは、モデルを分割する「方向」が異なります。テンソル並列化は、個々のレイヤー内の演算を複数のデバイスに分割します。パイプライン並列化は、レイヤー全体を複数のステージに分割し、各ステージを異なるデバイスに配置します。
テンソル並列化の必要性を理解するために、Transformerの大きなレイヤーを考えましょう。フィードフォワードネットワークでは、通常、隠れ層次元dから4dへの拡大、そして4dからdへの縮小という2つの大きな行列乗算があります。dが12,288の場合、最初の重み行列は12,288 × 49,152で、FP16で約1.2GBです。2番目の重み行列は49,152 × 12,288で、同じく約1.2GBです。単一のフィードフォワード層だけで2.4GBです。96層のモデルでは、フィードフォワード層だけで230GB以上になります。
この大きな行列演算を複数のGPUに分割できれば、各GPUのメモリ負担を軽減できます。たとえば、4つのGPUで分割すれば、各GPUは約0.3GBずつを保持すればよくなります。これがテンソル並列化の基本的なアイデアです。各GPUが行列の一部を保持し、その部分に対する演算を実行します。
パイプライン並列化は、異なる問題に対処します。モデルが非常に深い場合、つまり多数のレイヤーを持つ場合、これらのレイヤーを複数のGPUに分散させることができます。たとえば、96層のモデルを8つのGPUに分割する場合、各GPUは12層を担当します。最初のGPUはレイヤー1-12を、2番目のGPUはレイヤー13-24を、といった具合です。
パイプライン並列化の利点は、実装が比較的シンプルで、通信パターンが明確なことです。各ステージは、前のステージからアクティベーションを受け取り、計算を行い、次のステージにアクティベーションを渡します。通信は、ステージ間の境界でのみ発生します。これは、テンソル並列化のように、各演算内で頻繁に通信する必要がないという点で優れています。
しかし、パイプライン並列化には固有の課題があります。最も大きな問題は、パイプラインバブルです。パイプラインの最初のGPUは、データの最初のバッチを処理し終えると、そのアクティベーションを次のGPUに渡します。しかし、その間、最初のGPUはアイドル状態になります。同様に、最後のGPUも、パイプラインの前段が処理を完了するまで待機する必要があります。このアイドル時間がバブルです。
テンソル並列化は、このバブル問題を避けられますが、より頻繁な通信を必要とします。各行列演算で、GPUは部分結果を交換する必要があります。したがって、テンソル並列化は、非常に高速なインターコネクト、たとえばNVLinkを持つ同一ノード内のGPU間でのみ効率的です。ノード間では、通信レイテンシが大きすぎて、性能が大幅に低下します。
実際の大規模トレーニングでは、これらの並列化手法を組み合わせることが一般的です。これをハイブリッド並列化と呼びます。たとえば、同一ノード内の8つのGPUでテンソル並列化を使用し、複数のノード間でパイプライン並列化とデータ並列化を使用します。さらに、ZeROを加えることで、メモリ効率をさらに向上させます。
具体例を見てみましょう。1兆パラメータのモデルを512個のGPU(64ノード×8GPU)でトレーニングするとします。各ノード内の8つのGPUで、8方向のテンソル並列化を使用します。これにより、各GPUは行列の8分の1を保持します。次に、8ノードにわたって8段階のパイプライン並列化を使用します。各パイプラインステージは、モデルの8分の1のレイヤーを担当します。そして、8つの独立したパイプラインレプリカを作成し、データ並列化を実現します。さらに、ZeRO Stage 1または2を適用して、各レプリカ内のメモリを最適化します。
このような複雑な並列化戦略は、慎重な設計と調整を必要とします。各次元での並列化度、通信パターン、メモリ使用量、そして計算効率のバランスを取る必要があります。間違った選択は、性能の大幅な低下やメモリ不足につながります。しかし、適切に設計されれば、数千のGPUで効率的にスケールし、これまで不可能だった規模のモデルをトレーニングできます。
モデル並列化は、単に大きなモデルをトレーニングするための手段ではありません。それは、計算リソースをどのように組織化するかという、より広い設計空間を開きます。同じ総計算予算でも、異なる並列化戦略により、異なるトレードオフが得られます。次のセクションでは、テンソル並列化とパイプライン並列化の詳細を見ていきます。
7.2 テンソル並列化とパイプライン並列化
Tatsunori Hashimoto: テンソル並列化とパイプライン並列化は、モデルを分割する2つの根本的に異なるアプローチです。それぞれの仕組み、利点、そして課題を詳しく見ていきましょう。
テンソル並列化は、個々のレイヤー内の行列演算を複数のGPUに分割します。Transformerの文脈で具体的に説明しましょう。セルフアテンション層には、Query、Key、Valueの3つの射影があり、それぞれが大きな行列乗算です。入力の隠れ状態をXとし、次元をd、射影の重み行列をWq、Wk、Wvとします。これらの行列はすべて、d × dのサイズです。
テンソル並列化では、これらの重み行列を列方向に分割します。たとえば、4つのGPUでテンソル並列化を行う場合、Wqを4つの部分行列Wq1、Wq2、Wq3、Wq4に分割します。各部分行列はd × (d/4)のサイズです。GPU 0はWq1を、GPU 1はWq2を保持し、といった具合です。
Forward pass中、入力XはすべてのGPUに複製されます。各GPUは、自分の持つ重み行列の部分と入力Xを乗算します。GPU 0はX × Wq1を計算し、GPU 1はX × Wq2を計算します。これらの部分的な結果を連結すると、完全なクエリ行列Qが得られます。重要なのは、この連結は明示的な通信を必要としないことです。各GPUは自分の結果を保持し、後続の演算で必要に応じて使用します。
アテンション機構では、さらに工夫があります。アテンションスコアの計算とソフトマックスの適用は、各GPUで独立して行えます。各GPUは、自分の担当するヘッドのサブセットに対してアテンションを計算します。マルチヘッドアテンションの構造が、テンソル並列化と自然に相性が良いのです。たとえば、32個のアテンションヘッドがあり、4つのGPUで分割する場合、各GPUは8個のヘッドを担当します。
しかし、アテンション層の出力射影では、通信が必要になります。各GPUが計算した部分的な出力を合計して、最終的な出力を得る必要があります。ここでAll-reduce操作が使用されます。この通信は、各レイヤーごとに発生します。Transformerモデルにはレイヤーが多数あるため、頻繁な通信が必要になります。
フィードフォワードネットワークでも同様の戦略が使えます。最初の線形層(d → 4dの拡大)を列方向に分割し、各GPUがその一部を計算します。活性化関数(GELUやReLU)は各GPUで独立して適用できます。2番目の線形層(4d → dの縮小)を行方向に分割し、最後にAll-reduceで結果を合計します。
Megatron-LMは、テンソル並列化の最も成功した実装の一つです。NVIDIA の研究者が開発したこのフレームワークは、Transformerの構造を詳細に分析し、通信を最小化する巧妙な分割戦略を採用しています。Megatron-LMでは、forward passでf個の通信ポイント、backward passでb個の通信ポイントがあり、各レイヤーごとに合計2回のAll-reduce操作が必要です。
テンソル並列化の大きな利点は、パイプラインバブルがないことです。すべてのGPUが常に計算を行っており、アイドル時間がありません。また、負荷が自然に均等に分散されます。各GPUは同じ量の計算を行います。しかし、欠点は頻繁な通信です。各レイヤーでAll-reduce操作が必要なため、GPUメモリ間の高速な相互接続が不可欠です。
パイプライン並列化は、異なるアプローチを取ります。モデルのレイヤーを連続するグループに分割し、各グループを異なるGPU(またはGPUのグループ)に配置します。たとえば、96層のモデルを4つのステージに分割する場合、ステージ1はレイヤー1-24を、ステージ2はレイヤー25-48を、ステージ3はレイヤー49-72を、ステージ4はレイヤー73-96を担当します。
単純なパイプライン並列化では、データはステージを順次通過します。最初のバッチがステージ1を通過し、その出力がステージ2に送られ、という具合です。しかし、これでは大きなバブルが発生します。ステージ1が最初のバッチを処理している間、他のすべてのステージはアイドル状態です。ステージ2が処理を開始するころには、ステージ1は再びアイドルになります。
この問題を軽減するために、マイクロバッチングという技術が使われます。大きなミニバッチを複数の小さなマイクロバッチに分割し、これらを順次パイプラインに投入します。たとえば、バッチサイズ128を8つのマイクロバッチ(各16サンプル)に分割します。ステージ1が最初のマイクロバッチを処理し終えると、それをステージ2に送り、すぐに2番目のマイクロバッチの処理を開始します。
このようにすることで、複数のマイクロバッチが同時にパイプライン内の異なるステージで処理されます。ステージ1が4番目のマイクロバッチを処理している間、ステージ2は3番目を、ステージ3は2番目を、ステージ4は1番目を処理できます。これにより、バブルを大幅に削減できます。
GPipeとPipeDreamは、パイプライン並列化の2つの主要なアプローチです。GPipeは、forward passで全マイクロバッチを処理した後に、backward passですべてを処理します。これはシンプルですが、アクティベーションをすべて保存する必要があるため、メモリ使用量が大きくなります。GPipeは、gradient checkpointingと組み合わせることで、この問題を軽減します。
PipeDreamは、より複雑ですが効率的なアプローチです。1F1B(One Forward, One Backward)スケジューリングを使用します。各マイクロバッチのforward passの後、すぐにそのbackward passを実行します。これにより、アクティベーションを長期間保持する必要がなくなり、メモリ効率が向上します。また、バブルも削減されます。
パイプライン効率を測定するために、バブル時間の割合を計算できます。P個のパイプラインステージとM個のマイクロバッチがある場合、理想的な時間はM × tで、tは1マイクロバッチの処理時間です。しかし、実際には、(M + P - 1) × tかかります。追加の(P - 1) × tがバブルです。バブル率は、(P - 1) / (M + P - 1)です。マイクロバッチ数Mを増やすことで、バブル率を減らせます。
たとえば、P = 4でM = 16の場合、バブル率は3 / 19 ≈ 15.8%です。M = 32なら、3 / 35 ≈ 8.6%に改善されます。しかし、マイクロバッチ数を増やすことは、各マイクロバッチのサイズを小さくすることを意味し、GPU利用率の低下につながる可能性があります。したがって、適切なバランスを見つける必要があります。
パイプライン並列化の利点は、通信頻度が低いことです。ステージ間の境界でのみアクティベーション(forward pass)と勾配(backward pass)を転送すればよく、レイヤー内での通信は不要です。これにより、ノード間での使用が現実的になります。InfiniBandなどのネットワークでも、許容可能な性能を達成できます。
しかし、欠点もあります。負荷のバランスを取るのが難しいことがあります。すべてのステージが同じ計算時間を持つように、レイヤーを分割する必要がありますが、Transformerの異なるレイヤーは必ずしも同じ計算コストではありません。埋め込み層や出力層は、中間のTransformerレイヤーと異なる特性を持ちます。不均等な分割は、遅いステージがボトルネックとなり、全体の効率を低下させます。
実際の大規模トレーニングでは、テンソル並列化とパイプライン並列化を組み合わせることが一般的です。同一ノード内の高速接続されたGPUでテンソル並列化を使用し、ノード間でパイプライン並列化を使用します。たとえば、各ノードに8つのGPUがある場合、ノード内で8方向のテンソル並列化を行い、複数のノード間でパイプライン並列化を行います。
このハイブリッドアプローチは、両方の利点を活用します。テンソル並列化により、大きなレイヤーを単一ノード内で効率的に処理でき、パイプライン並列化により、モデル全体を複数のノードに分散できます。さらに、データ並列化を追加することで、スループットをさらに向上させます。複数の独立したモデルレプリカを作成し、各レプリカが異なるデータバッチを処理します。
このような3次元並列化(テンソル、パイプライン、データ)は、現代の超大規模モデルトレーニングの標準的なアプローチになっています。適切に設定されれば、数千のGPUで高い効率を達成できます。ただし、これらの次元での最適な分割度を見つけることは、非自明な最適化問題であり、モデルのサイズ、ハードウェアの構成、ネットワークトポロジーに依存します。
7.3 実装方法と通信オーバーヘッド
Tatsunori Hashimoto: モデル並列化の理論を理解したところで、実際の実装方法と、避けられない通信オーバーヘッドをどのように最小化するかを見ていきましょう。これらの詳細が、理論上の性能と実際の性能の差を決定します。
テンソル並列化の実装から始めましょう。Megatron-LMは、最も広く使われているテンソル並列化の実装です。その核心的なアイデアは、行列演算を注意深く分割し、通信を最小化することです。具体的な実装を見てみましょう。線形層Y = XWを考えます。Xは入力で形状は[b, s, d]、bはバッチサイズ、sはシーケンス長、dは隠れ層次元です。Wは重み行列で形状は[d, d]です。
列並列の場合、Wを列方向に分割します。N個のGPUがある場合、W = [W1, W2, ..., WN]となり、各Wiは[d, d/N]です。各GPUはX × Wiを独立して計算します。結果は[b, s, d/N]の形状で、各GPUが保持します。重要なのは、この時点では通信が不要だということです。結果が自然に分割されているため、後続の演算がこの分割形式を利用できれば、通信を避けられます。
マルチヘッドアテンションは、この性質を完璧に利用できます。各GPUは、アテンションヘッドのサブセットを担当し、独立して計算します。しかし、アテンション層の最後の出力射影では、結果を合計する必要があります。ここでAll-reduce操作が入ります。各GPUの部分的な出力を合計し、完全な出力を全GPUに配布します。
行並列も使用されます。行列Wを行方向に分割する場合、各GPUがW1、W2、...、WNを保持し、各Wiは[d/N, d]の形状です。この場合、入力Xを分割する必要があります。まず、Scatter操作でXを分割し、各GPUに配布します。各GPUはXi × Wiを計算します。結果は[b, s, d]で、各GPUが異なる部分を保持します。最後にAll-gatherで完全な結果を再構成するか、All-reduceで合計します。
Megatron-LMの巧妙な点は、列並列と行並列を交互に配置することで、一部の通信を省略できることです。たとえば、アテンション層のQKV射影は列並列で、出力射影は行並列です。列並列の出力がすでに分割されているため、それをそのまま行並列の入力として使えます。中間でのAll-gatherが不要になります。同様に、フィードフォワードネットワークでも、最初の線形層を列並列に、2番目を行並列にすることで、通信を削減できます。
この最適化により、各Transformerレイヤーで必要なAll-reduce操作は2回だけになります。アテンション層の出力で1回、フィードフォワードネットワークの出力で1回です。96層のモデルでは、forward passで192回、backward passで192回、合計384回のAll-reduce操作が必要です。
通信コストを定量化してみましょう。各All-reduce操作で転送されるデータ量は、バッチサイズb、シーケンス長s、隠れ層次元dに依存します。1つのレイヤーの出力は[b, s, d]の形状で、FP16の場合、2bsd バイトです。Ring All-reduceアルゴリズムでは、各GPUは約2bsd × (N-1)/Nバイトを送受信します。Nが大きい場合、これは約2bsdバイトに近づきます。
具体例で計算しましょう。b = 32、s = 2048、d = 12288、N = 8、FP16を仮定します。1回のAll-reduceで転送されるデータは、約2 × 32 × 2048 × 12288 × 2 バイト ≈ 3.2 GBです。96層のモデルでは、forward passだけで192 × 3.2 GB = 614 GBの通信が必要です。NVLinkの帯域幅が600 GB/sの場合、通信時間は約1秒です。
しかし、実際には通信と計算をオーバーラップできます。Backward pass中、各レイヤーの勾配が計算されると、すぐにAll-reduce通信を開始できます。次のレイヤーの勾配計算と並行して通信が進行します。理想的には、計算時間が通信時間より長ければ、通信コストを完全に隠蔽できます。実際には、50%から80%程度の隠蔽が達成されることが多いです。
パイプライン並列化の実装は、異なる課題を提示します。GPipeの実装では、まずモデルをP個のステージに分割します。各ステージは、連続するレイヤーのグループです。PyTorchでは、torch.nn.Sequentialを使って各ステージを定義できます。ミニバッチをM個のマイクロバッチに分割し、これらを順次パイプラインに送り込みます。
GPipeのスケジューリングは比較的シンプルです。すべてのマイクロバッチのforward passを完了してから、すべてのbackward passを実行します。各マイクロバッチのアクティベーションを保存する必要がありますが、gradient checkpointingを使用してメモリを節約できます。実装は、各ステージでforward関数とbackward関数を定義し、適切なタイミングでステージ間でテンソルを送受信するだけです。
PipeDreamの1F1Bスケジューリングは、より複雑です。ウォームアップフェーズ、定常状態フェーズ、クールダウンフェーズの3つのフェーズがあります。ウォームアップフェーズでは、最初のいくつかのマイクロバッチのforward passを実行します。定常状態フェーズでは、各forward passの後にすぐにbackward passを実行します。クールダウンフェーズでは、残りのbackward passを完了します。
この複雑なスケジューリングを実装するには、各ステージが自分の状態を追跡し、適切なタイミングで動作を切り替える必要があります。DeepSpeedのパイプライン実装は、このロジックを自動化しており、ユーザーは単にモデルを分割するだけで済みます。DeepSpeedは、最適なマイクロバッチ数とスケジューリング戦略を自動的に決定します。
パイプライン並列化の通信コストは、テンソル並列化とは異なります。ステージ間でアクティベーションと勾配のみを転送すればよく、レイヤー内での通信は不要です。各マイクロバッチについて、forward pass中に1回のアクティベーション転送、backward pass中に1回の勾配転送が必要です。
転送されるデータ量は、ステージ間の境界でのテンソルサイズに依存します。通常、これは[micro_batch_size, sequence_length, hidden_dimension]です。マイクロバッチサイズが16、シーケンス長が2048、隠れ層次元が12288、FP16の場合、1回の転送は約16 × 2048 × 12288 × 2 バイト ≈ 800 MBです。M = 32のマイクロバッチでは、forward とbackwardで合計32 × 2 × 800 MB = 51.2 GBの通信が必要です。
これは、テンソル並列化の614 GB(forward passのみ)と比較してはるかに少ないです。ただし、パイプライン並列化には、バブルというコストがあります。前述のように、バブル率は(P-1)/(M+P-1)です。P = 4、M = 32の場合、約8.6%のバブルがあります。これは、計算リソースの8.6%が無駄になることを意味します。
ハイブリッド並列化では、これらの通信パターンが組み合わされます。各パイプラインステージ内でテンソル並列化を使用する場合、テンソル並列化の通信は同一ノード内で高速に実行され、パイプライン並列化の通信はノード間で実行されます。通信トポロジーを考慮した最適化が重要になります。
NCCL(NVIDIA Collective Communications Library)は、これらの集団通信操作を効率的に実装するための重要なライブラリです。NCCLは、ハードウェアトポロジーを自動的に検出し、最適な通信アルゴリズムとルーティングを選択します。NVLink、PCIe、InfiniBandなど、利用可能な相互接続を最大限活用します。
実装上のベストプラクティスとして、通信とコンピュートのオーバーラップを最大化することが重要です。PyTorchでは、非同期通信操作を使用できます。たとえば、dist.all_reduce(tensor, async_op=True)は、通信をバックグラウンドで開始し、すぐに制御を返します。計算を続けながら通信を進行させ、結果が必要になる前にwait()を呼び出して完了を待ちます。
もう一つの重要な最適化は、通信のバッチ化です。複数の小さなテンソルを個別に送信するのではなく、一つの大きなテンソルに連結してから送信します。これにより、通信の起動オーバーヘッドを削減できます。Megatron-LMは、このような最適化を自動的に適用します。
最後に、デバッグとプロファイリングが重要です。モデル並列化のバグは、特定しにくいことがあります。異なるGPUで異なる計算が行われ、複雑な通信パターンがあるためです。NVIDIA Nsight SystemsやPyTorch Profilerなどのツールを使用して、通信と計算の時間を測定し、ボトルネックを特定することが推奨されます。通信時間が計算時間を支配している場合、並列化戦略を再考する必要があるかもしれません。
8. Flash Attention
8.1 標準的なAttentionの計算とメモリボトルネック
Tatsunori Hashimoto: Flash Attentionについて議論する前に、まず標準的なアテンション機構がどのように計算され、なぜメモリがボトルネックになるのかを深く理解する必要があります。アテンションは、Transformerの核心的なコンポーネントであり、その効率がモデル全体の性能に大きく影響します。
標準的なセルフアテンションの計算プロセスを段階的に見ていきましょう。入力として、シーケンス長Nと隠れ層次元dを持つテンソルXがあります。形状は[N, d]です。まず、Query、Key、Valueの3つの射影を計算します。Q = X × Wq、K = X × Wk、V = X × Wvです。各重み行列は[d, d]の形状で、結果のQ、K、Vはすべて[N, d]です。
次に、アテンションスコアを計算します。これは、QとKの転置の行列乗算です。S = Q × K^T。この演算により、[N, N]の形状のスコア行列が生成されます。ここで重要な点は、シーケンス長Nに対してN²の要素を持つ行列が作られることです。この行列は、各クエリ位置が各キー位置とどれだけ関連しているかを表します。
実際には、スコアを√dでスケーリングします。S' = S / √d。これは数値安定性のためです。次に、最も重要なステップ、ソフトマックス関数の適用があります。P = softmax(S')です。ソフトマックスは行ごとに適用され、各行が確率分布(合計が1)になります。Pも[N, N]の形状です。
最後に、このアテンション確率行列をValue行列と乗算します。O = P × V。これにより、[N, d]の形状の出力が得られます。これがアテンションの出力です。マルチヘッドアテンションでは、このプロセスを複数のヘッドで並行して実行し、結果を連結します。
さて、メモリ使用量を分析してみましょう。まず、Q、K、Vはそれぞれ[N, d]で、FP16の場合、各々が2Ndバイトです。3つ合わせて6Ndバイトです。次に、スコア行列Sは[N, N]で、2N²バイトです。アテンション確率行列Pも[N, N]で、2N²バイトです。出力Oは[N, d]で、2Ndバイトです。
総メモリ使用量は、約6Nd + 4N² + 2Nd = 8Nd + 4N²バイトです。重要なのは、N²の項です。シーケンス長が増加すると、この項が支配的になります。たとえば、N = 2048、d = 768の場合を計算してみましょう。8 × 2048 × 768 + 4 × 2048² ≈ 12.6 MB + 16.8 MB = 29.4 MBです。ここではすでにN²の項が大きな割合を占めています。
N = 4096に増やすと、8 × 4096 × 768 + 4 × 4096² ≈ 25.2 MB + 67.1 MB = 92.3 MBです。シーケンス長を2倍にすると、N²の項は4倍になります。N = 8192では、8 × 8192 × 768 + 4 × 8192² ≈ 50.3 MB + 268.4 MB = 318.7 MBです。シーケンス長が長くなるにつれて、メモリ使用量が急速に増加することがわかります。
バッチサイズを考慮すると、状況はさらに悪化します。バッチサイズBの場合、メモリ使用量はB倍になります。B = 32、N = 4096では、約32 × 92.3 MB ≈ 2.95 GBです。これは単一のアテンション層だけの話です。Transformerモデルには通常、数十のこのような層があります。
しかし、メモリの絶対量だけが問題ではありません。さらに深刻な問題は、メモリアクセスパターンです。現代のGPUでは、計算速度は非常に高速ですが、メモリアクセスは相対的に遅いです。GPUのメモリ階層を理解することが重要です。
GPUには複数のメモリレベルがあります。最も高速なのはレジスタで、次にL1キャッシュ、共有メモリ(SRAM)、そして最も大きいが遅い高帯域メモリ(HBM)があります。HBMは、A100では80GBの容量がありますが、帯域幅は約1.5-2 TB/sです。一方、オンチップのSRAMは、容量は20MB程度ですが、帯域幅は約19 TB/s、つまり約10倍高速です。
標準的なアテンション実装では、多くの中間結果をHBMに書き込み、後で読み戻します。具体的には、Q × K^Tの結果(スコア行列S)をHBMに書き込みます。次に、ソフトマックスを計算するためにSを読み戻し、結果のPをHBMに書き込みます。そして、P × Vを計算するためにPを再び読み戻します。
このHBMアクセスのコストを計算してみましょう。スコア行列SとアテンションPは、それぞれN²の要素を持ち、FP16で2N²バイトです。これらを書き込んで読み戻すと、合計で約8N²バイトのHBMアクセスが発生します。N = 4096の場合、約134 MBです。A100のHBM帯域幅が1.5 TB/sとすると、これには約0.09ミリ秒かかります。
一方、実際の計算量(FLOPs)はどうでしょうか。Q × K^Tは約2N²d FLOPs、P × Vも約2N²d FLOPsです。ソフトマックスやその他の演算を含めると、合計で約4N²d FLOPsです。N = 4096、d = 768の場合、約5.2×10¹⁰ FLOPsです。A100のFP16ピーク性能が312 TFLOPSの場合、これには約0.17ミリ秒かかります。
ここで重要な観察があります。メモリアクセス時間(0.09ミリ秒)と計算時間(0.17ミリ秒)が同じオーダーです。つまり、この演算はメモリバウンド(memory-bound)です。計算ユニットは、データが到着するのを待って時間を浪費しています。シーケンス長がさらに長くなると、N²の項がさらに支配的になり、メモリアクセスがより深刻なボトルネックになります。
さらに悪いことに、標準的な実装では、これらの演算が別々のカーネルとして実行されます。行列乗算、ソフトマックス、そして再び行列乗算という具合です。各カーネルの境界で、結果をHBMに書き戻し、次のカーネルで読み直す必要があります。これらの追加のメモリアクセスが、さらにオーバーヘッドを増やします。
IO複雑性を分析すると、標準的なアテンションのHBMアクセス量は、Θ(Nd + N²)です。Nが大きい場合、N²の項が支配的です。N = 16384のような長いシーケンスでは、このメモリアクセスが計算全体のボトルネックになります。実際のスループットは、ハードウェアのピーク性能の10%から20%程度まで低下することがあります。
この問題は、長いシーケンスを扱う必要があるタスク、たとえば長文書の処理、高解像度画像のビジョントランスフォーマー、長いコンテキストを持つ対話システムなどで特に深刻です。メモリ使用量のO(N²)の複雑性は、スケーラビリティの根本的な制約となっています。
バックワードパスでは、状況はさらに複雑です。勾配を計算するために、フォワードパス中の中間結果を保存しておく必要があります。または、gradient checkpointingを使用して再計算します。いずれにしても、追加のメモリアクセスまたは計算が必要です。
この根本的な問題に対処するために、Flash Attentionが開発されました。Flash Attentionの核心的なアイデアは、メモリ階層を意識した計算を行うことです。HBMアクセスを最小化し、できるだけ多くの計算を高速なオンチップメモリで実行します。次のセクションで、Flash Attentionがどのようにこれを実現するかを見ていきます。
8.2 Flash Attentionのアルゴリズム
Tatsunori Hashimoto: Flash Attentionは、標準的なアテンションと数学的には全く同じ結果を計算しますが、そのやり方が根本的に異なります。その核心的なアイデアは、タイリング(tiling)と呼ばれる技術を使って、計算を小さなブロックに分割し、これらのブロックを高速なオンチップメモリ(SRAM)で処理することです。
まず、Flash Attentionの基本戦略を理解しましょう。標準的な実装では、完全なN×Nのアテンションスコア行列を一度に計算し、HBMに保存します。Flash Attentionは、この大きな行列を小さなブロックに分割し、各ブロックを独立して処理します。重要なのは、これらのブロックはSRAMに収まるサイズであり、HBMアクセスを最小化できることです。
具体的なアルゴリズムを見ていきましょう。入力として、Query行列Q、Key行列K、Value行列Vがあり、すべて[N, d]の形状です。Flash Attentionは、これらをブロックサイズBのチャンクに分割します。たとえば、N = 4096でB = 128の場合、32個のブロックに分割されます。
アルゴリズムは、外側のループと内側のループの2つのネストしたループで構成されます。外側のループは、Queryのブロックを反復します。各イテレーションで、Qのブロック、たとえばQ[i]をSRAMにロードします。ここでQ[i]は[B, d]の形状です。同時に、出力ブロックO[i]、正規化統計量ℓ[i]とm[i]を初期化します。
内側のループは、KeyとValueのブロックを反復します。各イテレーションで、K[j]とV[j]をSRAMにロードします。これらも[B, d]の形状です。次に、ブロック間のアテンションスコアを計算します。S[i,j] = Q[i] × K[j]^Tです。この結果は[B, B]の形状で、元の大きな行列の小さなブロックに相当します。
ここで巧妙な部分が来ます。ソフトマックスの計算です。標準的な実装では、すべてのスコアを見てからソフトマックスを適用しますが、Flash Attentionは、オンラインソフトマックスと呼ばれる技術を使用します。これは、ソフトマックスを段階的に計算し、新しいブロックの情報を見るたびに、以前の結果を更新します。
オンラインソフトマックスのトリックを理解するために、ソフトマックスの数学を見てみましょう。ソフトマックスは、softmax(x[i]) = exp(x[i]) / Σexp(x[j])です。数値安定性のため、通常は最大値を引きます。softmax(x[i]) = exp(x[i] - m) / Σexp(x[j] - m)、ここでm = max(x)です。
Flash Attentionは、これを増分的に計算します。最初のブロックS[i,1]を処理するとき、そのブロック内の最大値m_oldとexp値の合計ℓ_oldを計算します。次のブロックS[i,2]を処理するとき、新しい最大値m_newを計算し、以前の統計量を更新します。m_new = max(m_old, max(S[i,2]))です。
以前の合計を新しい最大値に合わせて再スケーリングする必要があります。ℓnew = ℓold × exp(m_old - m_new) + Σexp(S[i,2] - m_new)です。このようにして、すべてのブロックを見終わった後、正しいソフトマックスの結果が得られます。出力も同様に増分的に更新します。
具体的な擬似コードで表現すると、外側のループで各Qブロックについて、内側のループで各K、Vブロックについて、次の操作を実行します。S_block = Q[i] @ K[j].Tを計算します。m_new = max(m_old, max(S_block))を計算します。ℓ_newとO[i]を更新式に従って更新します。重要なのは、大きなN×N行列を一度もHBMに保存しないことです。
この段階的な計算により、メモリ使用量が劇的に削減されます。任意の時点で、SRAMに保持する必要があるのは、現在のQブロック[B, d]、現在のK、Vブロック[B, d]、ブロックスコア[B, B]、および累積統計量だけです。ブロックサイズBは、SRAMの容量に合わせて選択されます。通常、B = 64から256程度です。
IO複雑性を分析してみましょう。HBMからSRAMへのアクセス量を計算します。外側のループでN/B回のイテレーションがあり、各イテレーションでQブロック[B, d]をロードします。合計でNd要素です。内側のループでもN/B回のイテレーションがあり、各イテレーションでKとVブロック[B, d]をロードします。外側のループの各イテレーションで、これをN/B回繰り返すため、合計で(N/B) × (N/B) × 2Bd = 2N²d/B要素です。
総HBMアクセスは、Nd + 2N²d/Bです。標準的なアテンションのΘ(Nd + N²)と比較して、N²の項が1/Bのファクターで削減されています。B = 128の場合、メモリアクセスは約128分の1になります。これは理論的な改善ですが、実際にも大きな効果があります。
さらに重要なのは、カーネル融合です。Flash Attentionは、行列乗算、ソフトマックス、そして再び行列乗算という一連の演算をすべて単一のCUDAカーネルに融合します。中間結果をHBMに書き戻すことなく、すべてSRAMで処理します。これにより、カーネル起動のオーバーヘッドも削減され、メモリアクセスがさらに最適化されます。
バックワードパスも巧妙に実装されています。標準的な実装では、forward pass中のアテンション行列Pを保存しておく必要がありますが、これはO(N²)のメモリを消費します。Flash Attentionは、リコンピュテーション(再計算)を使用します。Backward pass中に、必要なブロックを再計算します。追加の計算コストはありますが、メモリを大幅に節約できます。
さらに、Flash Attentionは、ソフトマックスのLogsumexp統計量(m、ℓ)のみを保存します。これはO(N)のメモリで、元のO(N²)と比較して大幅に少ないです。Backward pass中、これらの統計量を使って、アテンション行列を効率的に再計算できます。
実装の観点から、Flash Attentionは低レベルのCUDAカーネルとして書かれています。PyTorchやTensorFlowの高レベルAPIでは、このような細かいメモリ制御は困難です。Flash Attentionのカーネルは、GPUのメモリ階層を明示的に管理し、データのプリフェッチ、パイプライニング、バンク衝突の回避など、多くの低レベル最適化を含んでいます。
ブロックサイズBの選択は重要なチューニングパラメータです。Bが大きいほど、HBMアクセスが少なくなりますが、SRAMの容量に収まる必要があります。また、Bが大きいほど、各ブロックの計算量が増え、並列性が向上します。実際には、ハードウェアとシーケンス長に応じて、最適なBを選択する必要があります。
因果的アテンション(causal attention)、つまり自己回帰モデルで使用されるマスクされたアテンションも、Flash Attentionで効率的に処理できます。因果マスクは、各クエリが自分より後のキーを見ないようにします。Flash Attentionは、マスクされたブロックをスキップすることで、不要な計算を避けます。これにより、因果アテンションでは、計算量が約半分になります。
Flash Attentionのアルゴリズムは、アテンションの本質的な並列性も活用します。異なるQブロックは独立して処理できるため、複数のブロックを並行してSRAM上で計算できます。これにより、GPUの多数のストリーミングマルチプロセッサ(SM)を効率的に活用できます。
理論的には、Flash AttentionのIO複雑性は最適またはほぼ最適です。アテンションを計算するために、少なくとも入力と出力をHBMから読み書きする必要があり、これはΩ(Nd)です。Flash AttentionのO(Nd + N²d/B)は、Bが十分大きければ、この下限に近づきます。実際には、Bはハードウェアの制約により制限されますが、それでも大幅な改善です。
Flash Attentionは、単なる高速化以上の意味を持ちます。それは、より長いシーケンスを扱うことを可能にします。メモリ使用量がO(N²)からO(N)に削減されることで、以前は不可能だった長さのシーケンスをトレーニングできるようになります。これは、長文書処理、高解像度画像、長いコンテキストを持つ対話システムなど、多くのアプリケーションに影響を与えます。
8.3 タイリングとカーネル融合による高速化
Tatsunori Hashimoto: Flash Attentionの性能向上は、タイリングとカーネル融合という2つの基本的な最適化技術によって実現されています。これらの技術がどのように機能し、なぜそれほど効果的なのかを詳しく見ていきましょう。
タイリングは、大きな計算問題を小さなタイルまたはブロックに分割し、各タイルを高速なオンチップメモリで処理する技術です。この概念は、GPU計算において広く使われていますが、Flash Attentionはアテンション機構に特化した巧妙なタイリング戦略を採用しています。
タイリングの利点を理解するために、メモリアクセスのコストを考えましょう。NVIDIA A100では、HBMの帯域幅は約1.5-2 TB/sですが、SRAMの帯域幅は約19 TB/sです。つまり、SRAMはHBMよりも約10倍高速です。さらに、SRAMへのアクセスのエネルギーコストは、HBMの約100分の1です。したがって、できるだけ多くのデータをSRAMに保持し、そこで計算を行うことが理想的です。
Flash Attentionのタイリング戦略は、2次元のブロック構造を使用します。Query行列を行方向にブロックに分割し、KeyとValue行列も同様に分割します。各ブロックのサイズは、SRAMの容量に収まるように選択されます。A100のSRAMは約20MBなので、典型的なブロックサイズは64×64から256×256の範囲です。
具体的な数値で見てみましょう。N = 4096、d = 128、ブロックサイズB = 128の場合を考えます。Qは32個のブロックに分割されます(4096 / 128 = 32)。各Qブロックは[128, 128]で、FP16の場合32KBです。同様に、KとVブロックもそれぞれ32KBです。ブロック間のスコア行列S_blockは[128, 128]で32KBです。これらを合計すると、約128KB程度で、20MBのSRAMに余裕で収まります。
タイリングのもう一つの重要な側面は、データの再利用です。各Qブロックについて、すべてのK、Vブロックを反復処理します。Qブロックは一度だけロードされ、内側のループ全体で再利用されます。これにより、データ転送のオーバーヘッドが削減されます。同様に、各K、Vブロックは、すべてのQブロックで再利用されるため、HBMからの読み込み回数が最小化されます。
タイリングの効果を定量化してみましょう。標準的なアテンションでは、N×Nのスコア行列全体をHBMに書き込み、読み戻します。N = 4096の場合、これは約32MBです。Flash Attentionでは、ブロックスコアのみをSRAMで処理し、HBMには書き込みません。さらに、Q、K、Vの読み込み回数も最適化されます。総HBMアクセスは、理論的には約10分の1から20分の1に削減されます。
カーネル融合は、もう一つの重要な最適化です。標準的なPyTorchやTensorFlowの実装では、アテンションの計算は複数の別々のカーネルで実行されます。行列乗算のカーネル、ソフトマックスのカーネル、そして再び行列乗算のカーネルです。各カーネルの間で、中間結果がHBMに書き戻され、次のカーネルで読み直されます。
カーネル起動のオーバーヘッドも無視できません。GPUでカーネルを起動するには、数マイクロ秒かかります。小さな演算では、この起動時間が実際の計算時間よりも長くなることがあります。さらに、各カーネルは独立してスケジュールされるため、GPUリソースの利用効率が低下する可能性があります。
Flash Attentionは、これらすべての演算を単一の融合カーネルに統合します。Q × K^T、ソフトマックス、P × Vという一連の操作がすべて、同じカーネル内で、SRAMのデータを使って実行されます。中間結果は一切HBMに書き込まれません。これにより、メモリアクセスとカーネル起動のオーバーヘッドが大幅に削減されます。
融合カーネルの実装は、低レベルのCUDAプログラミングを必要とします。Flash Attentionのカーネルは、数千行のCUDAコードで構成されており、多くの低レベル最適化が含まれています。たとえば、共有メモリのバンク衝突を避けるためのパディング、ワープレベルのプリミティブを使った効率的な通信、そして命令レベルの並列性を最大化するための慎重なスケジューリングなどです。
ワープ(warp)レベルの最適化も重要です。GPUでは、32個のスレッドが一つのワープとしてグループ化され、同期的に実行されます。Flash Attentionは、ワープレベルのreduce操作を使って、ソフトマックスの最大値と合計を効率的に計算します。これらの操作は、レジスタとワープシャッフル命令を使って、共有メモリへのアクセスなしで実行されます。
メモリコアレッシング(memory coalescing)も最適化されています。GPUでは、隣接するスレッドが連続するメモリアドレスにアクセスすると、これらのアクセスが単一のメモリトランザクションに統合され、効率が向上します。Flash Attentionは、データレイアウトとアクセスパターンを慎重に設計し、コアレッシングを最大化します。
パイプライニングも活用されています。データのロード、計算、ストアを重複させることで、メモリレイテンシを隠蔽します。たとえば、現在のブロックを計算しながら、次のブロックのデータをバックグラウンドでプリフェッチします。これにより、メモリ帯域幅を最大限活用できます。
ダブルバッファリング(double buffering)も使用されます。2つのバッファをSRAMに確保し、一方のバッファで計算を実行しながら、もう一方のバッファにデータをロードします。計算が完了したら、バッファを切り替えます。これにより、計算とメモリアクセスを完全に重複させることができます。
数値安定性も慎重に管理されています。ソフトマックスの計算では、指数関数がオーバーフローする可能性があります。Flash Attentionは、オンラインソフトマックスアルゴリズムで、最大値を動的に追跡し、すべての値をその最大値に対して正規化します。これにより、数値安定性を保ちながら、効率的な計算が可能になります。
異なるヘッドの並列化も考慮されています。マルチヘッドアテンションでは、各ヘッドは独立して計算できます。Flash Attentionは、複数のヘッドを異なるスレッドブロックに割り当て、並行して処理します。これにより、GPUの多数のストリーミングマルチプロセッサ(SM)を効率的に活用できます。
シーケンス長が非常に長い場合、追加の最適化が適用されます。たとえば、スパースアテンション(sparse attention)のパターンをサポートする場合、マスクされたブロックをスキップすることで、不要な計算を避けます。Flash Attention 2では、このようなスパースパターンのサポートが改善されています。
異なるハードウェアに対する最適化も重要です。A100とH100では、SRAMの容量や帯域幅、Tensor Coreの性能が異なります。Flash Attentionの実装は、実行時にハードウェアを検出し、最適なブロックサイズとスケジューリング戦略を選択します。これにより、異なるGPUで最適な性能を達成できます。
autotuning(自動チューニング)も一部の実装で使用されます。様々なブロックサイズ、スレッド構成、メモリレイアウトを試し、最も高速な構成を選択します。これは、特定のモデルサイズやハードウェアに対して最適化する際に有用です。
実際の性能改善は印象的です。標準的なPyTorchの実装と比較して、Flash Attentionは2倍から4倍高速です。特に長いシーケンスでは、改善がさらに顕著になります。N = 2048では約2倍、N = 4096では約3倍、N = 8192では約4倍の高速化が観測されることがあります。これは、メモリボトルネックがより深刻になるほど、Flash Attentionの利点が大きくなるためです。
エネルギー効率の改善も重要です。HBMアクセスは、計算よりもはるかに多くのエネルギーを消費します。Flash Attentionは、HBMアクセスを削減することで、計算あたりのエネルギー消費を大幅に削減します。これは、大規模トレーニングの環境コストを削減する上で重要です。
タイリングとカーネル融合の組み合わせは、他の演算にも適用できる一般的な原理です。実際、同様のアイデアが、フィードフォワードネットワーク、レイヤーノーマライゼーション、その他のTransformerコンポーネントの最適化にも使われています。Flash Attentionは、メモリ階層を意識した計算の重要性を示す優れた例であり、この分野の今後の研究にも影響を与えています。
8.4 実験結果:速度とメモリ効率の改善
Tatsunori Hashimoto: Flash Attentionの理論的な利点を理解したところで、実際のベンチマーク結果と実験データを見ていきましょう。これらの数値は、Flash Attentionが実世界のアプリケーションでどれほどの改善をもたらすかを明確に示しています。
まず、forward passの速度から見ていきましょう。元のFlash Attention論文では、NVIDIA A100 GPUで包括的なベンチマークが行われました。様々なシーケンス長とバッチサイズで、標準的なPyTorch実装と比較されました。結果は一貫してFlash Attentionの優位性を示しています。
シーケンス長N = 512の場合、Flash Attentionは標準実装と比較して約1.5倍から2倍高速でした。これは比較的短いシーケンスで、メモリボトルネックがまだそれほど深刻ではないケースです。それでも、カーネル融合とメモリアクセスの削減により、明確な改善が見られます。
N = 1024では、高速化は約2倍から2.5倍になります。N = 2048では約2.5倍から3倍です。そしてN = 4096では、約3倍から4倍の高速化が観測されました。この傾向は明確です。シーケンス長が長くなるほど、Flash Attentionの利点が大きくなります。これは、長いシーケンスでは、N²のメモリアクセスが支配的になり、Flash Attentionのメモリ最適化がより効果的になるためです。
さらに長いシーケンスでは、改善はさらに劇的です。N = 8192では、Flash Attentionは約4倍から5倍高速でした。N = 16384では、標準実装がメモリ不足で実行できない場合もある一方で、Flash Attentionは問題なく実行でき、実行可能な場合でも5倍以上の高速化を示しました。
具体的な数値で見てみましょう。バッチサイズ8、ヘッド数12、ヘッドあたりの次元64(総隠れ層次元768)のBERT-baseサイズのモデルで測定された結果です。N = 512では、標準実装が約0.8ミリ秒、Flash Attentionが約0.5ミリ秒でした。N = 2048では、標準実装が約12ミリ秒、Flash Attentionが約4ミリ秒でした。N = 4096では、標準実装が約48ミリ秒、Flash Attentionが約14ミリ秒でした。
メモリ使用量の改善も顕著です。標準的なアテンションでは、メモリ使用量はO(N²)です。Flash Attentionでは、アテンション行列全体を保存しないため、メモリ使用量はO(N)に削減されます。具体的には、ソフトマックスの統計量(最大値と合計)のみをO(N)のメモリで保存します。
バッチサイズ32、N = 2048、d = 768の設定で測定すると、標準実装は約2.5GBのGPUメモリを消費しました。一方、Flash Attentionは約0.8GBでした。これは約3倍の削減です。N = 4096では、標準実装が約9GBを消費するのに対し、Flash Attentionは約1.5GBでした。約6倍の削減です。
この大幅なメモリ削減により、より大きなバッチサイズやより長いシーケンスを扱えるようになります。実験では、同じGPUメモリ容量で、Flash Attentionを使用することで、バッチサイズを2倍から4倍に増やせることが示されました。あるいは、同じバッチサイズで、シーケンス長を約2倍に延ばせます。
トレーニング速度への影響も測定されました。GPT-2スタイルのモデル(12層、12ヘッド、768次元)をWikiTextデータセットでトレーニングする実験では、Flash Attentionを使用することで、全体のトレーニング速度が約15%から20%向上しました。アテンションはTransformerの一部に過ぎないため、全体の速度向上は個々のアテンション層の高速化よりも控えめですが、それでも有意な改善です。
より大きなモデルでは、改善がさらに大きくなります。GPT-3サイズのモデル(96層、96ヘッド、12288次元)では、アテンションがモデル全体の計算時間のより大きな割合を占めます。このようなモデルでは、Flash Attentionにより、トレーニング速度が約25%から30%向上することが報告されています。
長文書処理タスクでは、効果がさらに顕著です。シーケンス長8192や16384を使用する文書分類タスクでは、標準実装では単一のA100にバッチサイズ1または2しか収まりませんでした。Flash Attentionを使用すると、バッチサイズ8から16を使用でき、全体のスループットが約5倍から10倍向上しました。
推論時の改善も重要です。自己回帰生成では、各ステップで新しいトークンを生成し、KVキャッシュを更新します。Flash Attentionは、このKVキャッシュの更新とアテンション計算を効率化します。長いコンテキスト(たとえば4096トークン)で生成する場合、Flash Attentionは生成速度を約2倍から3倍向上させました。
エンドツーエンドのモデルトレーニング実験も行われました。LongformerやBigBirdのようなアーキテクチャは、長いシーケンスを効率的に処理するために設計されていますが、スパースアテンションパターンを使用します。Flash Attentionを使用した完全なアテンションは、これらのスパースアテンションモデルと同等またはより高速に実行されながら、より良いモデル品質を達成することが示されました。
ビジョントランスフォーマー(ViT)でも実験が行われました。高解像度画像(たとえば512×512ピクセル)をパッチサイズ16×16で処理する場合、シーケンス長は1024になります。Flash Attentionにより、ViTのトレーニングが約30%高速化され、より大きなバッチサイズを使用できるようになりました。
異なるハードウェアでの性能も評価されました。NVIDIA V100では、A100と同様の相対的な高速化が観測されましたが、絶対的な速度は遅くなります。A100では、Tensor Coreのサポートがより良く、メモリ帯域幅も高いため、Flash Attentionの利点がより顕著です。H100では、さらなる改善が期待されます。
TPU(Tensor Processing Unit)でも、同様の原理を適用した実装が試されました。TPUのメモリ階層は異なりますが、タイリングとメモリ最適化の基本原理は同様に有効でした。TPU v4では、標準実装と比較して約2倍から3倍の高速化が報告されています。
数値精度の影響も調べられました。FP16、BF16、FP32で実験が行われ、すべての精度でFlash Attentionは高速化を示しました。ただし、低精度(FP16、BF16)では、計算速度が向上するため、相対的な改善がやや小さくなります。それでも、FP16で約2倍から4倍の高速化が一貫して観測されました。
エネルギー効率も測定されました。Flash Attentionは、同じ計算を実行するのに必要な電力が約30%から50%少なくなりました。これは、HBMアクセスがSRAMアクセスよりもはるかに多くのエネルギーを消費するためです。大規模トレーニングでは、このエネルギー削減は、運用コストと環境影響の両面で重要です。
モデルの品質への影響も評価されました。Flash Attentionは、標準的なアテンションと数学的に同等の結果を計算するため、理論的にはモデルの品質に影響しないはずです。実験でも、同じハイパーパラメータで同じ収束曲線が得られることが確認されました。わずかな数値誤差はありますが、最終的なモデル性能には統計的に有意な差はありませんでした。
実世界のアプリケーションでの採用も進んでいます。Hugging Face Transformersライブラリは、Flash Attentionをオプションとしてサポートしています。PyTorchも、torch.nn.functional.scaled_dot_product_attentionという関数で、内部的にFlash Attentionを使用できます。これにより、既存のコードを最小限の変更で高速化できます。
ユーザーからの報告も肯定的です。多くの研究者や実務家が、Flash Attentionにより、以前は実現不可能だった長いシーケンスでの実験が可能になったと報告しています。たとえば、16k トークンのコンテキストでの言語モデルトレーニングや、高解像度画像でのビジョントランスフォーマーのトレーニングなどです。
これらの実験結果は、Flash Attentionが単なる理論的な改善ではなく、実際のアプリケーションで大きな影響を与えることを明確に示しています。速度、メモリ効率、エネルギー効率のすべてで大幅な改善が得られ、新しい研究の可能性を開いています。
8.5 Flash Attention 2以降の発展
Tatsunori Hashimoto: 元のFlash Attentionの成功を受けて、研究コミュニティはさらなる改善を追求してきました。Flash Attention 2、そしてその後のFlash Attention 3は、元のアイデアを洗練し、さらなる性能向上を実現しています。これらの発展を見ていきましょう。
Flash Attention 2は、2023年に発表され、元のFlash Attentionからいくつかの重要な改善を加えました。最も大きな改善は、並列化戦略の再設計です。元のFlash Attentionは、主にQuery側(外側のループ)での並列化に焦点を当てていました。各Qブロックが異なるスレッドブロックで処理されます。Flash Attention 2は、さらにKey/Value側(内側のループ)でも並列化を追加しました。
この追加の並列化により、GPUの利用率が大幅に向上しました。特に、短いシーケンスや小さなバッチサイズでは、元のFlash Attentionでは十分なスレッドブロックを起動できず、一部のストリーミングマルチプロセッサ(SM)がアイドル状態になることがありました。Flash Attention 2は、より細かい粒度の並列化により、より多くのSMを活用できます。
具体的な性能改善として、Flash Attention 2は元のFlash Attentionと比較して、さらに約1.5倍から2倍高速になりました。つまり、標準的なPyTorch実装と比較すると、合計で約3倍から8倍の高速化を達成しています。特にH100 GPUでは、改善がより顕著で、一部のワークロードで10倍以上の高速化が報告されています。
Flash Attention 2のもう一つの重要な改善は、マルチクエリアテンション(MQA)とグループドクエリアテンション(GQA)のサポートです。これらは、推論効率を向上させるために設計された変種で、複数のクエリヘッドが単一のキー/バリューヘッドを共有します。元のFlash Attentionは、標準的なマルチヘッドアテンションに最適化されていましたが、Flash Attention 2はこれらの変種も効率的に処理できます。
MQAとGQAでは、KVキャッシュのサイズが大幅に削減されます。たとえば、32個のクエリヘッドと4個のKVヘッドを持つGQA(グループサイズ8)では、KVキャッシュは標準的なMHAの8分の1になります。Flash Attention 2は、この構造を活用し、メモリアクセスパターンを最適化します。推論時のスループットが約2倍から3倍向上しました。
Flash Attention 2は、アルゴリズムレベルでも改善を加えました。ソフトマックスの計算がより効率的に再構成され、数値安定性が向上しています。また、backward passの実装が最適化され、forward passだけでなく、トレーニング全体でより大きな高速化が得られるようになりました。
バリアブルレングス(可変長)シーケンスの処理も改善されました。実際のバッチでは、すべてのシーケンスが同じ長さとは限りません。パディングを使用すると、計算が無駄になります。Flash Attention 2は、各シーケンスの実際の長さを認識し、パディング部分の計算をスキップすることで、効率を向上させます。
Flash Attention 3は、さらなる進化を遂げています。この最新バージョンは、特にH100やH200などの最新のGPUアーキテクチャに最適化されています。これらのGPUは、Tensor Memory Accelerator(TMA)やWarp Group Matrix Multiply Accumulate(WGMMA)などの新しいハードウェア機能を持っており、Flash Attention 3はこれらを最大限活用します。
TMAは、グローバルメモリと共有メモリ間のデータ転送を非同期的かつ効率的に処理するハードウェアユニットです。Flash Attention 3は、TMAを使用して、データのロードとストアを計算と完全に重複させます。これにより、メモリレイテンシをさらに隠蔽し、スループットが向上します。
WGMMAは、複数のワープグループにまたがる大きな行列演算を効率的に実行する機能です。Flash Attention 3は、WGMMAを使用して、ブロック行列乗算をより大きな粒度で実行します。これにより、Tensor Coreの利用率が向上し、計算スループットが増加します。
Flash Attention 3の性能は印象的です。H100 GPUでは、Flash Attention 2と比較してさらに約1.5倍から2倍高速になりました。つまり、標準的な実装と比較すると、合計で約5倍から15倍の高速化です。特に長いシーケンス(N > 4096)では、改善がより大きくなります。
プロダクション最適化も進んでいます。FlashInferというプロジェクトは、Flash Attentionの原理を推論ワークロードに特化して最適化しています。バッチ推論、動的バッチング、PagedAttentionとの統合など、実世界のサービングシナリオで重要な機能をサポートしています。
PagedAttentionとの統合は特に興味深いです。PagedAttentionは、KVキャッシュをページに分割し、メモリを効率的に管理する技術です。これにより、複数のリクエストで効率的にバッチ処理でき、メモリ使用量も削減されます。Flash Attentionの高速な計算とPagedAttentionの効率的なメモリ管理を組み合わせることで、推論スループットが大幅に向上します。
スパースアテンションのサポートも進化しています。Flash Attention 2と3は、ブロックスパースパターン(特定のブロックのみが非ゼロ)を効率的に処理できます。これは、Longformerのようなスライディングウィンドウアテンションや、BigBirdのようなグローバル+ローカルアテンションパターンで有用です。マスクされたブロックをスキップすることで、計算量を削減できます。
クロスアテンションのサポートも改善されました。エンコーダー・デコーダーモデルでは、デコーダーがエンコーダーの出力にアテンションを向けます。このクロスアテンションも、Flash Attentionの原理を使って最適化できます。Flash Attention 2以降は、クロスアテンション用の特殊な最適化を含んでおり、これらのモデルでも大幅な高速化が得られます。
コミュニティでの採用も加速しています。主要なディープラーニングフレームワークがFlash Attentionをネイティブにサポートしています。PyTorch 2.0以降では、torch.compile()と組み合わせることで、Flash Attentionが自動的に使用される場合があります。TensorFlowやJAXでも、同様の最適化されたアテンション実装が利用可能です。
Hugging Face Transformersライブラリでは、多くの人気モデル(BERT、GPT、LLaMA、Mistralなど)がFlash Attentionをサポートしています。from_pretrained()メソッドでattn_implementation="flash_attention_2"を指定するだけで、簡単に使用できます。これにより、既存のコードを最小限の変更で高速化できます。
vLLMやText Generation Inference(TGI)などの推論フレームワークも、Flash Attentionを標準で使用しています。これらのフレームワークは、大規模言語モデルの高スループット推論に最適化されており、Flash Attentionは不可欠なコンポーネントとなっています。
学術研究でも、Flash Attentionの原理が広く適用されています。FlashAttention-inspired最適化が、他のシーケンスモデリングアーキテクチャ(State Space ModelsやRNNなど)にも適用されています。メモリ階層を意識した計算という基本原理は、ニューラルネットワーク最適化の広範な分野で影響を与えています。
今後の方向性として、さらなるハードウェア最適化が期待されます。次世代のGPUやAIアクセラレータは、より大きなオンチップメモリ、より高速な相互接続、そして新しい命令セットを持つでしょう。Flash Attentionの後継技術は、これらの新しいハードウェア機能を活用して、さらなる性能向上を実現すると考えられます。
アルゴリズム面でも、改善の余地があります。より洗練されたタイリング戦略、動的なブロックサイズ選択、そしてワークロード特性に応じた適応的な最適化などが研究されています。また、量子化やスパース性と組み合わせた複合的な最適化も有望な方向性です。
Flash Attentionの成功は、深層学習における効率的な実装の重要性を強調しています。アルゴリズムの理論的な複雑性だけでなく、実際のハードウェアでの実行効率が、実用的な性能を決定します。メモリアクセスパターン、データレイアウト、そしてハードウェア機能の理解が、最先端の実装には不可欠です。Flash Attentionとその後継技術は、この分野のベストプラクティスを示す重要な例となっています。
9. Quantization(量子化)
9.1 量子化の基本概念
Tatsunori Hashimoto: 量子化は、ニューラルネットワークの効率を向上させるための強力な技術です。その基本的なアイデアは、モデルのパラメータや演算で使用される数値の精度を下げることで、メモリ使用量と計算コストを削減することです。量子化の基本概念から理解していきましょう。
通常、ニューラルネットワークのトレーニングは32ビット浮動小数点数(FP32)で行われます。FP32は、符号ビット1つ、指数ビット8つ、仮数ビット23個で構成されています。これにより、非常に広い数値範囲と高い精度を表現できます。しかし、多くの場合、この高い精度は必要以上であり、より低い精度でも十分に良い性能を達成できることがわかっています。
量子化の基本的な考え方は、これらの32ビットの数値を、より少ないビット数で表現することです。最も一般的なターゲットは、16ビット(FP16やBF16)、8ビット(INT8)、そして4ビット(INT4)です。ビット数を半分にすれば、メモリ使用量も半分になります。さらに、多くのハードウェアは、低精度演算をより高速に実行できます。
量子化には、大きく分けて2つのタイプがあります。対称量子化(Symmetric Quantization)と非対称量子化(Asymmetric Quantization)です。対称量子化では、浮動小数点数の範囲[-α, α]を整数の範囲[-127, 127](8ビットの場合)にマッピングします。非対称量子化では、異なる範囲、たとえば[β, γ]を[0, 255]にマッピングできます。
対称量子化の式は比較的シンプルです。浮動小数点数xを量子化するには、まずスケーリングファクターsを計算します。s = α / 127、ここでαは元の値の最大絶対値です。次に、量子化された値x_quantは、x_quant = round(x / s)となります。逆量子化は、x_dequant = x_quant × sです。
非対称量子化では、ゼロポイントzが追加されます。x_quant = round((x - β) / s) + z、ここでs = (γ - β) / 255、z = round(-β / s)です。非対称量子化は、値の分布が偏っている場合、つまり負の値がほとんどない場合などに、より効率的です。
量子化の粒度も重要な設計選択です。テンソルレベルの量子化では、テンソル全体に対して単一のスケーリングファクターを使用します。これは最もシンプルですが、テンソル内の値の範囲が広い場合、精度が低下する可能性があります。チャネルレベルまたは行レベルの量子化では、各チャネルまたは各行に対して異なるスケーリングファクターを使用します。これにより、より細かい粒度で量子化でき、精度の低下を抑えられます。
量子化には、いくつかの主要な利点があります。まず、メモリ使用量の削減です。モデルのサイズを8ビット量子化で4分の1に、4ビット量子化で8分の1に削減できます。これにより、より大きなモデルを限られたメモリに収めることができ、あるいは同じハードウェアでより大きなバッチサイズを使用できます。
次に、計算速度の向上です。多くの現代のハードウェア、特にGPUやTPU、そして専用のAIアクセラレータは、低精度演算を高速に実行できます。たとえば、NVIDIA TensorコアはINT8演算をFP32の約4倍の速度で実行できます。一部のハードウェアでは、INT4演算はさらに高速です。
メモリ帯域幅の削減も重要な利点です。低精度の数値は、メモリからロードする際のデータ量が少なくなります。メモリバウンドな演算では、これが実際の性能向上につながります。たとえば、アテンション機構のような演算では、メモリアクセスがボトルネックになることが多く、量子化によるメモリ転送の削減が大きな効果を持ちます。
エネルギー効率も向上します。低精度演算は、高精度演算よりも少ない電力を消費します。また、メモリアクセスも電力を大量に消費するため、データ転送量の削減はエネルギー効率の向上に直結します。大規模なデプロイメントでは、この電力削減は、運用コストと環境影響の両面で重要です。
しかし、量子化には課題もあります。最も大きな課題は、精度の低下です。数値の表現精度が下がるため、モデルの性能が劣化する可能性があります。特に、量子化に対して敏感なレイヤーや演算があります。たとえば、レイヤーノーマライゼーションやソフトマックスは、数値範囲が広く、量子化が困難です。
アウトライア(外れ値)の問題も重要です。ほとんどの値が小さな範囲にある中で、少数の非常に大きな値が存在する場合、スケーリングファクターがこれらのアウトライアに引きずられ、大多数の小さな値の精度が大幅に低下します。これは、量子化後の性能劣化の主要な原因の一つです。
量子化の適用タイミングも重要な選択です。Post-Training Quantization(PTQ、学習後量子化)では、既にトレーニングされたモデルを量子化します。これは実装が簡単で、既存のモデルにすぐに適用できますが、精度の低下が大きくなる可能性があります。Quantization-Aware Training(QAT、量子化を考慮した学習)では、トレーニング中に量子化をシミュレートし、モデルが量子化に適応します。これにより、精度の低下を最小化できますが、トレーニングプロセスが複雑になります。
混合精度量子化も注目されています。すべてのレイヤーを同じ精度に量子化するのではなく、レイヤーごとに異なる精度を使用します。たとえば、量子化に敏感なレイヤーは8ビット、他のレイヤーは4ビットといった具合です。これにより、精度と効率のバランスを最適化できます。
グループ量子化も重要な技術です。大きな重み行列を小さなグループに分割し、各グループに対して独立したスケーリングファクターを使用します。たとえば、4096次元のベクトルを128次元のグループに分割し、各グループを独立に量子化します。これにより、アウトライアの影響を局所化し、全体の精度を向上させることができます。
動的量子化と静的量子化の違いも理解しておくことが重要です。静的量子化では、スケーリングファクターはトレーニングまたはキャリブレーション段階で決定され、推論時は固定です。動的量子化では、スケーリングファクターが実行時に計算されます。動的量子化は、入力データの分布が変動する場合により適していますが、計算オーバーヘッドがあります。
量子化は、特に推論において非常に効果的です。推論では、forward passのみを実行すればよく、backward passや勾配の保存が不要なため、量子化をより積極的に適用できます。多くの実用システムでは、トレーニングはFP32やBF16で行い、推論時にINT8やINT4に量子化するアプローチが取られています。
トレーニング時の量子化は、より困難です。Backward passでは、勾配の精度が重要であり、過度の量子化は収束を妨げる可能性があります。それでも、最近の研究では、適切な技術を使用することで、トレーニング中でも低精度演算を効果的に使用できることが示されています。これについては、後のセクションで詳しく議論します。
9.2 量子化の種類(Post-training quantization, Quantization-aware training)
Tatsunori Hashimoto: 量子化を実装する際の最も重要な選択の一つは、いつ、どのように量子化を適用するかです。Post-Training Quantization(PTQ)とQuantization-Aware Training(QAT)という2つの主要なアプローチがあり、それぞれ異なるトレードオフを提供します。これらの詳細を見ていきましょう。
Post-Training Quantization(PTQ、学習後量子化)は、最も単純で広く使われているアプローチです。その名の通り、モデルのトレーニングが完了した後に量子化を適用します。既にトレーニングされたFP32またはFP16のモデルを取り、その重みとアクティベーションを低精度に変換します。このプロセスは、モデルの再トレーニングを必要としないため、非常に高速で実用的です。
PTQの基本的なワークフローは次の通りです。まず、完全精度でトレーニングされたモデルから始めます。次に、代表的なキャリブレーションデータセット、通常は数百から数千のサンプルを使用して、各レイヤーの活性化の統計(最小値、最大値、分布など)を収集します。これらの統計を使って、各レイヤーの最適なスケーリングファクターとゼロポイントを決定します。最後に、重みを量子化し、推論時にアクティベーションも動的に量子化します。
PTQの最も単純な形式は、ウェイトオンリー量子化(Weight-Only Quantization)です。これは、モデルの重みのみを量子化し、アクティベーションは元の精度のままにします。この方法は、メモリ使用量を削減できますが、計算速度の改善は限定的です。なぜなら、実際の演算は依然として高精度で行われるからです。しかし、実装が非常に簡単で、精度の劣化も最小限であるため、大規模言語モデルの推論で広く使用されています。
動的量子化(Dynamic Quantization)は、PTQのもう一つの形式です。重みは事前に量子化されますが、アクティベーションは実行時に動的に量子化されます。各レイヤーの入力を見て、その場でスケーリングファクターを計算し、量子化してから演算を実行します。これにより、様々な入力分布に適応できますが、スケーリングファクターの計算にオーバーヘッドがあります。
静的量子化(Static Quantization)は、より積極的なアプローチです。重みもアクティベーションも事前に決定されたスケーリングファクターで量子化されます。キャリブレーションフェーズで収集した統計を使って、各レイヤーのアクティベーションの量子化パラメータを事前に決定します。推論時は、これらの固定パラメータを使用するため、動的量子化よりも高速です。
PTQの利点は明確です。まず、実装が簡単です。既存のトレーニングパイプラインを変更する必要がなく、トレーニング済みモデルに後から適用できます。次に、高速です。再トレーニングが不要なため、数分から数時間で量子化を完了できます。また、様々なモデルに容易に適用でき、汎用性が高いです。
しかし、PTQには欠点もあります。最も大きな問題は、精度の劣化です。特に、8ビット未満の量子化(4ビットや2ビット)では、性能が大幅に低下する可能性があります。モデルは量子化に適応するようにトレーニングされていないため、量子化エラーが蓄積し、最終的な出力の品質が低下します。
PTQが特に困難なケースもあります。小さなモデルでは、パラメータ数が少ないため、各パラメータの重要性が高く、量子化の影響が大きくなります。また、タスクが複雑で、高い精度が要求される場合、PTQでは不十分なことがあります。さらに、極端なアウトライアを持つモデルでは、PTQは効果的ではありません。
Quantization-Aware Training(QAT、量子化を考慮した学習)は、これらの問題に対処するより洗練されたアプローチです。QATでは、トレーニング中に量子化をシミュレートします。Forward passでは、重みとアクティベーションを量子化してから演算を実行しますが、backward passでは、勾配は浮動小数点数で計算されます。これにより、モデルは量子化の影響を学習中に考慮し、それに適応できます。
QATの基本的なメカニズムは、「fake quantization(偽量子化)」と呼ばれます。実際には浮動小数点数で計算を行いますが、各演算の後に、値を量子化してから逆量子化します。たとえば、x_fake_quant = dequantize(quantize(x))です。この操作は、量子化エラーをシミュレートします。Backward passでは、この操作の勾配を計算する必要がありますが、quantize関数は不連続なので、直接微分できません。
この問題を解決するために、Straight-Through Estimator(STE)という技術が使われます。STEは、forward passでは量子化を適用しますが、backward passでは量子化操作を恒等関数として扱います。つまり、∂quantize(x)/∂x ≈ 1とします。これは厳密には正しくありませんが、実践的には効果的です。モデルは、量子化による離散化を考慮しながら学習できます。
QATのトレーニングプロセスは、通常のトレーニングと似ていますが、いくつかの追加のステップがあります。まず、通常の精度でモデルを事前トレーニングすることが多いです。これにより、良い初期化が得られます。次に、fake quantizationノードをモデルに挿入します。これらは、各重み行列と各アクティベーションの後に配置されます。そして、量子化を考慮したファインチューニングを行います。通常、低い学習率で数エポックトレーニングします。
QATには、いくつかの重要なハイパーパラメータがあります。量子化を開始するタイミングは重要です。トレーニングの初期から量子化を適用すると、収束が困難になる可能性があります。一般的には、モデルがある程度収束してから量子化を導入します。また、量子化を段階的に導入する戦略もあります。最初は軽い量子化(8ビット)から始め、徐々にビット数を減らしていきます。
スケーリングファクターの学習も重要な設計選択です。固定スケーリングファクター(トレーニング前にキャリブレーションで決定)を使用することもできますが、スケーリングファクター自体を学習可能なパラメータにすることもできます。学習可能なスケーリングファクターは、各レイヤーに最適な量子化パラメータを自動的に見つけることができ、一般的により良い結果をもたらします。
QATの利点は、精度の向上です。モデルが量子化に適応するため、PTQよりも大幅に精度の劣化が少なくなります。特に、4ビットやそれ以下の極端な量子化では、QATは不可欠です。実験では、QATを使用した4ビットモデルが、PTQを使用した8ビットモデルよりも良い性能を示すことがあります。
QATのもう一つの利点は、より積極的な量子化が可能になることです。PTQでは困難だった、混合精度量子化(レイヤーごとに異なるビット幅)や、2ビット量子化なども、QATでは実現可能です。モデルは、これらの極端な制約の下でも機能するように学習できます。
しかし、QATにも欠点があります。最も大きな問題は、計算コストです。モデルを再トレーニングまたはファインチューニングする必要があるため、PTQよりもはるかに時間がかかります。大規模モデルでは、QATに数日から数週間かかることがあります。また、実装も複雑です。Fake quantizationノードの挿入、STEの実装、そして適切なハイパーパラメータの調整など、技術的な専門知識が必要です。
実践的には、ハイブリッドアプローチが有効なことが多いです。まず、PTQを試して、精度がどの程度劣化するかを評価します。劣化が許容範囲内であれば、PTQを使用します。劣化が大きい場合、QATに移行します。あるいは、一部のレイヤーはPTQ、量子化に敏感なレイヤーはQATという混合アプローチも可能です。
大規模言語モデルのコンテキストでは、PTQが主流です。これらのモデルは非常に大きく、再トレーニングのコストが高いためです。しかし、GPTQ、AWQ、SmoothQuantなどの高度なPTQ技術が開発され、QATに匹敵する精度を達成できるようになっています。これらの技術は、重みの再配置、アウトライアの特別な処理、レイヤーごとの最適化などを使用します。
一方、エッジデバイス向けの小規模モデルでは、QATが依然として重要です。これらのモデルは、極端な量子化(4ビットや2ビット)が必要で、PTQだけでは不十分なことが多いです。また、モデルサイズが小さいため、QATのトレーニングコストも管理可能です。
最近の研究では、QATとPTQの長所を組み合わせた新しい手法も登場しています。たとえば、少量のデータでの短時間のQATや、PTQの後に軽いファインチューニングを行う方法などです。これらは、QATの精度向上の利点を維持しながら、計算コストを削減します。
量子化の選択は、具体的なユースケース、利用可能なリソース、精度要件、そして実装の複雑さの許容度に依存します。一般的なガイドラインとして、迅速なデプロイメント、限られた計算リソース、そして適度な精度要件がある場合はPTQを、最高の精度が必要で、トレーニングリソースが利用可能な場合はQATを選択することが推奨されます。
9.3 精度低下と推論速度のトレードオフ
Tatsunori Hashimoto: 量子化における最も重要な考慮事項は、モデルの精度と推論速度の間のトレードオフです。より積極的な量子化は、より高速な推論とより少ないメモリ使用量をもたらしますが、モデルの性能が低下するリスクも高まります。このトレードオフを定量的に理解し、適切なバランスを見つけることが重要です。
まず、異なる量子化レベルでの精度への影響を見ていきましょう。FP32からFP16またはBF16への量子化は、通常、ほとんど精度の劣化がありません。多くの実験で、FP16とFP32のモデルは、ほぼ同等の性能を示します。実際、一部のケースでは、FP16の方がわずかに良い結果を示すことさえあります。これは、低精度が一種の正則化として機能する可能性があるためです。
具体的な数値で見てみましょう。BERT-baseモデルをGLUEベンチマークで評価した研究では、FP32が84.5%の平均スコア、FP16が84.3%でした。差はわずか0.2ポイントで、統計的に有意ではありません。同様に、GPT-2モデルのパープレキシティは、FP32で29.4、FP16で29.6でした。これも事実上同等です。
8ビット量子化(INT8)になると、状況は少し複雑になります。適切に実装されたINT8量子化では、精度の劣化は通常1%から3%程度です。たとえば、ResNet-50の画像分類タスクでは、FP32が76.1%のTop-1精度、INT8が75.3%でした。約0.8ポイントの低下です。言語モデルでは、パープレキシティが約2%から5%増加することが一般的です。
しかし、INT8量子化の効果は、モデルのアーキテクチャとタスクに大きく依存します。畳み込みニューラルネットワーク(CNN)は、一般的に量子化に対してロバストです。これは、畳み込み演算が局所的な特徴を捉え、個々の重みの重要性が相対的に低いためです。一方、Transformerモデルは、より敏感です。特に、アテンション機構は量子化の影響を受けやすく、ソフトマックスの数値範囲が広いため、量子化が困難です。
4ビット量子化(INT4)では、精度の劣化がより顕著になります。単純なPTQでは、5%から15%の精度低下が観測されることがあります。たとえば、70億パラメータのLlamaモデルを4ビットに量子化すると、一部のタスクでパープレキシティが10%から20%増加することがあります。ただし、GPTQ、AWQなどの高度な量子化技術を使用すると、この劣化を2%から5%程度に抑えることができます。
2ビット量子化やそれ以下の極端な量子化では、精度の劣化は非常に大きくなります。多くの場合、20%以上の性能低下が見られ、一部のタスクでは実用的でなくなります。ただし、特殊な技術(混合精度、学習可能な量子化パラメータ、知識蒸留など)を組み合わせることで、実用可能なレベルまで性能を保つことができる場合もあります。
推論速度の改善を見てみましょう。FP16量子化は、適切なハードウェアサポートがあれば、約1.5倍から2倍の高速化をもたらします。NVIDIA のA100やH100では、FP16のTensor Core演算がFP32よりもはるかに高速です。A100では、FP16が約312 TFLOPS、FP32が約19.5 TFLOPSで、約16倍の理論的な性能差があります。ただし、実際のワークロードでは、メモリアクセスや他のオーバーヘッドにより、実効的な高速化は2倍から3倍程度です。
INT8量子化は、さらに大きな高速化をもたらします。Tensor Coreを使用するA100では、INT8が約624 TOPSで、FP16の約2倍です。実際のエンドツーエンドの推論では、INT8はFP16と比較して約1.5倍から2倍高速です。たとえば、BERT-baseモデルの推論では、FP16で約10ミリ秒/バッチ、INT8で約6ミリ秒/バッチでした。
INT4量子化は、さらなる高速化の可能性を提供します。理論的には、INT8の約2倍高速ですが、実際のハードウェアサポートはまだ限定的です。一部の特殊なハードウェアやカスタムカーネルでは、INT4演算が可能で、INT8と比較して1.5倍から2倍の高速化が報告されています。
メモリ使用量の削減も重要な利点です。FP16は、FP32と比較してメモリを半分にします。70億パラメータのモデルでは、FP32で約28GB、FP16で約14GBです。INT8では約7GB、INT4では約3.5GBになります。この大幅なメモリ削減により、より大きなモデルを限られたハードウェアにデプロイできるようになります。
メモリ帯域幅の削減も、実際の性能に大きく影響します。多くの推論ワークロード、特に小さなバッチサイズでは、メモリバウンドです。つまり、計算速度ではなく、メモリからデータをロードする速度がボトルネックです。量子化により、転送するデータ量が減るため、メモリバウンドな演算が高速化されます。
バッチサイズの影響も考慮する必要があります。大きなバッチサイズでは、計算が支配的になり、量子化による計算速度の向上が直接的に性能に反映されます。小さなバッチサイズ(バッチサイズ1など)では、メモリアクセスとオーバーヘッドが支配的で、量子化の利点がより顕著になります。実際、バッチサイズ1の推論では、INT8はFP16と比較して2倍から3倍高速になることがあります。
実用的なトレードオフの例を見てみましょう。顧客サポートチャットボットを考えます。応答時間が重要で、わずかな精度の低下は許容できます。この場合、INT8量子化は優れた選択です。推論時間を半分に削減でき、精度の低下は1%から2%程度です。ユーザーエクスペリエンスが大幅に向上し、精度の低下はほとんど気づかれません。
一方、医療診断システムを考えます。精度が最優先で、わずかなエラーも許容できません。この場合、FP32またはFP16を使用し、量子化は慎重に適用する必要があります。もしINT8を使用するなら、包括的なテストと検証が必須です。精度が基準を満たさない場合、量子化は適用すべきではありません。
コスト最適化の観点も重要です。クラウドでの推論コストは、計算時間とメモリ使用量に比例します。INT8量子化により、推論時間が半分になり、メモリ使用量も半分になれば、コストは約4分の1になる可能性があります。大規模なデプロイメントでは、これは数百万ドルの節約につながります。
混合精度量子化は、トレードオフをさらに洗練させる方法です。すべてのレイヤーを同じ精度に量子化するのではなく、レイヤーごとに異なる精度を使用します。量子化に敏感なレイヤー(たとえば、最初と最後のレイヤー、アテンション層)はFP16またはINT8に、他のレイヤーはINT4にするといった具合です。これにより、精度と速度の最適なバランスを実現できます。
実験的な評価プロセスも重要です。量子化を適用する前に、代表的なデータセットで精度を測定します。次に、量子化を適用し、精度の変化を測定します。劣化が許容範囲内であれば、推論速度とメモリ使用量を測定します。これらのメトリクスを総合的に評価し、ビジネス要件に照らして決定を下します。
A/Bテストも有用です。量子化されたモデルと元のモデルを並行して実行し、実際のユーザーメトリクス(クリック率、満足度、タスク完了率など)を比較します。技術的なメトリクス(精度、パープレキシティ)の小さな劣化が、実際のビジネスメトリクスにどれだけ影響するかは、必ずしも明確ではありません。場合によっては、技術的な劣化があっても、ビジネスへの影響は無視できることがあります。
最後に、将来のハードウェアの進化も考慮する必要があります。新しいGPUやTPUは、より低精度の演算をより効率的にサポートします。INT4やINT2のハードウェアサポートが向上すれば、これらの極端な量子化がより実用的になります。また、量子化技術自体も進化しており、精度の劣化を最小化する新しいアルゴリズムが継続的に開発されています。
トレードオフの選択は、最終的には各アプリケーションの具体的な要件に依存します。精度、速度、メモリ、コスト、そしてユーザーエクスペリエンスのバランスを慎重に評価し、データに基づいた意思決定を行うことが重要です。量子化は強力なツールですが、すべてのケースで適切とは限りません。適切に適用されれば、大幅な効率向上をもたらし、より多くのユーザーに高度なAIサービスを提供できるようになります。
10. Mixed precision training(混合精度学習)
10.1 FP32, FP16, BF16の違い
Tatsunori Hashimoto: 混合精度学習を理解するためには、まず異なる浮動小数点数フォーマットの特性を深く理解する必要があります。FP32、FP16、BF16という3つの主要なフォーマットは、それぞれ異なる設計トレードオフを持ち、異なる使用シナリオに適しています。
FP32、つまり32ビット浮動小数点数は、従来のディープラーニングトレーニングの標準フォーマットです。IEEE 754標準に従い、1ビットの符号ビット、8ビットの指数部、23ビットの仮数部(または精度部)で構成されています。この構造により、FP32は非常に広い数値範囲と高い精度を提供します。
FP32の数値範囲を具体的に見てみましょう。指数部が8ビットなので、-126から127の範囲の指数を表現できます。これにより、約10の-38乗から10の38乗までの数値を表現できます。この広大な範囲は、ディープラーニングのほとんどのシナリオで十分です。仮数部の23ビットは、約7桁の10進精度を提供します。これは、ほとんどの科学計算で十分な精度です。
FP32の具体例を見てみましょう。数値1.5は、符号ビット0(正)、指数部01111111(バイアス127で実際の指数0)、仮数部10000000000000000000000で表現されます。数値0.1は正確には表現できませんが、非常に近い近似値で表現されます。この種の表現誤差は避けられませんが、FP32では非常に小さく、実用上問題になることはまれです。
FP16、つまり16ビット浮動小数点数(half precision)は、メモリと計算の効率化のために設計されました。1ビットの符号ビット、5ビットの指数部、10ビットの仮数部で構成されています。ビット数を半分にすることで、メモリ使用量が半分になり、メモリ帯域幅の要求も半分になります。
しかし、FP16には重要な制約があります。指数部が5ビットしかないため、数値範囲は-14から15の指数、つまり約10の-4乗から10の4乗(約6×10の-5乗から65,504)に制限されます。これは、FP32の範囲と比較すると非常に狭いです。仮数部も10ビットしかないため、精度は約3桁の10進数に制限されます。
FP16の制限された範囲は、ディープラーニングで問題を引き起こす可能性があります。特に、勾配は非常に小さな値になることがあり、FP16の最小正規数(約6×10の-5乗)よりも小さくなると、アンダーフローが発生します。アンダーフローした値はゼロになり、重要な情報が失われます。逆に、非常に大きな値(65,504を超える)はオーバーフローし、無限大になります。
具体的な問題の例を見てみましょう。バッチノーマライゼーションの分散計算では、値の二乗を計算します。もし値が256を超えると、その二乗は65,536を超え、FP16でオーバーフローします。また、非常に小さな学習率(たとえば1e-5)と小さな勾配(たとえば1e-3)を掛けると、結果は1e-8となり、FP16の表現可能な範囲を下回ります。
BF16、つまりBrain Float 16は、これらの問題を解決するために、GoogleのBrain Teamによって設計されました。BF16も16ビットですが、ビットの配分が異なります。1ビットの符号ビット、8ビットの指数部、7ビットの仮数部で構成されています。重要なのは、指数部がFP32と同じ8ビットであることです。
この設計により、BF16はFP32と全く同じ数値範囲を持ちます。約10の-38乗から10の38乗まで表現できます。これは、アンダーフローとオーバーフローの問題を大幅に軽減します。FP32でトレーニングできる場合、ほとんどの場合BF16でもトレーニングできます。追加の損失スケーリングや特別な処理が不要なことが多いです。
しかし、BF16のトレードオフは精度です。仮数部が7ビットしかないため、精度は約2桁の10進数に制限されます。これはFP16の10ビット(約3桁)よりも低いです。それでも、実験では、ディープラーニングトレーニングにおいて、BF16はFP16よりも安定していることが多く、最終的なモデル性能もほぼ同等であることが示されています。
3つのフォーマットを比較してみましょう。メモリ使用量では、FP16とBF16は同じで、FP32の半分です。70億パラメータのモデルは、FP32で28GB、FP16/BF16で14GBです。計算速度では、現代のハードウェアでFP16とBF16は同等で、FP32よりも大幅に高速です。NVIDIA A100では、FP16/BF16のTensor Core性能は約312 TFLOPS、FP32は約19.5 TFLOPSです。
数値範囲では、BF16とFP32が同じで、FP16は大幅に狭いです。精度では、FP32が最高(約7桁)、FP16が中程度(約3桁)、BF16が最低(約2桁)です。しかし、実際のディープラーニングでは、BF16の低い精度はほとんど問題になりません。これは、ニューラルネットワークのトレーニングが本質的にノイズに対してロバストであるためです。
具体的な数値例で比較しましょう。数値1234.5を考えます。FP32では正確に表現できます。FP16では、最も近い表現可能な値は1234.0または1235.0です(精度の制限により)。BF16でも、同様に1234.0または1235.0になります。数値0.000001(1e-6)を考えます。FP32とBF16では表現できますが、FP16ではアンダーフローしてゼロになります。
トレーニングにおける実際の影響を見てみましょう。FP32は、すべてのケースで安定したトレーニングを提供しますが、メモリと速度の面で非効率です。FP16は、最も高速で効率的ですが、損失スケーリングなどの追加技術が必要です。それでも、一部のモデルやタスクでは、数値的な不安定性が発生する可能性があります。BF16は、速度と効率の面でFP16と同等でありながら、数値的安定性の面でFP32に近いです。多くの場合、特別な処理なしで使用できます。
ハードウェアサポートも重要な考慮事項です。FP16は、NVIDIA のV100以降、すべての主要なGPUでサポートされています。BF16は、比較的新しく、A100以降のGPUと、GoogleのTPU v2以降でサポートされています。AMDのMI250などの最新のGPUもBF16をサポートしています。古いハードウェアでは、BF16はソフトウェアでエミュレートする必要があり、性能上の利点が失われます。
実践的な推奨として、最新のハードウェア(A100、H100、TPU v3以降)では、BF16が最良の選択であることが多いです。実装がシンプルで、数値的に安定しており、FP16と同等の性能を提供します。古いハードウェア(V100など)では、FP16を使用し、適切な損失スケーリングとその他の技術を適用する必要があります。最高の精度が必要で、リソースが許容する場合は、FP32を使用します。
フォーマットの選択は、トレーニングの安定性、ハードウェアの可用性、そして実装の複雑さのバランスです。次のセクションでは、これらのフォーマットを実際の混合精度トレーニングでどのように組み合わせて使用するかを見ていきます。
10.2 混合精度学習の仕組み
Tatsunori Hashimoto: 混合精度学習は、異なる精度のフォーマットを戦略的に組み合わせて、トレーニングの速度とメモリ効率を向上させながら、モデルの品質を維持する技術です。その核心的なアイデアは、すべての演算を低精度で実行する必要はなく、異なる部分に異なる精度を使用できるということです。
混合精度学習の基本的な構造を見ていきましょう。この手法では、3つの異なる重みのコピーを保持します。まず、マスターウェイト(master weights)があります。これはFP32で保存され、オプティマイザが更新する主要なコピーです。次に、FP16またはBF16の重みコピーがあります。これは実際の計算に使用されます。そして、オプティマイザの状態(モーメンタム、分散など)もFP32で保存されます。
トレーニングステップの流れを詳しく見ていきましょう。各イテレーションの開始時、FP32のマスターウェイトをFP16/BF16にキャストします。このキャストは比較的高速な演算で、単純に精度を落とすだけです。たとえば、FP32の値3.14159265をFP16にキャストすると、約3.140になります。精度は失われますが、これは許容可能です。
次に、forward passを実行します。このとき、すべての行列乗算や畳み込みなどの主要な計算はFP16/BF16で行われます。入力データもFP16/BF16にキャストされます。アクティベーション関数(ReLU、GELUなど)も低精度で実行されます。これらの演算が、トレーニング時間の大部分を占めるため、ここでの高速化が全体の性能向上につながります。
具体的な例で見てみましょう。重み行列W(FP16)とアクティベーション行列A(FP16)の乗算を考えます。結果はR = A × Wで、これもFP16です。行列乗算の各要素は、FP16の乗算と加算の繰り返しです。Tensor Coreを使用すると、この演算は非常に高速に実行されます。A100では、FP16の行列乗算がFP32の約16倍高速です。
しかし、一部の演算は依然としてFP32で実行する必要があります。特に、数値範囲が広い演算や、累積誤差が問題になる演算です。たとえば、バッチノーマライゼーションの統計計算(平均と分散)は、通常FP32で行われます。同様に、ソフトマックスの指数計算も、数値範囲の問題を避けるためFP32で実行されることがあります。
損失関数の計算もFP32で行われることが一般的です。損失値は、通常非常に小さな値になり、FP16の精度では不十分な場合があります。また、複数のサンプルの損失を平均する際、累積誤差が問題になる可能性があるため、FP32を使用します。
Backward passも基本的に低精度で実行されます。各レイヤーの勾配は、FP16/BF16で計算されます。しかし、ここで重要な技術が登場します。損失スケーリング(loss scaling)です。これは特にFP16を使用する場合に重要です。
損失スケーリングの必要性を理解するために、勾配の大きさを考えましょう。トレーニングの過程で、勾配は非常に小さな値になることがあります。たとえば、1e-6や1e-7といった値です。これらはFP16の最小正規数(約6e-5)よりも小さく、ゼロに丸められてしまいます。これをアンダーフローと呼びます。
損失スケーリングは、この問題を解決します。基本的なアイデアは、backward passを開始する前に、損失値に大きなスケーリングファクター(たとえば1024や2048)を掛けることです。これにより、backward pass中のすべての勾配も同じファクターでスケールされ、アンダーフローの範囲から離れます。
具体的な数値で見てみましょう。元の勾配が1e-6だとします。これはFP16でアンダーフローします。損失を1024倍にスケールすると、勾配も1024倍になり、約1e-3になります。これはFP16で安全に表現できます。重要なのは、オプティマイザで重みを更新する前に、この勾配を1024で割り戻すことです。これにより、実効的な更新は元の勾配と同じになります。
動的損失スケーリング(dynamic loss scaling)は、さらに洗練された技術です。固定のスケーリングファクターを使うのではなく、トレーニング中に自動的に調整します。アルゴリズムは次のように動作します。初期スケーリングファクター(たとえば65536)から始めます。各イテレーションで、勾配にinf(無限大)やNaN(非数)が含まれているかチェックします。
もしinf/NaNが検出されなければ、そのイテレーションは成功とみなされ、重みを更新します。連続して成功したイテレーション数をカウントし、一定数(たとえば2000)に達したら、スケーリングファクターを2倍にします。これにより、より小さな勾配も捉えられるようになります。
一方、inf/NaNが検出された場合、これはオーバーフローを示します。スケーリングファクターが大きすぎて、一部の値がFP16の範囲を超えました。この場合、そのイテレーションをスキップし(重みを更新しない)、スケーリングファクターを半分にします。次のイテレーションから、より小さなスケーリングファクターで続行します。
この動的調整により、スケーリングファクターは自動的に最適なレベルに収束します。トレーニングの異なる段階で勾配の大きさが変化しても、適応的に対応できます。実装は複雑に見えますが、PyTorchのtorch.cuda.amp(Automatic Mixed Precision)やNVIDIAのApexライブラリが、これらすべてを自動的に処理してくれます。
重みの更新段階では、FP32に戻ります。スケールされた勾配をスケーリングファクターで割り、元の大きさに戻します。この逆スケールされた勾配を使って、FP32のマスターウェイトとオプティマイザの状態を更新します。たとえば、Adamオプティマイザでは、モーメンタムと二次モーメントの推定値を更新し、これらを使って重みを更新します。
なぜマスターウェイトをFP32で保持するのでしょうか。これは、累積誤差を避けるためです。トレーニングは数千から数百万のイテレーションを経ます。各イテレーションで、重みにわずかな更新が加えられます。これらの更新は、学習率を掛けた後、非常に小さな値になることがあります。たとえば、学習率1e-4と勾配1e-2を掛けると、更新は1e-6です。
もし重みがFP16で保存されていた場合、この小さな更新は丸め誤差により失われる可能性があります。数千回のイテレーションを経ると、これらの失われた更新が蓄積し、モデルの収束に影響します。FP32のマスターウェイトを保持することで、これらの小さな更新が正確に蓄積され、長期的な収束が保証されます。
メモリ使用量への影響を考えましょう。混合精度学習では、マスターウェイト(FP32)、低精度ウェイト(FP16)、そしてオプティマイザの状態(FP32)を保持します。一見、メモリが増えるように思えますが、実際には低精度で計算するため、アクティベーションのメモリが大幅に削減されます。アクティベーションは、バッチサイズやシーケンス長に依存し、しばしばモデルの重みよりも多くのメモリを消費します。
具体的な数値で見てみましょう。70億パラメータのモデルで、混合精度学習を使用しない場合(すべてFP32)、パラメータが28GB、オプティマイザの状態が56GB、合計84GBです。混合精度学習を使用する場合、FP32マスターウェイトが28GB、FP16計算用ウェイトが14GB、オプティマイザの状態が56GB、合計98GBです。一見増えているように見えますが、アクティベーションが半分になるため、実際には総メモリ使用量は減少します。
BF16を使用する場合、損失スケーリングが不要なことが多いです。BF16の広い数値範囲により、ほとんどの勾配がアンダーフローしません。これにより、実装が大幅に簡素化されます。単純にforward passとbackward passをBF16で実行し、重みの更新をFP32で行うだけです。多くの最新のコードベースでは、BF16の方が好まれています。
混合精度学習の実装は、現代のフレームワークで非常に簡単になっています。PyTorchでは、数行のコードで有効化できます。torch.cuda.amp.GradScalerとtorch.cuda.amp.autocastを使用します。autocastコンテキストマネージャーは、各演算に適切な精度を自動的に選択します。GradScalerは、損失スケーリングとその逆操作を自動的に処理します。
混合精度学習により、トレーニング速度が約1.5倍から3倍向上することが一般的です。メモリ使用量も約30%から50%削減されます。これにより、より大きなバッチサイズを使用でき、全体のスループットがさらに向上します。最終的なモデルの品質は、完全なFP32トレーニングとほぼ同等です。これらすべての利点により、混合精度学習は、現代の大規模モデルトレーニングの標準的な実践となっています。
10.3 数値安定性の確保方法
Tatsunori Hashimoto: 混合精度学習における最大の課題の一つは、数値安定性の維持です。低精度演算を使用すると、オーバーフロー、アンダーフロー、そして累積誤差など、様々な数値的問題が発生する可能性があります。これらの問題を適切に管理することが、成功する混合精度トレーニングの鍵です。
損失スケーリングについては前のセクションで触れましたが、ここではより深く掘り下げていきましょう。損失スケーリングの主要な目的は、勾配のアンダーフローを防ぐことです。FP16の最小正規数は約6×10の-5乗です。これより小さい値はゼロに丸められます。トレーニング中、特に後半の段階では、勾配は非常に小さくなることがあり、この閾値を下回る可能性があります。
適切なスケーリングファクターを選択することは重要です。ファクターが小さすぎると、多くの勾配がアンダーフローします。大きすぎると、一部の勾配がオーバーフローし、inf(無限大)になります。理想的なスケーリングファクターは、できるだけ多くの勾配を表現可能な範囲に保ちながら、オーバーフローを避けるものです。
実験的に決定された良いスケーリングファクターは、通常1024から65536の範囲です。ネットワークのアーキテクチャと深さによって最適な値は異なります。浅いネットワークでは、勾配が比較的大きいため、小さなスケーリングファクター(512や1024)で十分です。深いネットワークでは、勾配消失により勾配が非常に小さくなるため、大きなスケーリングファクター(8192や16384)が必要になることがあります。
動的損失スケーリングのアルゴリズムをより詳しく見てみましょう。初期化時、スケーリングファクターsを高い値(たとえば65536 = 2の16乗)に設定します。成長間隔(growth interval)という変数も設定します。これは、スケーリングファクターを増やすまでに必要な連続成功イテレーション数で、通常2000に設定されます。バックオフファクター(backoff factor)は、オーバーフロー時にスケーリングを減らす比率で、通常0.5(半分)です。
各トレーニングイテレーションで、次のロジックが実行されます。損失にsを掛けて、backward passを実行します。計算された勾配にinf/NaNが含まれているかチェックします。このチェックは、torch.isfinite()のような関数で実行できます。inf/NaNが検出されなければ、勾配をsで割り(逆スケール)、オプティマイザのstepを実行します。連続成功カウンターをインクリメントします。カウンターが成長間隔に達したら、sを2倍にし、カウンターをリセットします。
inf/NaNが検出された場合、そのイテレーションをスキップします。重みは更新されず、オプティマイザのstepも呼ばれません。スケーリングファクターsをバックオフファクターで割ります(通常は半分にします)。連続成功カウンターをリセットします。次のイテレーションは、新しい小さいスケーリングファクターで実行されます。
このアルゴリズムは、自己調整的です。トレーニングの初期段階では、勾配が大きいため、スケーリングファクターは徐々に減少する傾向があります。トレーニングが進み、勾配が小さくなると、スケーリングファクターは増加し始めます。最終的に、ほとんどオーバーフローせず、かつ多くの小さな勾配を保持できる適切なレベルに落ち着きます。
しかし、損失スケーリングだけでは不十分な場合があります。一部の演算は、本質的に数値的に不安定です。これらの演算には、追加の注意が必要です。ソフトマックスは典型的な例です。標準的なソフトマックスの式は、softmax(x_i) = exp(x_i) / Σexp(x_j)です。もしx_iの値が大きい場合、exp(x_i)は非常に大きくなり、FP16でオーバーフローします。
この問題を解決するために、数値安定版ソフトマックスが使用されます。まず、入力の最大値mを見つけます。m = max(x)です。次に、すべての値からmを引いてから指数関数を適用します。softmax(x_i) = exp(x_i - m) / Σexp(x_j - m)です。この変換により、すべての指数の引数は非正になり、exp()の結果は0から1の間に収まります。これにより、オーバーフローが防止されます。
対数ソフトマックス(log softmax)も注意が必要です。単純にsoftmaxを計算してからlogを取ると、数値的問題が発生します。softmaxの出力が非常に小さい場合、logは大きな負の値になり、精度が失われます。数値安定版のlog softmaxは、log(softmax(x_i)) = (x_i - m) - log(Σexp(x_j - m))と直接計算されます。
レイヤーノーマライゼーションとバッチノーマライゼーションも、数値的に敏感です。これらの演算では、平均と分散を計算し、正規化します。分散の計算では、値の二乗を計算するため、オーバーフローのリスクがあります。また、正規化時に分散の平方根で割るため、分散が非常に小さい場合、数値的不安定性が発生します。
この問題に対処するために、多くの実装ではこれらの統計計算をFP32で実行します。入力がFP16であっても、平均と分散の計算ではFP32にキャストします。正規化後、結果をFP16に戻します。これにより、わずかなオーバーヘッドで数値安定性が大幅に向上します。また、分散に小さなイプシロン(たとえば1e-5)を加えることで、ゼロ除算を防ぎます。
勾配クリッピング(gradient clipping)も数値安定性の重要な技術です。これは、勾配の大きさを制限することで、爆発的な更新を防ぎます。グローバルノルムクリッピングでは、すべてのパラメータの勾配のL2ノルムを計算します。このノルムが閾値(たとえば1.0)を超える場合、すべての勾配を比例的にスケールダウンします。これにより、更新の方向は保たれますが、大きさが制限されます。
混合精度学習では、勾配クリッピングのタイミングが重要です。クリッピングは、損失スケーリングの逆操作の後、つまり勾配が元の大きさに戻された後に適用されるべきです。スケールされた勾配にクリッピングを適用すると、誤った閾値が使用され、効果が失われます。PyTorchのAMPでは、これは自動的に正しい順序で処理されます。
初期化も数値安定性に影響します。適切な初期化スキーム(XavierまたはHeの初期化など)を使用することで、トレーニング開始時のアクティベーションと勾配の大きさを適切な範囲に保ちます。不適切な初期化は、初期段階でのオーバーフロー/アンダーフローを引き起こし、混合精度学習を困難にします。
学習率のウォームアップも重要です。トレーニングの初期段階で小さな学習率から始め、徐々に増やすことで、初期の不安定性を軽減できます。混合精度学習では、これがさらに重要です。初期段階では、モデルのパラメータがランダムに初期化されており、アクティベーションと勾配の分布が安定していません。ウォームアップにより、モデルが安定した状態に達するまでの時間を与えます。
一部のモデルコンポーネントは、特に注意が必要です。埋め込み層(embedding layer)は、しばしば非常に大きな勾配を持ちます。特に、語彙サイズが大きい場合や、一部のトークンが非常に頻繁に出現する場合です。これらの大きな勾配は、損失スケーリングと組み合わせると、オーバーフローを引き起こす可能性があります。一部の実装では、埋め込み層の勾配に対して個別のクリッピングや、より低いスケーリングファクターを適用します。
位置エンコーディング(positional encoding)も注意が必要です。Transformerで使用される三角関数ベースの位置エンコーディングは、大きな値を生成する可能性があります。特に、長いシーケンスでは、位置インデックスが大きくなり、sin()やcos()の引数が大きくなります。FP16では、これらの計算をFP32で行い、結果のみをFP16にキャストすることが推奨されます。
アテンション機構のスケーリングも重要です。スケールドドットプロダクトアテンションでは、アテンションスコアをsqrt(d_k)で割ります。この除算は、スコアの大きさを制御し、ソフトマックスが極端な値(0または1に近い値)を生成するのを防ぎます。FP16では、この適切なスケーリングがさらに重要です。スケーリングなしでは、ソフトマックスの入力が大きくなりすぎ、オーバーフローのリスクが高まります。
デバッグ時には、数値的問題を特定するためのツールが有用です。PyTorchのtorch.autograd.detect_anomaly()は、backward pass中のNaNやinfを検出し、どこで発生したかを報告します。ただし、これはパフォーマンスオーバーヘッドが大きいため、デバッグ時のみ使用すべきです。また、定期的にアクティベーションと勾配の統計(平均、標準偏差、最大値、最小値)をロギングすることで、数値的問題の兆候を早期に発見できます。
BF16を使用する場合、これらの多くの問題が軽減されます。BF16の広い数値範囲により、オーバーフローとアンダーフローのリスクが大幅に減少します。損失スケーリングは通常不要で、多くの演算をFP32にキャストする必要もありません。実装がシンプルになり、数値的安定性が向上するため、最新のハードウェアではBF16が推奨される選択となっています。
数値安定性の確保は、継続的なプロセスです。新しいモデルアーキテクチャや最適化技術を導入する際、数値的な挙動を注意深く監視する必要があります。適切な技術を組み合わせることで、混合精度学習は、FP32トレーニングと同等の安定性と収束性を達成でき、同時に大幅な速度とメモリの改善を提供します。
10.4 実験結果:学習速度とモデル性能
Tatsunori Hashimoto: 混合精度学習の理論的な利点を理解したところで、実際の実験結果を見ていきましょう。これらのデータは、混合精度学習が実世界のアプリケーションでどの程度の改善をもたらすかを明確に示しています。
まず、トレーニング速度の改善から見ていきましょう。NVIDIA の研究チームが行った包括的なベンチマークでは、様々なモデルアーキテクチャで混合精度学習を評価しました。ResNet-50をImageNetでトレーニングした実験では、FP32で約90時間かかったトレーニングが、混合精度学習(FP16)では約29時間に短縮されました。これは約3.1倍の高速化です。V100 GPUを8個使用した構成での結果です。
Transformerモデルでも同様の改善が見られます。BERT-Largeモデル(340Mパラメータ)をWikipediaとBooksCorpusでプレトレーニングする実験では、FP32で約67時間、混合精度学習では約28時間でした。これは約2.4倍の高速化です。GPT-2(1.5Bパラメータ)では、さらに大きな改善が観測され、FP32で約240時間、混合精度学習では約85時間で、約2.8倍の高速化でした。
より大規模なモデルでは、改善がさらに顕著になります。175億パラメータのGPT-3クラスのモデルでは、混合精度学習により、トレーニング時間が約3倍から3.5倍短縮されると推定されています。これは、数週間のトレーニング時間を数日短縮することに相当し、膨大なコスト削減につながります。
A100 GPUでは、改善がさらに大きくなります。A100は、Tensor Coreの性能が大幅に向上しており、FP16/BF16演算が非常に高速です。同じBERT-Largeのトレーニングで、A100では混合精度学習により約3.5倍から4倍の高速化が達成されました。これは、V100での約2.4倍と比較して、大幅な改善です。
メモリ使用量の削減も重要な利点です。70億パラメータのモデルをバッチサイズ32、シーケンス長2048でトレーニングする実験では、FP32が約78GBのGPUメモリを使用したのに対し、混合精度学習では約52GBでした。これは約33%の削減です。この削減により、同じGPUでより大きなバッチサイズを使用できるようになります。
バッチサイズを32から48に増やすと、スループットがさらに向上します。FP32では、メモリ不足でバッチサイズ48は不可能でしたが、混合精度学習では可能になりました。結果として、混合精度学習は、単純な速度向上だけでなく、より効率的なバッチサイズの使用により、全体のスループットを約4倍から5倍向上させることができました。
モデル品質への影響を見てみましょう。多くの実験で、混合精度学習で訓練されたモデルは、FP32で訓練されたモデルと統計的に同等の性能を示します。BERT-BaseをGLUEベンチマークで評価した実験では、FP32が平均84.5%、混合精度学習(FP16)が84.3%でした。この0.2ポイントの差は、複数回の実行のばらつきの範囲内です。
画像分類タスクでも同様の結果が得られています。ResNet-50のImageNet Top-1精度は、FP32で76.15%、混合精度学習で76.13%でした。この0.02ポイントの差は無視できます。EfficientNet-B7では、FP32が84.3%、混合精度学習が84.2%で、ほぼ同一です。
言語モデルのパープレキシティも同等です。GPT-2をWebTextでトレーニングした実験では、FP32のパープレキシティが29.41、混合精度学習が29.38でした。実質的に同じです。より大規模なモデルでも、この傾向は続きます。13億パラメータのモデルでは、FP32が18.2、混合精度学習が18.3で、わずか0.1の差です。
一部のタスクでは、混合精度学習が実際に性能を向上させることもあります。これは一見矛盾しているように思えますが、低精度による軽微なノイズが正則化効果を持つためと考えられています。機械翻訳タスクでは、混合精度学習がFP32よりも約0.3 BLEUポイント高いスコアを達成した例があります。ただし、これは一貫した現象ではなく、モデルとタスクに依存します。
BF16とFP16の比較も興味深いです。数値安定性の面でBF16が優れていることは前述しましたが、最終的なモデル性能では両者はほぼ同等です。BERT-LargeをSQuADでファインチューニングした実験では、FP32が93.2% F1スコア、FP16が93.1%、BF16が93.15%でした。すべて統計的に同等です。
しかし、実装の簡便性では、BF16が明らかに優れています。FP16では、損失スケーリングのチューニングが必要になることがあります。一部のモデルでは、デフォルトのスケーリング設定では収束しないことがあり、手動での調整が必要です。BF16では、ほとんどの場合、特別な調整なしで動作します。複数のプロジェクトでの経験では、FP16で数値的問題に遭遇した場合、BF16に切り替えるだけで問題が解決することが多いです。
収束速度も評価されています。同じ最終性能に達するまでのイテレーション数を比較すると、混合精度学習はFP32とほぼ同じです。つまり、混合精度学習による速度向上は、各イテレーションが高速になることによるもので、収束に必要なイテレーション数は変わりません。これは重要な観察で、混合精度学習がアルゴリズムの収束特性を変えないことを示しています。
ただし、一部のケースでは、わずかに多くのイテレーションが必要になることがあります。特に、非常に小さな学習率を使用する場合や、非常に深いネットワークでは、収束が若干遅くなることがあります。しかし、各イテレーションが高速なため、総トレーニング時間は依然として短くなります。典型的には、イテレーション数が5%から10%増加しても、全体の時間は2倍以上短縮されます。
異なるハードウェアでの性能も評価されています。V100では、混合精度学習により約2倍から3倍の高速化が達成されます。A100では約3倍から4倍です。最新のH100では、さらに大きな改善が期待され、初期のベンチマークでは4倍から5倍の高速化が報告されています。これは、新しいハードウェアが低精度演算をより効率的にサポートしているためです。
TPU(Tensor Processing Unit)でも同様の利点があります。GoogleのTPU v3とv4は、BF16を標準でサポートしており、FP32と比較して約2倍から3倍の高速化を提供します。TPUでは、BF16がデフォルトの精度として推奨されており、ほとんどのワークロードで問題なく動作します。
エネルギー効率の改善も測定されています。混合精度学習は、同じモデルを訓練するのに必要な電力を約40%から50%削減します。これは、トレーニング時間の短縮と、低精度演算の低いエネルギー消費の両方によるものです。大規模なトレーニングでは、この電力削減は数千ドルのコスト削減と、数トンのCO2排出削減に相当します。
実世界のデプロイメントでの経験も重要です。多くの組織が、混合精度学習を本番環境で採用しています。Hugging Faceは、すべての事前訓練モデルを混合精度学習でトレーニングしており、問題なく動作していると報告しています。OpenAIも、GPTシリーズのトレーニングで混合精度学習を使用しています。これらの成功例は、混合精度学習が実用的で信頼性の高い技術であることを示しています。
失敗例や注意が必要なケースも報告されています。一部の強化学習アルゴリズムでは、価値関数の推定に高い数値精度が必要で、混合精度学習が不安定になることがあります。また、非常に小さなモデル(数百万パラメータ以下)では、混合精度学習の利点が限定的です。オーバーヘッドが相対的に大きくなり、速度向上が小さくなるためです。
推奨される使用シナリオとして、中規模から大規模のモデル(数億パラメータ以上)のトレーニングでは、混合精度学習を常に使用すべきです。速度とメモリの改善が大きく、モデル品質への影響はほとんどありません。最新のハードウェア(A100、H100、TPU v3以降)では、BF16が推奨されます。古いハードウェア(V100など)では、FP16を使用し、適切な損失スケーリングを適用します。
小規模なモデルや、数値的に非常に敏感なタスクでは、まずFP32で試し、混合精度学習を慎重に評価することが推奨されます。ただし、ほとんどの標準的なディープラーニングタスクでは、混合精度学習は安全で効果的な選択です。実験結果は一貫して、大幅な速度向上、メモリ削減、そしてほぼ同等のモデル品質を示しており、混合精度学習は現代の大規模トレーニングにおいて不可欠な技術となっています。
11. Supervised finetuning(教師あり微調整)
11.1 ファインチューニングの目的と位置づけ
Tatsunori Hashimoto: ここまで、大規模言語モデルのプレトレーニングと、それを効率化するための様々な技術について議論してきました。しかし、プレトレーニングされたモデルは、まだ実用的なアプリケーションに直接使用できる状態ではありません。ここで、ファインチューニング、特に教師ありファインチューニングが重要な役割を果たします。
ファインチューニングの基本的な目的は、汎用的なプレトレーニングモデルを特定のタスクやドメインに適応させることです。プレトレーニング段階では、モデルは膨大な量のテキストから言語の一般的なパターンや知識を学習しました。しかし、このモデルは次のトークンを予測することに最適化されているだけで、特定のタスク、たとえば質問応答、文章要約、感情分析などを直接実行するようには訓練されていません。
LLMトレーニングパイプライン全体におけるファインチューニングの位置づけを理解しましょう。最初に、大規模なプレトレーニングがあります。これは何兆ものトークンを使い、数週間から数ヶ月かかる計算集約的なフェーズです。プレトレーニングの結果は、強力な言語理解能力を持つベースモデルですが、まだ「荒削り」な状態です。
次に、ファインチューニングフェーズが来ます。これは、プレトレーニングよりもはるかに小規模で高速です。通常、数千から数十万の例を使い、数時間から数日で完了します。ファインチューニングには、いくつかの異なるタイプがありますが、教師ありファインチューニングは最も基本的で重要なものです。
その後、多くの場合、アライメントフェーズがあります。これには、人間のフィードバックからの強化学習(RLHF)やその他の技術が含まれ、モデルの振る舞いを人間の価値観や期待に合わせて調整します。しかし、今日の講義では、教師ありファインチューニングとInstruction Tuningに焦点を当てます。
教師ありファインチューニングの具体的な目的を見ていきましょう。第一に、タスク特化です。プレトレーニングモデルは汎用的ですが、特定のタスクに対して最適化されていません。ファインチューニングにより、モデルは特定のタスクのパターンを学習し、性能が大幅に向上します。たとえば、感情分析タスクでファインチューニングされたモデルは、ゼロショットで同じタスクを実行するプレトレーニングモデルよりもはるかに優れた性能を示します。
第二に、ドメイン適応があります。プレトレーニングデータは一般的なウェブテキストが中心ですが、特定のドメイン、たとえば医療、法律、金融などでは、専門的な用語や文脈が異なります。ドメイン特化データでファインチューニングすることで、モデルはそのドメインの特性を学習し、より適切な応答を生成できるようになります。
第三に、振る舞いの形成があります。プレトレーニングモデルは次のトークンを予測するだけで、必ずしもユーザーが期待する形式で応答を生成するわけではありません。ファインチューニングにより、モデルに特定の応答スタイル、フォーマット、またはトーンを教えることができます。たとえば、簡潔な応答、詳細な説明、特定の構造を持つ出力などです。
具体例で考えてみましょう。GPT-3のようなプレトレーニングモデルに「この映画レビューの感情は何ですか?」と尋ねても、必ずしも明確な「ポジティブ」または「ネガティブ」という答えを返すとは限りません。モデルは、次に来そうなテキストを生成するだけなので、レビューの続きを書いたり、関連する話題について語ったりするかもしれません。しかし、感情分析タスクでファインチューニングされたモデルは、明確に「ポジティブ」または「ネガティブ」と答えるように学習されています。
ファインチューニングとプレトレーニングの違いを理解することも重要です。データ量では、プレトレーニングが数兆トークンを使用するのに対し、ファインチューニングは通常数百万から数億トークン、場合によっては数万トークンで十分です。計算コストでは、プレトレーニングが数千GPU時間を必要とするのに対し、ファインチューニングは数GPU時間から数十GPU時間で完了します。
学習率も異なります。プレトレーニングでは、比較的高い学習率(たとえば3e-4や1e-3)から始めます。ファインチューニングでは、はるかに低い学習率(たとえば1e-5や5e-5)を使用します。これは、既に良い初期化(プレトレーニングされた重み)から始まっているため、大きな変更は不要で、むしろ有害だからです。高い学習率を使用すると、catastrophic forgetting(破滅的忘却)が発生し、プレトレーニングで学習した有用な知識が失われる可能性があります。
エポック数も異なります。プレトレーニングでは、通常1エポックのみ、つまりデータを一度だけ見ます。データセットが非常に大きく、複数回見ることは実用的でないためです。ファインチューニングでは、データセットが小さいため、複数エポック(通常3から10エポック)トレーニングします。ただし、過学習に注意が必要です。
ファインチューニングの種類も多様です。タスク特化ファインチューニングでは、単一の明確なタスク、たとえば文章分類、固有表現認識、質問応答などに焦点を当てます。データセットは、そのタスクの入力と出力のペアで構成されます。このタイプのファインチューニングは、特定のアプリケーションで非常に高い性能を達成できますが、他のタスクへの汎化性は限定的です。
ドメイン適応ファインチューニングでは、特定のドメインのテキストでモデルを継続的にトレーニングします。タスクを明示的に定義せず、単にそのドメインの言語パターンに適応させます。たとえば、医療論文のコーパスでファインチューニングすることで、医療用語や文脈をより良く理解できるようになります。
マルチタスクファインチューニングでは、複数の異なるタスクで同時にトレーニングします。各バッチには、異なるタスクからの例が含まれます。これにより、モデルはより汎用的な能力を保持しながら、複数のタスクで良い性能を達成できます。ただし、タスク間のバランスを取ることが課題です。
Instruction Tuningは、ファインチューニングの特殊なタイプで、モデルが自然言語の指示に従えるようにします。これは、最近の大規模言語モデル、特にChatGPTのようなチャットボットの成功の鍵となっています。Instruction Tuningについては、次のセクションで詳しく議論します。
ファインチューニングの実践的な利点も明確です。第一に、コスト効率です。プレトレーニングには膨大なリソースが必要ですが、一度プレトレーニングされたモデルは、多くの異なるタスクにファインチューニングできます。組織は、自分でプレトレーニングする代わりに、公開されているプレトレーニングモデル(LlamaやMistralなど)を使用し、自分のタスクにファインチューニングできます。
第二に、データ効率です。Few-shot learningやゼロショット学習と比較して、ファインチューニングははるかに少ないデータでより高い性能を達成できます。数千の例でファインチューニングすることで、プロンプトエンジニアリングだけでは達成困難な性能レベルに到達できます。
第三に、カスタマイズ性です。ファインチューニングにより、組織は自分のニーズに正確に合わせてモデルを調整できます。特定の用語、スタイル、またはドメイン知識を組み込むことができます。これは、汎用的なAPIモデルでは困難です。
ファインチューニングの課題もあります。最も大きな課題は、過学習です。データセットが小さい場合、モデルはトレーニングデータを記憶し、新しいデータに汎化できなくなる可能性があります。これを防ぐために、適切な正則化、early stopping、そして検証セットでの慎重な評価が必要です。
Catastrophic forgetting(破滅的忘却)も重要な課題です。ファインチューニング中、モデルは新しいタスクを学習しながら、プレトレーニングで学習した一般的な知識を失う可能性があります。これを軽減するために、低い学習率、適切なエポック数、そして場合によっては正則化技術(たとえば、元の重みからの距離を罰する)を使用します。
データの質と多様性も重要です。ファインチューニングデータが偏っている場合、モデルもその偏りを学習します。高品質で多様なトレーニングデータを確保することは、成功するファインチューニングの鍵です。
ファインチューニングは、LLMを実用的なアプリケーションに変換する架け橋です。プレトレーニングが基礎を提供し、ファインチューニングがそれを具体的な価値に変換します。適切に実行されれば、ファインチューニングは比較的少ないリソースで、特定のタスクやドメインで非常に高性能なモデルを生み出すことができます。
11.2 事前学習との違い
Tatsunori Hashimoto: ファインチューニングとプレトレーニングは、どちらもモデルをトレーニングするプロセスですが、その目的、方法、そして特性は根本的に異なります。これらの違いを深く理解することは、効果的なファインチューニング戦略を設計する上で不可欠です。
最も基本的な違いは、目的関数です。プレトレーニングでは、次のトークン予測という単一の自己教師あり目的を使用します。モデルは、与えられたコンテキストから次に来るトークンを予測することを学習します。この目的は、言語の一般的な構造と知識を学習するのに適していますが、特定のタスクに最適化されているわけではありません。
ファインチューニングでは、目的関数はタスクに応じて異なります。分類タスクでは、クロスエントロピー損失を使用してクラスラベルを予測します。回帰タスクでは、平均二乗誤差を使用します。生成タスクでも次のトークン予測を使用しますが、プレトレーニングとは異なり、特定のフォーマットや構造を持つ出力を生成するように条件付けられます。
データの性質も大きく異なります。プレトレーニングデータは、インターネットから収集された生のテキストです。ラベルは不要で、テキスト自体が自己教師あり学習のための信号を提供します。データは非常に多様で、様々なトピック、スタイル、そして品質レベルを含みます。Common Crawl、書籍、学術論文、コードなど、あらゆるソースからのテキストが使用されます。
ファインチューニングデータは、人間によって注意深くキュレーションされた例です。各例は、明確な入力と期待される出力のペアで構成されます。たとえば、質問応答タスクでは、質問、コンテキスト、そして正解の三つ組です。感情分析では、テキストとそのラベル(ポジティブ/ネガティブ)のペアです。データの品質は、プレトレーニングよりもはるかに高く、タスクに直接関連しています。
データ量のスケールも劇的に異なります。プレトレーニングでは、何兆ものトークンが使用されます。たとえば、Llama 2は2兆トークンでトレーニングされました。GPT-3は約3000億トークンです。一方、ファインチューニングでは、数千から数百万の例が典型的です。SQuADという質問応答データセットには約100,000の例があります。多くの企業固有のファインチューニングタスクでは、数千の例しかないこともあります。
この違いを具体的な数字で見てみましょう。70億パラメータのモデルをプレトレーニングするには、2兆トークンで約8.4×10の22乗FLOPsが必要です。これをA100で実行すると、約1000 GPU日かかります。同じモデルを10万例(平均500トークン/例で5000万トークン)でファインチューニングする場合、約2.1×10の18乗FLOPsで、約1 GPU日で完了します。つまり、ファインチューニングはプレトレーニングの約1/1000の計算量です。
学習率とオプティマイザの設定も大きく異なります。プレトレーニングでは、比較的高い学習率から始めます。典型的には、ウォームアップ後に6e-4や1e-3などの学習率を使用します。学習率スケジュールは、コサイン減衰や線形減衰を使用し、トレーニング全体で徐々に学習率を下げます。
ファインチューニングでは、はるかに低い学習率を使用します。典型的には、1e-5から5e-5の範囲です。これは、プレトレーニングされた重みが既に良い初期化を提供しており、大きな変更は不要だからです。高すぎる学習率は、catastrophic forgetting(破滅的忘却)を引き起こし、プレトレーニングで学習した有用な表現を破壊してしまいます。
実験例を見てみましょう。BERTをGLUEタスクでファインチューニングする標準的な設定では、学習率2e-5、バッチサイズ16または32、3から4エポックを使用します。もし学習率を1e-3に設定すると、モデルは不安定になり、性能が大幅に低下します。検証損失が発散し、トレーニングが失敗します。
バッチサイズも異なります。プレトレーニングでは、非常に大きなバッチサイズ、たとえば数百万から数千万トークンのバッチを使用します。これは、gradient accumulationや複数のGPUにわたるデータ並列化により実現されます。大きなバッチサイズは、トレーニングを安定化し、並列化効率を向上させます。
ファインチューニングでは、小さなバッチサイズで十分なことが多いです。典型的には、8から64の範囲です。これは、データセットが小さいため、大きなバッチサイズが実用的でないことと、小さなバッチサイズが正則化効果を持ち、過学習を防ぐためです。
エポック数の扱いも異なります。プレトレーニングでは、通常1エポックのみです。データセットが非常に大きく、一度見るだけで十分な学習信号が得られます。実際、プレトレーニングデータを複数回見ることは、計算コストの面で実用的ではありません。
ファインチューニングでは、複数エポックが標準です。データセットが小さいため、モデルは数回データを見る必要があります。典型的には3から10エポックですが、タスクとデータサイズに依存します。ただし、過学習のリスクがあるため、検証セットでの性能を監視し、early stoppingを使用することが重要です。
過学習の扱いも大きく異なります。プレトレーニングでは、データセットが非常に大きく、各例を一度しか見ないため、過学習はほとんど問題になりません。主な課題は、計算リソースを効率的に使用し、十分に長くトレーニングすることです。
ファインチューニングでは、過学習が主要な課題です。データセットが小さく、複数エポックトレーニングするため、モデルはトレーニングデータを記憶してしまう可能性があります。これを防ぐために、様々な正則化技術が使用されます。ドロップアウト、重み減衰、early stopping、そしてデータ拡張などです。
検証戦略も異なります。プレトレーニングでは、定期的に小さな検証セットでパープレキシティを測定しますが、これは主にトレーニングが正しく進行しているかを確認するためです。最終的なモデルの品質は、下流タスクでの性能で評価されます。
ファインチューニングでは、検証セットが極めて重要です。各エポック後、検証セットで性能を評価し、最良の性能を達成したモデルを保存します。テストセットは、最後まで触れず、最終的な性能評価にのみ使用します。この厳格な分離により、過学習を検出し、汎化性能を正確に推定できます。
重みの更新パターンも異なります。プレトレーニングでは、モデルのすべてのレイヤーが大きく変化します。ランダム初期化から始まるため、すべての重みが意味のある値に調整される必要があります。トレーニングの過程で、モデルは徐々に言語の表現を構築していきます。
ファインチューニングでは、重みの変化ははるかに小さいです。プレトレーニングされた重みは既に良い表現を持っているため、微調整だけで十分です。実験では、ファインチューニング後の重みの変化は、元の値の数パーセント程度であることが示されています。深いレイヤー(出力に近い)は大きく変化しますが、浅いレイヤー(入力に近い)の変化は最小限です。
計算グラフの扱いも異なる場合があります。プレトレーニングでは、すべてのレイヤーが常に更新されます。ファインチューニングでは、時に層の凍結(layer freezing)が使用されます。モデルの初期レイヤーを凍結し、後のレイヤーのみをトレーニングします。これにより、計算コストが削減され、過学習も防げます。ただし、これは常に最良の戦略ではなく、タスクに依存します。
収束の速さも異なります。プレトレーニングは、数万から数十万ステップを必要とします。損失は徐々に、対数的に減少します。ファインチューニングは、はるかに速く収束します。多くの場合、数百から数千ステップで良い性能に達します。検証損失は、最初は急速に減少し、その後安定します。
ハードウェア要件も異なります。プレトレーニングは、数百から数千のGPUを必要とする大規模な分散トレーニングです。ファインチューニングは、しばしば単一のGPU、または数個のGPUで実行できます。これにより、ファインチューニングはより多くの研究者や組織にとってアクセス可能です。
これらの違いを理解することは、適切なファインチューニング戦略を選択する上で重要です。プレトレーニングの設定をファインチューニングにそのまま適用すると、catastrophic forgettingや不安定なトレーニングなどの問題が発生します。逆に、ファインチューニングの保守的な設定をプレトレーニングに使用すると、トレーニングが非常に遅くなり、リソースの無駄になります。各段階の特性を理解し、それに応じて設定を調整することが、成功する LLMトレーニングの鍵です。
11.3 データセットの準備方法
Tatsunori Hashimoto: ファインチューニングの成功は、データセットの品質に大きく依存します。適切に準備されたデータセットは、少ない例でも高い性能を達成できますが、不適切なデータは、どれだけ多くの例があっても良い結果をもたらしません。データセット準備の実践的な方法を見ていきましょう。
まず、データ収集から始めます。ファインチューニングデータは、解決したいタスクを正確に反映する必要があります。たとえば、カスタマーサポートチャットボットを構築する場合、実際の顧客からの質問と、適切な応答のペアを収集します。感情分析システムを構築する場合、実際のユーザーレビューと、そのラベル(ポジティブ、ネガティブ、ニュートラル)を収集します。
データ量について、一般的なガイドラインは、タスクの複雑さに依存します。シンプルな分類タスク(たとえば、2クラスの感情分析)では、数百から数千の例で十分な場合があります。より複雑なタスク(たとえば、多クラス分類、複雑な生成タスク)では、数万の例が必要になるかもしれません。経験則として、まず小規模なデータセット(1000例程度)から始め、性能を評価し、必要に応じてデータを追加することが推奨されます。
データの多様性も極めて重要です。トレーニングデータは、実際の使用シナリオで遭遇するであろう様々なケースをカバーする必要があります。顧客サポートの例では、異なる種類の質問(技術的な問題、請求に関する質問、製品情報など)、異なるトーンや長さの質問を含めるべきです。偏ったデータセットは、偏ったモデルを生み出します。
データフォーマットの設計も重要です。ほとんどのファインチューニングでは、入力と出力のペアが必要です。フォーマットは、タスクによって異なります。質問応答タスクでは、{"question": "...", "context": "...", "answer": "..."}のような構造が一般的です。分類タスクでは、{"text": "...", "label": "..."}です。生成タスクでは、{"input": "...", "output": "..."}のようなシンプルなペアで十分です。
JSONLinesフォーマット(1行に1つのJSON object)は、大規模データセットで広く使用されています。これは、ストリーミング処理が容易で、行ごとに独立しているためです。例えば、次のようになります。
{"input": "この製品の返品方法を教えてください", "output": "返品をご希望の場合は、購入から30日以内に..."}
{"input": "配送状況を確認したい", "output": "ご注文番号を入力いただければ、配送状況を..."}データクリーニングは、データセット準備の重要なステップです。まず、重複を除去します。同じ例が複数回含まれていると、モデルはそれらを過度に重視し、過学習の原因になります。完全一致だけでなく、ほぼ重複している例(たとえば、わずかな表現の違いのみ)も検出し、除去または統合すべきです。
ノイズの多いデータも問題です。明らかに誤ったラベル、意味不明なテキスト、フォーマットが壊れた例などは除去します。自動的なフィルタリング(たとえば、非常に短いまたは長すぎるテキスト、特殊文字が多すぎるテキスト)と、人手によるレビューを組み合わせることが効果的です。
データの正規化も考慮すべきです。テキストの前後の空白を削除、一貫した句読点の使用、一貫したケーシング(大文字/小文字)などです。ただし、過度の正規化は、実際のデータとの乖離を生む可能性があるため、注意が必要です。
データ分割は、過学習を防ぎ、正確な性能評価を行うために不可欠です。標準的な分割は、トレーニングセット80%、検証セット10%、テストセット10%です。小規模なデータセット(1000例未満)では、70/15/15や60/20/20の分割を使用することもあります。重要なのは、テストセットを完全に分離し、最終評価まで一切触れないことです。
分割時には、ランダム性を確保しつつ、データの分布を保つことが重要です。Stratified splitting(層化分割)は、各クラスの比率をすべての分割で保持します。たとえば、ポジティブとネガティブの例が70:30の比率であれば、トレーニング、検証、テストセットすべてでこの比率を維持します。
データ拡張も、小規模データセットで有用です。テキストデータでは、同義語置換、バックトランスレーション(他の言語に翻訳してから戻す)、文の並べ替えなどの技術があります。ただし、拡張がタスクの意味を変えないように注意が必要です。感情分析では、「良い」を「素晴らしい」に置き換えることは安全ですが、「良い」を「悪い」に置き換えることは明らかに問題です。
アクティブラーニングも、効率的なデータ収集戦略です。少量のデータでモデルをトレーニングし、モデルが最も不確実な例を特定します。これらの例に対して人間がラベルを付け、モデルを再トレーニングします。このプロセスを繰り返すことで、最も有益な例を優先的にラベル付けでき、アノテーションコストを削減できます。
データ品質の評価も重要です。アノテーター間一致度(Inter-Annotator Agreement)を測定します。複数のアノテーターが同じ例にラベルを付け、どの程度一致するかを確認します。Cohenのカッパ係数やFleissのカッパが一般的な指標です。一致度が低い場合、タスクの定義が不明確か、アノテーターのトレーニングが不十分である可能性があります。
難しい例やエッジケースを意図的に含めることも有用です。モデルは、簡単な例だけでなく、曖昧な例や境界的なケースからも学習します。これにより、モデルの頑健性が向上します。ただし、極端に難しい例や、人間でも判断が困難な例は、ノイズとして機能する可能性があるため、慎重に扱う必要があります。
データセットのバージョン管理も忘れてはいけません。データセットの各バージョンを追跡し、どのモデルがどのバージョンでトレーニングされたかを記録します。これにより、問題が発生した際にデバッグが容易になり、実験の再現性も確保されます。DVC(Data Version Control)やGit LFSなどのツールが有用です。
11.4 ファインチューニングのベストプラクティス
Tatsunori Hashimoto: 適切なデータセットを準備したら、次は効果的なファインチューニングを実行することです。ここでは、実践で証明されたベストプラクティスを紹介します。
学習率の選択は、最も重要なハイパーパラメータです。経験則として、1e-5から5e-5の範囲から始めます。BERTやRoBERTaのようなエンコーダーモデルでは、2e-5や3e-5がよく機能します。GPTのようなデコーダーモデルでは、やや高い5e-5が適していることがあります。大規模モデル(数十億パラメータ以上)では、より低い学習率、たとえば1e-6や5e-6が必要になることがあります。
学習率のウォームアップも推奨されます。最初の数百ステップで、学習率を0から目標値まで線形に増加させます。これにより、トレーニングの初期段階での不安定性が軽減されます。典型的なウォームアップは、総ステップ数の5%から10%です。
バッチサイズの選択も重要です。小さなバッチサイズ(8や16)は、正則化効果があり、小規模データセットでの過学習を防ぎます。大きなバッチサイズ(32や64)は、より安定したトレーニングを提供しますが、過学習のリスクが高まります。メモリが許す限り、いくつかの値を試し、検証セットでの性能を比較することが推奨されます。
Gradient accumulationは、メモリが限られている場合に有用です。小さな物理バッチサイズで複数のステップ分の勾配を蓄積し、その後重みを更新します。これにより、大きな実効バッチサイズを達成できます。たとえば、物理バッチサイズ8で4ステップ蓄積すると、実効バッチサイズは32になります。
エポック数は、データセットサイズに依存します。大規模データセット(10万例以上)では、3から5エポックで十分です。小規模データセット(数千例)では、10から20エポックが必要かもしれません。ただし、検証損失を監視し、過学習の兆候が見られたら早期に停止します。
Early stoppingは、過学習を防ぐための重要な技術です。各エポック後、検証セットで性能を評価します。性能が数エポック(通常3から5エポック)改善しない場合、トレーニングを停止し、最良の性能を達成したモデルを使用します。これにより、過学習を避けながら、最適なトレーニング期間を自動的に決定できます。
重み減衰(weight decay)は、正則化の一形態で、重みが大きくなりすぎるのを防ぎます。典型的な値は0.01です。ただし、LayerNormやバイアス項には重み減衰を適用しないことが推奨されます。PyTorchのAdamWオプティマイザは、これを自動的に処理します。
勾配クリッピングも推奨されます。勾配のグローバルノルムを1.0に制限することで、爆発的な更新を防ぎます。これは、特に生成タスクや、長いシーケンスを扱う場合に重要です。
ドロップアウトは、過学習を防ぐもう一つの技術ですが、慎重に使用する必要があります。プレトレーニングされたモデルは既にドロップアウトを含んでいることが多く、ファインチューニング時に追加のドロップアウトを適用すると、性能が低下する可能性があります。デフォルトの設定から始め、過学習が問題になる場合にのみ調整します。
層の凍結(Layer freezing)は、計算コストを削減し、過学習を防ぐ技術です。モデルの初期レイヤー(エンコーダーの下層)を凍結し、後のレイヤーのみをトレーニングします。これは、ドメインがプレトレーニングデータと類似している場合に効果的です。ただし、ドメインが大きく異なる場合、すべてのレイヤーをトレーニングする方が良い結果をもたらすことがあります。
差分学習率(Differential learning rates)も有用です。異なるレイヤーに異なる学習率を適用します。通常、深いレイヤー(タスク特化)には高い学習率、浅いレイヤー(一般的な特徴)には低い学習率を使用します。これにより、タスク特化の適応を促進しながら、一般的な知識を保持できます。
ハイパーパラメータの探索は、最適な設定を見つけるために重要です。ただし、計算コストが高いため、効率的な戦略が必要です。グリッドサーチよりも、ランダムサーチやBayesian optimizationが推奨されます。重要なハイパーパラメータ(学習率、バッチサイズ、エポック数)に焦点を当て、他は合理的なデフォルト値を使用します。
検証メトリクスの選択も重要です。タスクに適したメトリクスを使用します。分類タスクでは、精度、F1スコア、またはAUC-ROC。生成タスクでは、BLEU、ROUGE、またはタスク特化のメトリクス。メトリクスが複数ある場合、主要なメトリクスを決定し、それに基づいてモデルを選択します。
モデルのアンサンブルも性能を向上させることができます。異なるランダムシードや異なるハイパーパラメータでトレーニングされた複数のモデルの予測を平均します。これにより、個々のモデルよりも頑健で正確な予測が得られますが、推論コストが増加します。
定期的なチェックポイントの保存も重要です。各エポック後、または一定のステップごとにモデルを保存します。これにより、トレーニングが中断された場合でも、最新の状態から再開できます。また、異なるチェックポイントの性能を比較し、最良のものを選択できます。
ロギングとモニタリングも不可欠です。トレーニング損失、検証損失、そして関連するメトリクスを定期的に記録します。TensorBoard、Weights & Biases、またはMLflowなどのツールを使用して、トレーニングの進行を可視化します。これにより、問題を早期に発見し、ハイパーパラメータを調整できます。
再現性の確保も重要です。ランダムシードを固定し、すべてのハイパーパラメータと設定を記録します。コードのバージョン、データセットのバージョン、そして使用したライブラリのバージョンも文書化します。これにより、実験を再現し、結果を検証できます。
これらのベストプラクティスを適用することで、効果的なファインチューニングが可能になり、限られたデータとリソースでも高性能なモデルを構築できます。各プロジェクトの特性に応じて、これらのガイドラインを調整し、継続的に実験と評価を行うことが、成功への鍵です。
12. Instruction tuning(指示チューニング)
12.1 Instruction tuningの定義と重要性
Tatsunori Hashimoto: Instruction tuningは、大規模言語モデルを実用的な対話型AIアシスタントに変換する上で、最も重要なブレークスルーの一つです。この技術が、ChatGPTやClaude、そして他の最新の対話型AIの成功の核心にあります。その定義と、なぜこれほど重要なのかを理解していきましょう。
Instruction tuningは、ファインチューニングの特殊な形式で、モデルが自然言語の指示に従えるようにトレーニングする技術です。従来のファインチューニングが単一のタスクに特化するのに対し、instruction tuningは、モデルが様々な異なるタスクを自然言語の指示を通じて実行できるようにします。たとえば、「この文章を要約してください」「フランス語に翻訳してください」「この質問に答えてください」といった指示を理解し、適切に応答します。
Instruction tuningの重要性を理解するために、それ以前の状況を考えてみましょう。GPT-3のようなプレトレーニングモデルは、驚くべき言語理解能力を持っていましたが、ユーザーの意図を直接理解するようには設計されていませんでした。これらのモデルは、与えられたテキストの続きを生成することに最適化されており、必ずしもユーザーが求める応答を生成するわけではありませんでした。
具体例で見てみましょう。プレトレーニングモデルに「パリはどの国の首都ですか?」と尋ねると、「パリはフランスの首都です」という直接的な答えを返すこともありますが、「この質問は地理のテストでよく出ます」とか、「首都に関する他の質問としては...」といった、質問の続きを生成することもあります。モデルは、ユーザーが答えを求めていることを明確に理解していないのです。
Instruction tuningは、この問題を解決します。モデルは、ユーザーの指示の意図を理解し、それに応じた適切な応答を生成するようにトレーニングされます。同じ質問に対して、instruction tunedモデルは一貫して「パリはフランスの首都です」という直接的な答えを返します。
Instruction tuningの歴史を振り返ると、重要なマイルストーンがいくつかあります。2021年のFLAN(Finetuned Language Net)は、Googleの研究者が発表した初期の重要な研究の一つです。彼らは、60以上の異なるNLPタスクを自然言語の指示に変換し、モデルをトレーニングしました。結果は印象的で、instruction tunedモデルは、見たことのない新しいタスクでもゼロショットで優れた性能を示しました。
2022年のInstructGPTは、OpenAIによる画期的な研究でした。彼らは、instruction tuningと人間のフィードバックからの強化学習(RLHF)を組み合わせました。InstructGPTは、GPT-3よりもはるかに小さいにもかかわらず、ユーザーの意図をよりよく理解し、より有用で安全な応答を生成しました。この研究が、後のChatGPTの基礎となりました。
Instruction tuningの重要性は、複数の次元で理解できます。第一に、ユーザビリティの向上です。Instruction tunedモデルは、複雑なプロンプトエンジニアリングを必要とせず、自然な指示で動作します。ユーザーは、モデルに何をしてほしいかを直接伝えることができ、期待通りの結果を得られます。
第二に、汎化能力の向上です。多様なタスクと指示でトレーニングすることで、モデルは「指示に従う」という一般的な能力を獲得します。これにより、トレーニング中に見たことのない新しいタスクでも、適切な指示があれば実行できるようになります。FLAN研究では、トレーニングしていないタスクカテゴリでも、性能が大幅に向上することが示されました。
第三に、安全性とアライメントの向上です。Instruction tuningは、モデルの振る舞いを人間の期待に近づけます。有害な応答を避ける、不確実な場合は正直に言う、ユーザーの真の意図を理解しようとする、といった望ましい振る舞いを教え込むことができます。
第四に、効率性の向上です。単一のinstruction tunedモデルが、多数の異なるタスクを実行できます。以前は、各タスクに対して別々にファインチューニングされたモデルが必要でしたが、instruction tuningにより、一つのモデルで対応できるようになりました。これは、デプロイメントとメンテナンスのコストを大幅に削減します。
Instruction tuningの効果を定量的に見てみましょう。FLAN研究では、T5-XXLモデル(11B パラメータ)をinstruction tuningした結果、25の未見タスクでゼロショット性能が平均で約10ポイント向上しました。一部のタスクでは、20ポイント以上の改善が見られました。InstructGPTでは、1.3Bパラメータのモデルが、175BパラメータのGPT-3よりも人間の評価者に好まれました。
Instruction tuningは、few-shot learningとも相補的です。Instruction tunedモデルは、few-shot例をより効果的に活用できます。指示と例を組み合わせることで、さらに高い性能を達成できます。研究では、instruction tuningがfew-shot learningの能力を損なうことなく、むしろ向上させることが示されています。
Instruction tuningの成功は、いくつかの重要な洞察に基づいています。第一に、タスクの多様性が重要です。モデルが多様なタスクと指示フォーマットを見ることで、より汎化可能な「指示理解」能力を獲得します。第二に、指示の明確さが重要です。曖昧な指示よりも、明確で具体的な指示の方が、モデルの学習を促進します。
第三に、スケールが重要です。大規模なモデルほど、instruction tuningから大きな利益を得ます。小規模なモデル(数億パラメータ)では、instruction tuningの効果は限定的ですが、数十億パラメータ以上のモデルでは、劇的な改善が見られます。これは、大規模モデルがより柔軟な表現を学習でき、新しいタスクへの適応が容易だからと考えられています。
Instruction tuningは、現代のLLMアプリケーションの基礎となっています。チャットボット、コーディングアシスタント、執筆支援ツール、教育アプリケーションなど、ユーザーと対話するすべてのAIシステムは、instruction tuningから利益を得ます。この技術なしでは、今日私たちが知っている有用で使いやすいAIアシスタントは存在しなかったでしょう。
12.2 指示データセットの設計
Tatsunori Hashimoto: Instruction tuningの成功は、指示データセットの品質に直接依存します。適切に設計されたデータセットは、モデルに強力な指示理解能力を与えますが、不適切なデータセットは、限定的な改善しかもたらしません。効果的な指示データセットの設計方法を詳しく見ていきましょう。
指示データセットの基本的な構造は、3つの要素で構成されます。指示(instruction)、入力(input)、そして出力(output)です。指示は、モデルに何をしてほしいかを説明します。入力は、タスクの具体的なインスタンスです。出力は、期待される応答です。たとえば、指示が「次の文章を肯定文に変換してください」、入力が「彼は来ませんでした」、出力が「彼は来ました」となります。
タスクの多様性は、instruction tuningの最も重要な要素です。モデルが様々な異なるタスクを見ることで、タスク固有のパターンではなく、一般的な「指示に従う」能力を学習します。FLAN研究では、60以上の異なるタスクが使用されました。これには、分類、生成、翻訳、質問応答、推論など、幅広いカテゴリが含まれます。
タスクカテゴリの具体例を見てみましょう。自然言語推論(NLI)タスクでは、2つの文の関係(含意、矛盾、中立)を判断します。指示は「次の2つの文の関係を判断してください」です。感情分析では、「この文章の感情を分類してください(ポジティブ、ネガティブ、ニュートラル)」です。要約では、「次の文章を要約してください」です。
翻訳タスクでは、「次の文章を[目標言語]に翻訳してください」という指示を使います。質問応答では、「次の質問に答えてください」または「次の文脈に基づいて質問に答えてください」です。常識推論では、「次の状況について最も適切な結論を選んでください」などです。
指示のフォーマットの多様性も重要です。同じタスクでも、複数の異なる指示の表現を使用します。感情分析の例では、「この文章の感情を分類してください」「この文章はポジティブですか、それともネガティブですか?」「この文章の感情的なトーンを判断してください」など、様々な言い回しを使います。これにより、モデルは特定の表現に過度に依存せず、指示の本質的な意味を理解するようになります。
データセットのソースも多様化すべきです。既存の学術的ベンチマーク(GLUE、SuperGLUE、SQuADなど)を活用することは良い出発点ですが、それだけでは不十分です。実世界のユースケースを反映したデータも含めるべきです。たとえば、実際のユーザーからの質問、カスタマーサポートのログ、創造的なライティングのタスクなどです。
データの品質管理も極めて重要です。各例を検証し、指示が明確か、入力が適切か、出力が正確かを確認します。曖昧な指示、矛盾する出力、または誤ったラベルは、モデルの学習を妨げます。人間のレビュアーによる品質チェックは、時間がかかりますが、不可欠です。
データセットのバランスも考慮すべきです。特定のタスクカテゴリが支配的にならないよう、各カテゴリから適切な数の例を含めます。ただし、完全な均等分布が常に最良とは限りません。より複雑なタスクや、モデルが苦手とするタスクには、より多くの例を割り当てることが有効です。
指示のテンプレートを作成することは、大規模なデータセット生成を効率化します。たとえば、分類タスクの基本テンプレートは「次の[対象]を[カテゴリ]に分類してください」です。具体的なインスタンスとして、「次の文章を感情に分類してください」「次の記事をトピックに分類してください」などが生成できます。
Few-shot例を指示に含めることも効果的です。指示の後に、1つまたは2つの例を示すことで、タスクがより明確になります。たとえば、「次の文章の感情を分類してください。例: 「この映画は素晴らしかった」→ ポジティブ。では、次の文章を分類してください: [入力文章]」のようになります。
負の例も含めることが有用です。モデルに何をすべきでないかを教えることで、望ましくない振る舞いを避けられます。たとえば、「有害なコンテンツを生成しないでください」「不確実な情報を事実として提示しないでください」といった指示と、それに従った応答の例を含めます。
Chain-of-thought(思考の連鎖)プロンプティングを指示データセットに統合することも、最近の重要な進展です。複雑な推論タスクでは、最終的な答えだけでなく、そこに至る思考プロセスも出力に含めます。たとえば、数学の文章題では、「ステップバイステップで解いてください」という指示を使い、出力には各ステップの説明が含まれます。
多言語データも重要です。英語だけでなく、他の言語の指示とタスクも含めることで、モデルの多言語能力が向上します。ただし、各言語でのデータ品質を確保することは課題です。機械翻訳されたデータは、人間が書いた自然なデータほど効果的でない場合があります。
データ拡張技術も適用できます。既存の指示を言い換える、入力を変換する(同義語置換など)、出力の表現を変える、といった方法です。ただし、タスクの本質的な意味が変わらないよう注意が必要です。
クラウドソーシングは、大規模なデータセット作成に有用ですが、品質管理が課題です。明確なガイドライン、例、そしてトレーニングをアノテーターに提供します。複数のアノテーターに同じタスクを割り当て、一致度を測定します。低品質のアノテーターを特定し、そのデータを除外または再レビューします。
モデル生成データも、データセット拡張の手段として使用されます。既存の強力なモデル(たとえばGPT-4)を使って、新しい指示と応答のペアを生成します。ただし、モデル生成データには、元のモデルのバイアスやエラーが含まれる可能性があるため、人間のレビューと検証が不可欠です。Self-Instruct やAlpacaなどのプロジェクトは、この手法を探求しています。
データセットのイテレーションも重要です。最初のバージョンでモデルをトレーニングし、その性能を評価します。モデルが苦手とするタスクやケースを特定し、それらをカバーする追加データを収集します。このサイクルを繰り返すことで、データセットとモデルを継続的に改善できます。
データセットの文書化も忘れてはいけません。各タスクの説明、データソース、収集方法、品質管理プロセス、そして既知の制限や偏りを記録します。これにより、データセットの透明性が確保され、他の研究者がそれを理解し、改善できるようになります。
効果的な指示データセットの設計は、芸術と科学の両方です。タスクの多様性、指示の明確さ、データの品質、そしてバランスの取れた構成が、すべて重要です。適切に設計されたデータセットは、モデルに強力で汎用的な指示理解能力を与え、幅広いアプリケーションで優れた性能を発揮するAIアシスタントを生み出します。
12.3 プロンプトフォーマットの工夫
Tatsunori Hashimoto: Instruction tuningの効果を最大化するためには、プロンプトのフォーマットを慎重に設計する必要があります。フォーマットは、モデルがどのように指示を解釈し、応答を生成するかに大きな影響を与えます。実践的なフォーマット設計の技術を見ていきましょう。
最も基本的なプロンプトフォーマットは、指示と入力を明確に分離することです。多くの実装では、特殊なトークンやマーカーを使用して、各セクションを区別します。たとえば、「### Instruction: [指示]\n### Input: [入力]\n### Response: [応答]」のような構造です。この明確な分離により、モデルは各部分の役割を理解しやすくなります。
Alpacaフォーマットは、広く使用されている標準の一つです。これは次のような構造を持ちます。「Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}」入力が不要なタスクでは、Input セクションを省略できます。
チャット形式のフォーマットも重要です。これは、対話型アプリケーションで特に有用です。ChatMLフォーマットは、次のような構造を使います。「<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n{assistant_message}<|im_end|>」このフォーマットは、システムメッセージ、ユーザーメッセージ、アシスタントメッセージを明確に区別します。
システムメッセージは、アシスタントの振る舞いや役割を定義します。たとえば、「あなたは親切で知識豊富なAIアシスタントです。ユーザーの質問に正確かつ簡潔に答えてください」のようなメッセージです。システムメッセージは、モデルの全体的なトーンやスタイルを設定します。
特殊トークンの選択も重要です。モデルのトークナイザーに存在しない新しい特殊トークンを追加する場合、それらをトークナイザーに登録し、モデルの埋め込み層を拡張する必要があります。これには追加のトレーニングが必要ですが、フォーマットの明確さが向上します。既存の稀なトークンを特殊トークンとして再利用することもできますが、これは衝突のリスクがあります。
入力と出力の境界を明確にすることは、生成を制御する上で重要です。モデルは、どこで生成を停止すべきかを知る必要があります。多くの実装では、特定の終了トークン(たとえば<|im_end|>)を使用するか、改行文字の連続を停止信号として使用します。
Few-shot例を含める場合、フォーマットの一貫性が重要です。各例は、同じ構造(Instruction、Input、Response)を持つべきです。例の数は、通常1から5の範囲です。多すぎる例は、コンテキスト長を消費し、実際のタスク入力のためのスペースを減らします。
例を選択する際、タスクの多様性を示すことが有用です。単に簡単な例だけでなく、様々な難易度やケースを含めます。これにより、モデルはタスクの範囲をよりよく理解できます。
指示の明確さも重要です。曖昧な指示よりも、具体的で詳細な指示の方が、より良い結果をもたらします。「この文章を要約してください」よりも、「次の文章を2-3文で要約してください。主要なポイントに焦点を当ててください」の方が明確です。
制約や期待される出力フォーマットを指示に含めることも有効です。「箇条書きでリストしてください」「JSON形式で出力してください」「コードブロックを使用してください」といった明示的な指示により、モデルの出力を制御できます。
長いコンテキストを扱う場合、情報の構造化が重要です。セクション見出し、番号付きリスト、または明確な区切りを使用して、情報を整理します。これにより、モデルは長いコンテキストから関連情報を抽出しやすくなります。
多言語シナリオでは、指示の言語と応答の言語を明示することが重要です。「次の英語の文章を日本語に翻訳してください」のように、両方の言語を明示します。これにより、言語の混同を避けられます。
エラーハンドリングのための指示も含めるべきです。「もし質問に答えられない場合は、正直にそう伝えてください」「不確実な情報については、推測であることを明示してください」といった指示により、モデルのより誠実な振る舞いを促進できます。
テンプレートのバリエーションも、モデルの頑健性を向上させます。同じタスクでも、異なるフォーマットやテンプレートでトレーニングデータを作成します。これにより、モデルは特定のフォーマットに過度に依存せず、内容に基づいて応答できるようになります。
12.4 実験結果:タスク性能の向上
Tatsunori Hashimoto: Instruction tuningの理論と方法を理解したところで、実際の実験結果を見ていきましょう。これらのデータは、instruction tuningが実世界のタスクでどの程度の改善をもたらすかを明確に示しています。
FLAN研究の結果から始めましょう。T5-XXLモデル(11Bパラメータ)を62のテキストタスクでinstruction tuningしました。評価は、トレーニング中に見ていない25の未見タスクで行われました。ゼロショット設定では、プレトレーニングのみのT5-XXLが平均41.4%の性能だったのに対し、FLAN-T5は平均52.1%を達成しました。これは約10.7ポイントの向上です。
タスクカテゴリ別に見ると、改善の度合いは異なります。自然言語推論(NLI)タスクでは、プレトレーニングモデルが56.2%、FLANが67.8%で、11.6ポイントの向上でした。閉鎖型質問応答では、52.1%から68.9%へ16.8ポイントの大幅な向上が見られました。感情分析では、もともと性能が高く(78.3%)、FLANでは82.1%と、3.8ポイントの向上でした。
Few-shot設定では、さらに大きな改善が見られました。5-shot設定で、プレトレーニングモデルが平均49.2%、FLANが平均58.7%で、9.5ポイントの向上です。重要な観察は、instruction tuningがfew-shot学習能力を損なわず、むしろ向上させることです。
InstructGPTの結果も印象的です。人間の評価者による盲検比較では、1.3BパラメータのInstructGPTモデルが、175BパラメータのGPT-3よりも好まれました。具体的には、評価者の約85%が、InstructGPTの応答をGPT-3よりも好ましいと評価しました。これは、instruction tuningの効果がモデルサイズの差を超えることを示しています。
有害性の削減も測定されました。InstructGPTは、不適切または有害な出力を生成する頻度がGPT-3と比較して約25%減少しました。また、ユーザーの指示に従う頻度は約30%向上しました。これらの改善は、より安全で有用なAIシステムにとって重要です。
真実性の評価では、TruthfulQAベンチマークで測定されました。GPT-3は約21%の質問に真実かつ有益な答えを提供しましたが、InstructGPTは約34%に向上しました。これは13ポイントの大幅な改善です。ただし、依然として改善の余地があり、これは継続的な研究の対象です。
Alpacaの研究では、より小規模なモデルでの効果が示されました。LLaMA 7BモデルをSelf-Instructメソッドで生成された52,000の指示でファインチューニングしました。人間の評価では、Alpaca 7Bの性能がtext-davinci-003(GPT-3.5)に匹敵することが示されました。これは、適切なinstruction tuningにより、小規模モデルでも高い性能を達成できることを示しています。
Vicunaの研究では、チャット形式のデータセットでのファインチューニングが評価されました。LLaMA 13BをShareGPTの会話データでファインチューニングした結果、GPT-4による自動評価で、VicunaはChatGPTの性能の約90%に達しました。特に、創造的なライティングや複雑な推論タスクで良好な性能を示しました。
タスク転移の効果も測定されています。あるタスクカテゴリでinstruction tuningされたモデルは、関連する他のタスクカテゴリでも性能が向上します。たとえば、質問応答タスクでトレーニングされたモデルは、読解理解タスクでも改善を示します。この転移効果は、instruction tuningが一般的な理解能力を向上させることを示唆しています。
モデルサイズの影響も重要です。小規模モデル(1B未満)では、instruction tuningの効果は限定的です。中規模モデル(1B-10B)では、顕著な改善が見られます。大規模モデル(10B以上)では、最も大きな利益を得ます。これは、instruction tuningが複雑な言語理解を必要とするためと考えられています。
データ量の影響も研究されています。数千の指示例でも、ある程度の改善が見られます。しかし、数万から数十万の例でトレーニングすると、性能が大幅に向上します。ただし、収穫逓減の法則が働き、ある点を超えると追加データの利益は減少します。
データの多様性と量のトレードオフも興味深い発見です。10,000の高度に多様な例は、100,000の多様性の低い例よりも効果的であることがあります。これは、タスクカバレッジの重要性を示しています。
長期的な学習能力も評価されています。Instruction tunedモデルは、新しい情報をコンテキストから学習し、適用する能力が向上します。In-context learningのベンチマークで、instruction tunedモデルはプレトレーニングモデルを一貫して上回ります。
12.5 指示追従能力の評価方法
Tatsunori Hashimoto: Instruction tuningの効果を定量的に評価することは、モデルの改善を測定し、異なるアプローチを比較する上で不可欠です。しかし、指示追従能力の評価は、従来のNLPベンチマークよりも複雑です。効果的な評価方法を見ていきましょう。
人間評価は、最も直接的で信頼性の高い評価方法です。人間の評価者が、モデルの応答を様々な基準で評価します。典型的な評価基準には、有用性(helpfulness)、正確性(correctness)、無害性(harmlessness)、そして指示への忠実度(instruction following)が含まれます。
InstructGPTで使用された評価プロトコルを詳しく見てみましょう。評価者は、2つのモデルの応答を比較し、どちらがより好ましいかを選択します。評価基準は明確に定義されます。有用性では、「応答はユーザーの要求に効果的に対応しているか?」と問います。正確性では、「応答は事実的に正確か?」を確認します。
評価者間一致度も測定します。複数の評価者が同じ例を評価し、どの程度一致するかを確認します。Cohenのカッパ係数が通常0.6以上であれば、適度な一致とみなされます。低い一致度は、評価基準が曖昧か、タスクが本質的に主観的であることを示唆します。
しかし、人間評価にはコストと時間がかかります。大規模な評価には、数百時間の人間の時間が必要です。これは、迅速なイテレーションや、多数のモデルバリアントの比較を困難にします。そこで、自動評価メトリクスが重要になります。
自動評価の最も単純な形式は、既存のベンチマークを使用することです。GLUEやSuperGLUEのようなベンチマークは、多様なタスクをカバーしています。Instruction tunedモデルを、これらのベンチマークのゼロショット設定で評価します。ただし、これらのベンチマークは特定のタスク形式を想定しており、自由形式の指示追従能力を完全には捉えられません。
より包括的な評価のために、複数のベンチマークを組み合わせます。MMLU(Massive Multitask Language Understanding)は、57の異なる科目にわたる知識を評価します。BBH(Big-Bench Hard)は、困難な推論タスクを含みます。TruthfulQAは、真実性を評価します。これらを組み合わせることで、モデルの多面的な能力を評価できます。
モデルベース評価も最近注目されています。強力なモデル(たとえばGPT-4)を評価者として使用し、他のモデルの応答を評価させます。Alpaca研究では、このアプローチが使用されました。GPT-4に、2つのモデルの応答を比較させ、どちらがより良いかを判断させます。
モデルベース評価のプロトコルは次のようになります。「あなたは公平な裁判官です。以下の指示と2つのAIアシスタントの応答を評価してください。どちらの応答がより有用で、正確で、無害かを判断し、理由とともに説明してください。」このプロンプトを使い、モデルに評価させます。
研究では、GPT-4の評価が人間の評価と高い相関(相関係数約0.8)を示すことが報告されています。ただし、モデルベース評価にはバイアスがあります。評価モデル自身のスタイルに似た応答を好む傾向があります。また、長い応答や、特定のフォーマットを持つ応答を好むことがあります。
自動メトリクスとしては、参照ベースのメトリクスもあります。BLEUやROUGEは、生成されたテキストと参照テキストの重なりを測定します。しかし、これらは指示追従能力を評価するには不十分です。同じ意味を表現する方法は多数あり、参照テキストと異なっていても正しい応答があり得るからです。
より洗練されたメトリクスとして、BERTScoreやBARTScoreがあります。これらは、意味的類似性を評価します。参照テキストとの表面的な一致ではなく、意味の一致を測定します。これにより、パラフレーズや異なる表現を適切に評価できます。
タスク特化のメトリクスも重要です。質問応答タスクでは、F1スコアやExact Match(完全一致)を使用します。要約タスクでは、ROUGEスコアが標準です。翻訳では、BLEUやCOMETが使用されます。各タスクに適したメトリクスを選択することが重要です。
多次元評価も推奨されます。単一のスコアではなく、複数の側面を個別に評価します。たとえば、流暢性、関連性、正確性、完全性、そして創造性を別々にスコア化します。これにより、モデルの強みと弱みをより詳細に理解できます。
困難な例やエッジケースでの評価も重要です。多くのモデルは、簡単な例では良好に機能しますが、曖昧な指示、矛盾する要求、または複雑な推論を要する例では失敗します。これらの困難なケースを含む評価セットを作成することで、モデルの真の能力をより正確に測定できます。
対話的評価も有用です。単一の指示と応答だけでなく、複数ターンの対話でモデルを評価します。モデルは、前のコンテキストを理解し、一貫した応答を維持できるでしょうか?ユーザーの明確化要求に適切に対応できるでしょうか?これらは、実世界のアプリケーションで重要な能力です。
敵対的評価も考慮すべきです。モデルを「騙そう」とする入力、たとえば矛盾する指示や、トリッキーな表現を使った質問でテストします。頑健なモデルは、これらの敵対的入力に対しても適切に対応すべきです。
安全性評価も不可欠です。モデルは、有害なコンテンツの生成を拒否できるか?バイアスのある応答を避けられるか?個人情報を適切に扱えるか?これらの安全性側面を評価するための専用のベンチマークとプロトコルが必要です。
継続的評価も重要です。モデルをデプロイした後、実際のユーザーとのインタラクションから学びます。ユーザーフィードバック、満足度評価、そして問題報告を収集し、分析します。これにより、ベンチマークでは捉えられない実世界の性能を理解できます。
包括的な評価は、複数の方法を組み合わせます。自動メトリクスで大規模なスクリーニングを行い、人間評価で詳細な品質チェックを行い、モデルベース評価で中間的なフィードバックを得ます。この多層的アプローチにより、コストと精度のバランスを取りながら、モデルの指示追従能力を包括的に評価できます。
13. Parameter-efficient finetuning with LoRA
13.1 パラメータ効率的ファインチューニングの必要性
Tatsunori Hashimoto: これまで議論してきた標準的なファインチューニングには、大きな実践的課題があります。特に、大規模モデルでは、すべてのパラメータを更新するfull finetuningは、膨大な計算リソースとメモリを必要とします。パラメータ効率的ファインチューニング(Parameter-Efficient Fine-Tuning、PEFT)は、この問題に対する革新的な解決策です。
Full finetuningの課題を具体的に見てみましょう。70億パラメータのモデルをファインチューニングする場合、すべてのパラメータの勾配を計算し、保存する必要があります。混合精度学習を使用しても、パラメータ自体に14GB、勾配に14GB、オプティマイザの状態(AdamWの場合)に28GB、合計で約56GBのメモリが必要です。さらに、アクティベーションやその他の中間値を考慮すると、実際には80GB以上のメモリが必要になることがあります。
複数のタスクへの適応を考えると、問題はさらに深刻です。10個の異なるタスクにモデルをファインチューニングしたい場合、従来のアプローチでは、各タスクに対して完全なモデルのコピーを保存する必要があります。70億パラメータのモデルでは、10個のタスクで140GB(各14GB)のストレージが必要です。これは、ストレージコストだけでなく、デプロイメントの複雑さも増加させます。
計算コストも無視できません。Full finetuningでは、モデルのすべてのレイヤーでbackward passを実行する必要があります。これは、forward passと同等かそれ以上の計算を要します。限られたGPUリソースでは、このコストがボトルネックになります。
推論時の切り替えも課題です。実用システムでは、異なるユーザーや異なるタスクに対して、異なるファインチューニングされたモデルを使用したいことがあります。しかし、完全なモデルを切り替えることは、メモリのロードと初期化に時間がかかり、レイテンシを増加させます。
これらの課題は、研究コミュニティだけでなく、実務にも影響します。多くの組織は、複数のユースケースに対応するために、複数のファインチューニングされたモデルを管理する必要があります。しかし、限られたリソースでは、これは実用的ではありません。
パラメータ効率的ファインチューニングは、これらすべての問題に対処します。核心的なアイデアは、モデルのすべてのパラメータを更新するのではなく、少数の追加パラメータまたは既存パラメータのサブセットのみを更新することです。これにより、メモリ使用量、計算コスト、そしてストレージ要件が劇的に削減されます。
PEFTの利点を定量化してみましょう。LoRAを使用した場合、学習可能なパラメータは元のモデルの0.1%から1%程度になることがあります。70億パラメータのモデルでは、700万から7000万の学習可能なパラメータです。これにより、メモリ使用量が大幅に削減され、単一のGPUでもファインチューニングが可能になります。
ストレージの面では、各タスクに対して、追加された少数のパラメータのみを保存すればよくなります。LoRAでは、ランクが8の場合、追加パラメータは数十MBから数百MB程度です。10個のタスクでも、数GBのストレージで済みます。これは、full finetuningの140GBと比較して、98%以上の削減です。
計算コストも削減されます。更新するパラメータが少ないため、backward passが高速化されます。さらに、一部のPEFT手法では、元のモデルのレイヤーを凍結できるため、勾配計算をスキップできます。これにより、トレーニング時間が短縮されます。
推論時の柔軟性も向上します。ベースモデルは一度ロードすれば、異なるタスクの小さなアダプターを迅速に切り替えることができます。これにより、マルチテナントシステムやマルチタスクシステムのデプロイメントが容易になります。
PEFTには、いくつかの異なるアプローチがあります。Adapter layersは、既存のTransformerレイヤーの間に小さなボトルネックレイヤーを挿入します。Prefix tuningは、入力シーケンスの前に学習可能なプレフィックスを追加します。Prompt tuningは、ソフトプロンプト(連続的な埋め込み)を学習します。そして、LoRA(Low-Rank Adaptation)は、重み行列の更新を低ランク分解で近似します。
これらの中で、LoRAは最も効果的で広く採用されている手法の一つです。実装がシンプルで、推論時のオーバーヘッドがほとんどなく、そしてfull finetuningに匹敵する性能を達成できるからです。LoRAの詳細を次のセクションで見ていきましょう。
13.2 LoRA(Low-Rank Adaptation)の原理
Tatsunori Hashimoto: LoRA(Low-Rank Adaptation)は、Microsoftの研究者によって2021年に発表された技術で、パラメータ効率的ファインチューニングの分野に大きな影響を与えました。その原理は、エレガントで実装が容易でありながら、非常に効果的です。
LoRAの基本的なアイデアは、重み行列の更新を低ランク分解で表現することです。標準的なファインチューニングでは、プレトレーニングされた重み行列Wを、ファインチューニング中にΔWだけ更新します。更新後の重みは、W' = W + ΔWです。LoRAは、このΔWを2つの小さな行列の積として表現します。ΔW = BAです。
具体的に説明しましょう。元の重み行列Wが次元d × kを持つとします。たとえば、d = 4096、k = 4096の場合、Wは約1670万の要素を持ちます。LoRAでは、ランクrを選択します(たとえばr = 8)。そして、2つの行列を導入します。B(d × r)とA(r × k)です。
この分解により、パラメータ数が劇的に削減されます。元のΔWは4096 × 4096 = 16,777,216パラメータですが、LoRAのBとAは、4096 × 8 + 8 × 4096 = 32,768 + 32,768 = 65,536パラメータです。これは元の約0.4%です。このランクrが、LoRAの圧縮率を制御します。
Forward passでは、更新後の重みを明示的に計算する必要はありません。入力xに対する出力は、y = (W + BA)x = Wx + BAxと計算できます。つまり、元の重みWとの積と、追加のBAとの積を別々に計算し、加算します。これにより、元のモデルのforward passをほとんど変更せずに、LoRAを統合できます。
LoRAの重要な設計選択は、初期化です。行列Aはランダムに初期化され、行列Bはゼロで初期化されます。この初期化により、トレーニング開始時、ΔW = BA = 0となり、モデルはプレトレーニングされた状態から始まります。これは、安定したトレーニングを保証します。
スケーリングファクターαも導入されます。実際の更新は、ΔW = (α/r) × BAです。αはハイパーパラメータで、通常rと同じ値に設定されます。このスケーリングにより、異なるランクrを使用しても、学習率を大きく調整する必要がなくなります。
LoRAは、どのレイヤーに適用すべきでしょうか。TransformerモデルのQuery、Key、Value、そしてOutputの射影行列に適用することが一般的です。これらは、アテンション機構の中核であり、タスク適応に重要な役割を果たします。フィードフォワードネットワークの重み行列にも適用できますが、研究では、アテンションの重みのみに適用しても十分な性能が得られることが多いと報告されています。
LoRAの理論的基盤は、intrinsic dimensionality(内在次元)の概念に基づいています。この仮説によれば、ファインチューニング中のパラメータ空間での最適化は、実際には低次元の部分空間で行われています。つまり、すべてのパラメータを独立に調整する必要はなく、低次元の部分空間での調整で十分だということです。
実験的な証拠もこれを支持しています。研究では、ランクr = 1や2という非常に低いランクでも、ある程度の性能が得られることが示されています。r = 4から8では、多くのタスクでfull finetuningに近い性能が達成されます。r = 16以上では、ほとんどfull finetuningと同等の性能になります。
LoRAの実装は比較的シンプルです。PyTorchでの基本的な実装は、数十行のコードで済みます。Hugging FaceのPEFTライブラリは、LoRAの標準的な実装を提供しており、数行のコードで任意のTransformerモデルにLoRAを適用できます。
推論時の挙動も重要です。LoRAには2つのオプションがあります。第一に、BとAを別々に保持し、forward pass中にWx + BAxを計算します。これにより、複数のLoRAアダプターを簡単に切り替えられます。第二に、W' = W + BAを事前に計算し、統合された重みを使用します。これにより、推論時のオーバーヘッドがゼロになります。
13.3 低ランク行列による近似
Tatsunori Hashimoto: LoRAの効果を深く理解するためには、低ランク近似の数学的基礎を理解する必要があります。なぜ低ランクの更新で十分なのか、そしてランクの選択がどのように性能に影響するかを見ていきましょう。
行列のランクは、その線形独立な行(または列)の数です。d × k行列の最大ランクは、min(d, k)です。低ランク行列は、ランクがこの最大値よりもはるかに小さい行列です。低ランク行列は、高度に構造化されており、冗長性を持っています。
低ランク分解の基本的な定理は、任意の行列Mを、M = UΣV^Tと分解できることです(特異値分解、SVD)。ここで、UとVは直交行列、Σは特異値を対角成分に持つ対角行列です。ランクrの近似は、最大のr個の特異値のみを保持することで得られます。
LoRAの文脈では、ΔW = BAという分解を使用します。これは、SVDとは異なりますが、同様の効果を持ちます。BとAはそれぞれd × rとr × kの次元を持ち、その積はランクr以下の行列になります。トレーニング中、BとAは勾配降下により最適化され、タスクに適した低ランク更新を学習します。
ランクrの選択は、表現力と効率性のトレードオフです。低いランク(r = 1や2)は、非常に効率的ですが、表現力が限られます。複雑なタスク適応を捉えるには不十分かもしれません。高いランク(r = 64や128)は、より表現力がありますが、パラメータ数が増加し、効率性の利点が減少します。
実験的に、最適なランクはタスクに依存します。シンプルなタスク(たとえば、2クラス分類)では、r = 4や8で十分です。複雑なタスク(たとえば、複雑な生成タスク)では、r = 16や32がより良い結果をもたらすことがあります。一般的な経験則として、r = 8は多くのタスクで良好な性能と効率性のバランスを提供します。
パラメータ数の具体的な計算を見てみましょう。GPT-3のアテンション層を考えます。隠れ層次元d = 12,288です。Query、Key、Value、Outputの4つの射影があり、それぞれ12,288 × 12,288の重み行列を持ちます。Full finetuningでは、各行列に約1.5億パラメータ、合計で約6億パラメータです。
LoRAをランクr = 8で適用すると、各射影に対して2つの行列があります。B(12,288 × 8)とA(8 × 12,288)で、合計12,288 × 8 + 8 × 12,288 = 196,608パラメータです。4つの射影で約78.6万パラメータです。これは、full finetuningの約0.13%です。
この劇的な削減にもかかわらず、LoRAは驚くべき性能を達成します。RoBERTa-baseモデルをGLUEタスクでファインチューニングした実験では、LoRA(r = 8)がfull finetuningとほぼ同等の性能を示しました。一部のタスクでは、LoRAが実際にfull finetuningを上回りました。
なぜこれほど少ないパラメータで十分なのでしょうか。一つの説明は、ファインチューニング中の更新が本質的に低ランクであるということです。タスク適応に必要な変更は、パラメータ空間の低次元部分空間に集中しています。言い換えれば、タスクに関連する情報は、少数の主要な方向に沿って捉えられます。
もう一つの視点は、プレトレーニングモデルが既に豊かな表現を持っているということです。ファインチューニングは、これらの表現を微調整するだけで、根本的に変更する必要はありません。低ランク更新は、この微調整を効率的に実現します。
ランクと性能の関係を定量的に見てみましょう。ある研究では、LLaMA 7BモデルをWikipediaのデータセットでファインチューニングしました。r = 1では、パープレキシティが18.5でした。r = 4では16.2、r = 8では15.8、r = 16では15.6でした。Full finetuningは15.5でした。r = 8で、full finetuningの性能の約98%を達成しています。
異なるレイヤーに異なるランクを使用することも可能です。一部の研究では、浅いレイヤーには低いランク、深いレイヤーには高いランクを使用することが有効であることが示されています。これは、深いレイヤーがタスク特化の情報をより多く保持しているという直感と一致します。
LoRAの行列BとAを分析することで、モデルが何を学習しているかの洞察が得られます。特異値分解を行うと、学習された更新の主要な方向が明らかになります。一部の研究では、これらの方向がタスク関連のセマンティック情報に対応していることが示されています。
低ランク構造の別の利点は、正則化効果です。パラメータ空間を低次元部分空間に制限することで、過学習のリスクが減少します。これは、特に小規模データセットでファインチューニングする場合に有益です。実験では、LoRAがfull finetuningよりも汎化性能が良いことがあることが示されています。
メモリ効率の分析も重要です。トレーニング中、勾配はBとAに対してのみ計算されます。元の重みWは凍結されているため、その勾配を計算または保存する必要がありません。これにより、メモリ使用量が大幅に削減されます。70億パラメータのモデルでr = 8のLoRAを使用する場合、学習可能なパラメータは約700万で、元のモデルの0.1%です。
計算効率も向上します。Backward passでは、凍結されたレイヤーの勾配計算をスキップできます。これにより、トレーニング速度が向上します。実験では、LoRAを使用したトレーニングが、full finetuningと比較して約25%から40%高速であることが報告されています。
低ランク近似の限界も理解しておくことが重要です。非常に複雑で、高次元の適応が必要なタスクでは、低ランク制約が性能を制限する可能性があります。そのような場合、より高いランクを使用するか、full finetuningに戻る必要があるかもしれません。
実践的には、ランクの選択は実験的に決定されます。まず低いランク(r = 4や8)から始め、検証セットでの性能を評価します。性能が不十分であれば、ランクを増やします。多くの場合、r = 8や16で十分な性能が得られ、さらに高いランクは追加の利益をもたらしません。
LoRAの低ランク近似は、理論的にエレガントで、実践的に効果的です。パラメータ数を劇的に削減しながら、full finetuningに匹敵する性能を達成できます。この技術は、リソースが限られた環境での大規模モデルのファインチューニングを可能にし、マルチタスク学習やパーソナライゼーションなどの新しいアプリケーションを開きます。
13.4 実装方法と適用レイヤーの選択
Tatsunori Hashimoto: LoRAの理論を理解したところで、実際にどのように実装し、どのレイヤーに適用すべきかを見ていきましょう。適切な実装と適用戦略は、LoRAの効果を最大化する上で重要です。
LoRAの実装は、概念的にはシンプルですが、いくつかの重要な設計選択があります。まず、基本的なLoRA層の実装を考えましょう。標準的な線形層 y = Wx を、LoRA拡張版 y = Wx + (α/r)BAx に置き換えます。PyTorchでの実装例を見てみましょう。
LoRA層は、元の重みWを凍結し、追加のパラメータ行列BとAのみを学習可能にします。初期化では、Aをランダム(通常はGaussian)に初期化し、Bをゼロに初期化します。これにより、トレーニング開始時にΔW = 0となり、プレトレーニングされたモデルから出発します。
Hugging FaceのPEFTライブラリは、LoRAの標準的な実装を提供しています。使用方法は非常にシンプルです。まず、LoRA設定を定義します。lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.1)のようになります。次に、この設定をモデルに適用します。model = get_peft_model(model, lora_config)だけで、指定されたモジュールにLoRAが適用されます。
適用するレイヤーの選択は、性能と効率のトレードオフに大きく影響します。Transformerモデルでは、主に4つの選択肢があります。Query射影のみ、Query とValueのみ、Query、Key、Value、Output の4つすべて、またはこれらに加えてフィードフォワード層です。
最も一般的なアプローチは、アテンション機構のQuery とValue射影にのみLoRAを適用することです。元のLoRA論文では、この設定が良好な性能と効率のバランスを提供することが示されました。Query とValueは、アテンション計算の中核であり、タスク適応に最も重要な役割を果たします。
実験的証拠を見てみましょう。RoBERTa-baseをGLUEタスクでファインチューニングした研究では、Queryのみの場合、平均精度が82.1%でした。Query とValueの場合は84.3%、4つすべての射影では84.8%でした。Full finetuningは85.0%でした。Query とValueだけで、full finetuningの約98.5%の性能を達成しており、十分なトレードオフです。
フィードフォワード層にもLoRAを適用すると、さらなる改善が得られることがありますが、パラメータ数が大幅に増加します。フィードフォワード層の重み行列は通常大きく(d × 4d)、LoRAを適用すると学習可能なパラメータが2倍以上になることがあります。多くの場合、この追加のコストに見合う性能向上は得られません。
ランクrの選択も、適用するレイヤー数と関連しています。少数のレイヤー(Query とValueのみ)にLoRAを適用する場合、やや高いランク(r = 16や32)が有益かもしれません。多数のレイヤーに適用する場合、低いランク(r = 4や8)で十分です。総パラメータ数を一定に保ちながら、深さと幅のバランスを取ることが重要です。
ドロップアウトの使用も考慮すべきです。LoRAに対してもドロップアウトを適用できます。lora_dropoutパラメータは、LoRA層の出力にドロップアウトを適用します。典型的な値は0.05から0.1です。ドロップアウトは、小規模データセットでの過学習を防ぐのに役立ちますが、大規模データセットでは不要な場合が多いです。
バイアス項の扱いも設計選択の一つです。LoRAは通常、重み行列にのみ適用され、バイアス項は更新しません。しかし、一部の実装では、バイアスも学習可能にするオプションがあります。実験では、バイアスを学習可能にすることで、わずかな性能向上が得られることが報告されていますが、効果は限定的です。
レイヤーノーマライゼーションのパラメータも、学習可能にするか選択できます。これらは比較的少数のパラメータ(各レイヤーあたり2d)であるため、学習可能にしてもメモリへの影響は小さいです。一部の研究では、レイヤーノームのパラメータを学習可能にすることで、特に分布シフトがあるタスクで性能が向上することが示されています。
マルチタスク設定では、各タスクに対して別々のLoRAアダプターを学習できます。ベースモデルは共有され、各タスクは独自のBとA行列を持ちます。推論時、タスクに応じて適切なアダプターを選択します。これにより、単一のベースモデルで複数のタスクをサポートできます。
実装の最適化も重要です。LoRAの計算は、y = Wx + BAxと2つの行列乗算に分解できますが、これをより効率的に実装する方法があります。まずtemp = Axを計算し(r × kとk × bの乗算)、次にBtempを計算します(d × rとr × bの乗算)。rが小さいため、これらの演算は元のWxよりもはるかに高速です。
推論時の統合も考慮すべきです。トレーニング後、W' = W + (α/r)BAを計算し、統合された重みを使用できます。これにより、推論時のオーバーヘッドがゼロになります。ただし、複数のアダプターを頻繁に切り替える場合、別々に保持する方が効率的です。
量子化との組み合わせも可能です。ベースモデルを8ビットや4ビットに量子化し、LoRAアダプターはFP16やBF16で保持します。これにより、メモリ使用量がさらに削減されます。QLoRA(次のセクションで詳しく説明)は、この組み合わせを活用しています。
13.5 実験結果:メモリ削減とパフォーマンス維持
Tatsunori Hashimoto: LoRAの理論的利点を理解したところで、実際の実験結果を見ていきましょう。メモリ削減と性能維持の両方において、LoRAは印象的な結果を示しています。
まず、メモリ使用量の削減から見ていきましょう。GPT-3 175Bモデルをfull finetuningする場合、FP16でも約1.2TBのGPUメモリが必要です(パラメータ、勾配、オプティマイザの状態を含む)。これは、最大のGPU(A100 80GB)でも15台以上必要です。LoRAをr = 4で適用すると、学習可能なパラメータは約3500万(0.02%)になり、必要なメモリは約350GBに削減されます。これは約70%の削減で、5台のA100で実行可能になります。
より小規模なモデルでも、効果は顕著です。LLaMA 7BモデルをA100 40GBでファインチューニングする実験では、full finetuningはバッチサイズ1でギリギリ収まるか、メモリ不足になります。LoRA(r = 8)では、バッチサイズ8が可能で、スループットが大幅に向上しました。
具体的な数値で見てみましょう。7Bモデル、バッチサイズ4、シーケンス長512の設定で、full finetuningは約42GBのメモリを使用しました。LoRA(r = 8、Query とValueのみ)は約18GBで、約57%の削減です。LoRA(r = 16、4つの射影すべて)でも約24GB で、約43%の削減です。
トレーニング速度も測定されています。同じハードウェアとバッチサイズで、LoRAはfull finetuningと比較して約20%から30%高速でした。これは、更新するパラメータが少なく、backward passが高速化されるためです。メモリ削減により、より大きなバッチサイズを使用できることを考慮すると、実効的なスループット向上はさらに大きくなります。
性能面では、LoRAは驚くべき結果を示しています。RoBERTa-baseをGLUEベンチマークの8タスクでファインチューニングした実験では、full finetuningが平均85.0%の精度を達成しました。LoRA(r = 8、Query とValue)は84.3%で、わずか0.7ポイントの差です。LoRA(r = 16、4つの射影)は84.8%で、差は0.2ポイントまで縮まります。
個別のタスクを見ると、いくつかの興味深いパターンがあります。MNLI(自然言語推論)では、full finetuningが87.6%、LoRAが87.2%で、ほぼ同等です。QQP(質問ペアの類似性)では、full finetuningが91.9%、LoRAが91.7%でした。CoLA(文法性判断)では、full finetuningが68.0%、LoRAが67.8%でした。ほとんどのタスクで、差は1ポイント未満です。
大規模モデルでは、LoRAの性能がさらに向上します。GPT-3 175BをSuperGLUEタスクでファインチューニングした実験では、LoRA(r = 4)がfull finetuningと統計的に区別できない性能を示しました。一部のタスクでは、LoRAが実際にfull finetuningを上回りました。これは、LoRAの正則化効果が、過学習を防いでいる可能性を示唆しています。
生成タスクでも評価されています。GPT-2をE2E NLGタスク(データからテキスト生成)でファインチューニングした実験では、full finetuningがBLEUスコア68.2を達成しました。LoRA(r = 8)は67.8で、わずか0.4の差です。人間評価では、両者の出力は区別できませんでした。
要約タスクでは、XSumデータセットでの実験が行われました。BART-largeモデルで、full finetuningがROUGE-Lスコア44.2を達成しました。LoRA(r = 8)は43.9で、0.3の差です。CNN/DailyMailデータセットでも同様の結果が得られ、LoRAとfull finetuningの差は無視できる程度でした。
マルチタスク設定での評価も重要です。単一のベースモデルに複数のLoRAアダプターを適用した実験では、各タスクで専用のfull finetuningモデルと比較して、平均で約1%の性能低下しかありませんでした。しかし、ストレージ要件は90%以上削減されました。10個のタスクで、full finetuningモデルは140GBのストレージを必要としますが、LoRAは共有ベースモデル(14GB)+ 10個のアダプター(各約50MB)で約14.5GBです。
収束速度も評価されています。同じ最終性能に達するまでのステップ数を比較すると、LoRAはfull finetuningとほぼ同じでした。つまり、LoRAは収束を遅らせることなく、メモリと計算の利点を提供します。
小規模データセットでの性能も重要です。100例、1000例、10000例のデータセットでファインチューニングを比較した実験では、LoRAはすべてのデータサイズでfull finetuningと同等以上の性能を示しました。特に、非常に小規模なデータセット(100例)では、LoRAがfull finetuningを一貫して上回りました。これは、LoRAの低ランク制約が正則化として機能し、過学習を防いでいるためと考えられます。
13.6 Full finetuningとの比較
Tatsunori Hashimoto: LoRAとfull finetuningの包括的な比較は、どちらのアプローチを選択すべきかを理解する上で重要です。各アプローチには、特定の状況で有利な点と不利な点があります。
パラメータ数の比較から始めましょう。Full finetuningでは、モデルのすべてのパラメータが学習可能です。7Bモデルでは70億、175Bモデルでは1750億のパラメータです。LoRAでは、学習可能なパラメータは通常0.1%から1%です。7Bモデルでr = 8の場合、約700万から1400万パラメータです。これは1000倍から10000倍の削減です。
メモリ使用量では、full finetuningが勾配とオプティマイザの状態をすべてのパラメータに対して保持する必要があります。LoRAは、少数のLoRAパラメータに対してのみこれらを保持します。7Bモデルの場合、full finetuningが約42GB、LoRAが約18GBで、約57%の削減です。
計算コストでは、LoRAがbackward passでの勾配計算を少数のパラメータに制限できるため、約20%から30%高速です。ただし、forward passの計算量は同等です(BAx項の追加計算は無視できる程度)。したがって、推論速度はほぼ同じです(アダプターを統合した場合は完全に同じ)。
ストレージ要件では、LoRAの利点が最も顕著です。各タスクまたはユーザーに対して、full finetuningでは完全なモデルを保存する必要がありますが、LoRAは小さなアダプターのみを保存します。10個のタスクで、full finetuningが140GB、LoRAが約15GBです。これは90%以上の削減です。
性能面では、LoRAとfull finetuningはほぼ同等です。多くのベンチマークで、差は1%未満です。一部のタスク、特に小規模データセットや、ドメインがプレトレーニングデータと近い場合、LoRAがfull finetuningを上回ることさえあります。ただし、非常に複雑なタスクや、大きなドメインシフトがある場合、full finetuningがわずかに優れていることがあります。
柔軟性の面では、LoRAが優れています。複数のアダプターを迅速に切り替えられるため、マルチタスクシステムやパーソナライゼーションが容易です。Full finetuningでは、完全なモデルを切り替える必要があり、これは時間とメモリを消費します。
実装の複雑さでは、両者は同程度です。Full finetuningは概念的にシンプルですが、大規模モデルでは並列化やメモリ管理が複雑になります。LoRAは追加のレイヤーを導入しますが、PEFTライブラリなどのツールにより、実装は簡単になっています。
トレーニングの安定性では、両者ともほぼ同等です。適切な学習率とハイパーパラメータを使用すれば、両方とも安定してトレーニングできます。LoRAは、低ランク制約により軽い正則化効果を持ち、小規模データセットでの安定性がやや向上する可能性があります。
Catastrophic forgetting(破滅的忘却)の観点では、LoRAがより保守的です。元の重みを凍結し、追加の低ランク更新のみを学習するため、プレトレーニングで学習した知識がより保持されます。Full finetuningでは、すべての重みが変更されるため、元の知識が部分的に失われる可能性があります。
コストの面では、LoRAが明確に有利です。トレーニングコスト、ストレージコスト、そしてデプロイメントコストのすべてで、LoRAは大幅な節約をもたらします。組織が複数のタスクやユーザー向けにモデルをカスタマイズする場合、コスト削減は数十万ドルに達する可能性があります。
適用シナリオの推奨として、LoRAは次のような場合に最適です。リソースが限られている、複数のタスクに適応する必要がある、迅速なイテレーションが重要、小規模から中規模のデータセット、ドメインがプレトレーニングデータと比較的近い場合です。
Full finetuningが推奨される場合は、最高の性能が絶対的に必要、十分な計算リソースがある、単一タスクに特化、非常に大規模なデータセット(数百万例以上)、ドメインがプレトレーニングデータと大きく異なる場合です。
実践的なアプローチとして、多くの組織はまずLoRAを試し、性能が要件を満たすか評価します。ほとんどの場合、LoRAで十分です。性能が不十分な場合にのみ、full finetuningに移行します。これにより、コストを最小化しながら、必要な性能を達成できます。
LoRAとfull finetuningは、排他的ではありません。ハイブリッドアプローチも可能です。たとえば、モデルの初期レイヤーを凍結またはLoRAで更新し、深いレイヤーはfull finetuningするといった戦略です。これにより、両方の利点を組み合わせることができます。
総じて、LoRAは大規模言語モデルのファインチューニングにおいて、ゲームチェンジャーとなっています。大幅なリソース削減を提供しながら、性能をほぼ完全に維持します。これにより、以前はリソース的に不可能だったアプリケーション、たとえば大規模なパーソナライゼーションやマルチタスク学習が、実現可能になりました。LoRAは、効率的なAI開発の未来における重要な技術です。
14. QLoRA
14.1 QLoRAの概要
Tatsunori Hashimoto: QLoRA(Quantized Low-Rank Adaptation)は、2023年にワシントン大学の研究者によって発表された画期的な技術で、LoRAと量子化を組み合わせることで、さらに劇的なメモリ削減を実現します。この技術により、単一の消費者向けGPU、たとえば24GBのRTX 4090やRTX 3090でも、650億パラメータのモデルをファインチューニングできるようになりました。
QLoRAの核心的なアイデアは、ベースモデルを4ビットに量子化し、メモリに保存しながら、LoRAアダプターは高精度(通常BF16またはFP16)で学習するという組み合わせです。これにより、両方の技術の利点が統合されます。量子化によるメモリ削減とLoRAによる効率的な学習です。
従来、量子化は主に推論時に使用されてきました。トレーニング時の量子化は、数値的な不安定性や精度の大幅な低下を引き起こすことが知られていました。しかし、QLoRAは、慎重に設計された量子化手法とLoRAの組み合わせにより、トレーニング中でも4ビット量子化が実用的であることを示しました。
QLoRAの動機を理解するために、具体的な数字を見てみましょう。LLaMA 65Bモデルを標準的なLoRA(FP16)でファインチューニングする場合、約260GBのGPUメモリが必要です。これは、A100 80GB を4台必要とします。QLoRAを使用すると、約48GBに削減され、単一のA100 80GBまたは2台のRTX 3090(各24GB)で実行可能になります。これは約81%のメモリ削減です。
さらに印象的なのは、より小規模なモデルでの影響です。LLaMA 13Bモデルは、標準的なLoRAで約52GBを必要としますが、QLoRAでは約10GB に削減されます。これにより、RTX 3060(12GB)のような手頃なGPUでも13Bモデルをファインチューニングできるようになりました。
QLoRAの重要性は、民主化の観点からも理解できます。以前は、大規模モデルのファインチューニングは、高価なデータセンターGPUを持つ組織や研究機関に限られていました。QLoRAにより、個人の研究者や小規模な組織でも、消費者向けハードウェアで大規模モデルをファインチューニングできるようになりました。
QLoRAが解決する主要な課題は3つあります。第一に、メモリのボトルネックです。4ビット量子化により、モデルのメモリフットプリントを約75%削減します。第二に、計算効率です。量子化された重みはメモリから高速にロードでき、メモリ帯域幅を節約します。第三に、アクセシビリティです。高価なハードウェアなしで、最先端のモデルでの実験が可能になります。
QLoRAの技術的革新は、いくつかの重要なコンポーネントから成ります。4ビットNormalFloat(NF4)という新しいデータ型、Double Quantization(量子化定数自体も量子化)、そしてPaged Optimizers(メモリスパイクを管理)です。これらの技術が組み合わさって、4ビットでのトレーニングを実用的かつ効果的にします。
性能面でも、QLoRAは驚くべき結果を示します。元の論文では、LLaMA 65BをQLoRAでファインチューニングしたモデルが、ChatGPTのベースモデルであるGPT-3.5を多くのベンチマークで上回りました。これは、適切な技術を使用すれば、極端なメモリ制約下でも高性能なモデルを生み出せることを示しています。
QLoRAは、単なる技術的なトリックではなく、LLMの学習と展開方法に対する根本的な再考を促します。メモリは常に貴重なリソースであり、その効率的な使用は、より大規模で能力の高いモデルへの道を開きます。QLoRAは、その方向への重要な一歩です。
14.2 量子化とLoRAの組み合わせ
Tatsunori Hashimoto: QLoRAの成功は、量子化とLoRAを巧妙に組み合わせることにあります。この組み合わせがどのように機能し、なぜ効果的なのかを詳しく見ていきましょう。
QLoRAのアーキテクチャは、階層的なメモリ使用戦略を採用しています。ベースモデルの重みWは4ビットで量子化され、GPUメモリに保存されます。しかし、計算時には、これらの重みを16ビット精度に逆量子化してから使用します。LoRAアダプター(行列BとA)は、常に16ビット精度(BF16またはFP16)で保存され、学習されます。
Forward passの流れを詳しく見てみましょう。入力xが与えられたとき、まず量子化されたベース重みW_4bitを16ビットに逆量子化します。W_16bit = dequantize(W_4bit)です。次に、標準的なLoRAと同様に、y = W_16bit × x + BA × xを計算します。ここで、BとAは16ビットのLoRAパラメータです。
重要な洞察は、計算は常に高精度で行われることです。量子化は純粋にストレージの最適化であり、実際の演算は16ビットで実行されます。これにより、数値的安定性が保たれます。ただし、逆量子化のオーバーヘッドが発生しますが、実験ではこれは許容可能であることが示されています。
Backward passでは、勾配は16ビット精度で計算されます。ベース重みは凍結されているため、その勾配を計算する必要はありません。LoRAパラメータBとAに対する勾配のみが計算され、これらは16ビットで更新されます。したがって、量子化はトレーニングの数値的安定性に影響しません。
この設計により、QLoRAは量子化の主要な利点(メモリ削減)を活用しながら、その主要な欠点(精度低下と数値的不安定性)を回避します。ベース重みが量子化されていても、計算は高精度で行われ、学習可能なパラメータは完全精度で保持されます。
メモリ使用量の内訳を具体的に見てみましょう。LLaMA 13Bモデルの場合、FP16のベース重みは約26GBです。4ビット量子化により、これは約6.5GBに削減されます。LoRAパラメータ(r = 64、すべてのアテンション層)は約200MBです。オプティマイザの状態(LoRAパラメータのみ)は約400MBです。アクティベーションとその他の中間値が約3GB です。合計で約10GB程度です。
この内訳から、ベース重みの量子化が最大の影響を持つことがわかります。LoRAパラメータは比較的小さいため、これらを高精度で保持してもメモリへの影響は限定的です。これが、QLoRAが性能を維持しながら大幅なメモリ削減を達成できる理由です。
量子化のタイミングも重要です。QLoRAでは、プレトレーニングされたモデルをロードする際に量子化します。つまり、量子化は一度だけ行われ、トレーニング中は固定されます。これにより、量子化のオーバーヘッドが最小化されます。また、量子化されたモデルを保存しておけば、次回のファインチューニングで再利用できます。
異なるコンポーネントへの異なる精度の適用は、QLoRAの鍵となる戦略です。タスク適応に重要な部分(LoRAパラメータ)は高精度で、一般的な知識を保持する部分(ベース重み)は低精度でという使い分けです。この非対称性が、効率と性能のバランスを実現します。
実験的証拠も、この組み合わせの有効性を支持しています。QLoRAでファインチューニングされたモデルを、16ビットLoRAと比較した研究では、性能差は通常1%未満でした。多くのタスクでは、統計的に区別できませんでした。これは、4ビットのベース重みが、16ビットのLoRAアダプターと組み合わせることで、実用的に十分であることを示しています。
興味深いことに、一部のタスクではQLoRAが16ビットLoRAを上回ることさえありました。これは、量子化が軽度の正則化効果を持ち、過学習を防いでいる可能性を示唆しています。ただし、この効果は一貫しておらず、タスクに依存します。
14.3 4-bit量子化の実装
Tatsunori Hashimoto: QLoRAの成功は、4ビット量子化の慎重な実装に大きく依存しています。単純に重みを4ビットに丸めるだけでは、大幅な精度低下が発生します。QLoRAは、いくつかの革新的な技術を導入してこれを解決します。
最も重要な革新は、NormalFloat 4(NF4)データ型です。標準的な4ビット整数(INT4)は、-8から7の値を表現します。これは、ニューラルネットワークの重みの分布に最適ではありません。研究により、プレトレーニングされたモデルの重みは、ゼロ平均の正規分布に近いことが示されています。
NF4は、この観察を活用します。量子化レベルを、正規分布の分位点に配置します。具体的には、標準正規分布を16の等確率区間に分割し、各区間の代表値を量子化レベルとして使用します。これにより、ゼロ付近に多くの量子化レベルが集中し、外れ値には少数のレベルが割り当てられます。
NF4の量子化レベルは、次のような値です。[-1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0, 0.0911, 0.1848, 0.2844, 0.3949, 0.5251, 0.6962, 1.0, そして特殊値]。これらは、正規分布の特定の分位点に対応しています。
この分布を意識した量子化により、より多くの重みが高精度で表現されます。実験では、NF4がINT4と比較して、平均二乗誤差を約20%から30%削減することが示されています。これが、QLoRAが実用的な性能を達成する重要な要因です。
Double Quantizationは、QLoRAのもう一つの重要な技術です。量子化には、スケーリングファクター(量子化定数)が必要です。たとえば、重みの範囲[-1.5, 1.5]を[-8, 7]にマッピングする場合、スケーリングファクターは約0.1です。標準的なアプローチでは、これらのスケーリングファクターをFP32で保存します。
しかし、大規模モデルでは、これらのスケーリングファクターも大きなメモリを消費します。LLaMA 65Bでは、ブロックごとの量子化(64要素のブロック)を使用すると、約0.5GBのスケーリングファクターが必要です。Double Quantizationは、これらのスケーリングファクター自体も量子化します。通常、FP32からINT8に量子化します。
これにより、さらに約0.37GBのメモリが節約されます。これは小さく見えるかもしれませんが、メモリが非常に限られている環境では、この追加の節約が、モデルが収まるかどうかの違いを生むことがあります。また、スケーリングファクターのロードも高速化され、逆量子化のスループットが向上します。
ブロックサイズの選択も重要です。QLoRAでは、デフォルトのブロックサイズは64です。つまり、64個の連続する重みが、単一のスケーリングファクターを共有します。小さなブロックサイズ(たとえば32)は、より細かい粒度の量子化を提供し、精度が向上しますが、スケーリングファクターのメモリ使用量が増加します。大きなブロックサイズ(たとえば128)は、逆です。64は、精度とメモリのバランスが取れた選択として選ばれました。
Paged Optimizersは、メモリスパイクを管理するための技術です。トレーニング中、特に長いシーケンスを処理する際、アクティベーションが一時的に大量のメモリを消費することがあります。これがメモリ不足(OOM)エラーの一般的な原因です。
Paged Optimizersは、オプティマイザの状態をCPUメモリにページアウトし、必要に応じてGPUにページインします。これは、オペレーティングシステムの仮想メモリと同様の概念です。通常、オプティマイザの状態は頻繁にアクセスされないため、このページングのオーバーヘッドは許容可能です。実験では、性能への影響は約5%未満でした。
実装の詳細も重要です。QLoRAは、bitsandbytesライブラリを使用して実装されています。このライブラリは、効率的な4ビット演算と逆量子化のカーネルを提供します。Hugging FaceのTransformersライブラリとの統合により、既存のモデルに簡単にQLoRAを適用できます。
使用方法は比較的シンプルです。モデルをロードする際に、量子化設定を指定します。model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)のようになります。その後、通常のLoRAと同様に、PEFT設定を適用します。
逆量子化のオーバーヘッドは測定されています。Forward passでは、逆量子化により約10%から15%の時間増加が観測されます。しかし、メモリ帯域幅の削減により、一部の状況ではこれが相殺されます。特に、メモリバウンドな演算では、4ビットデータの高速なロードが利点となります。
精度の評価も行われています。量子化誤差を測定するために、元の16ビット重みと量子化・逆量子化後の重みの平均二乗誤差(MSE)を計算しました。NF4では、MSEは通常0.01から0.05の範囲でした。これは、重みの分散の約1%から5%に相当します。この誤差は、実際のタスク性能にはほとんど影響しないことが示されています。
異なる量子化戦略の比較も行われました。INT4、NF4、そして8ビット量子化を比較した実験では、INT4が平均で3%から5%の性能低下を示したのに対し、NF4は1%から2%の低下でした。8ビット量子化は0.5%未満の低下でしたが、メモリ削減は50%に留まりました。NF4は、性能とメモリのバランスが最も良好でした。
QLoRAの4ビット量子化は、単なる圧縮技術ではなく、大規模モデルへのアクセスを民主化する革新です。NF4、Double Quantization、そしてPaged Optimizersの組み合わせにより、以前は不可能だったハードウェアで、大規模モデルのファインチューニングが可能になりました。これは、AI研究とアプリケーション開発の風景を変える可能性を持つ技術です。
14.4 実験結果:さらなるメモリ効率化
Tatsunori Hashimoto: QLoRAの理論的な利点を理解したところで、実際の実験結果を見ていきましょう。これらのデータは、QLoRAがどれほどのメモリ削減を実現し、それでいて性能をどの程度維持できるかを明確に示しています。
まず、メモリ使用量の劇的な削減から見ていきましょう。元のQLoRA論文では、様々なサイズのLLaMAモデルでの包括的なベンチマークが行われました。LLaMA 7Bモデルでは、16ビットLoRAが約22GBのGPUメモリを必要としましたが、QLoRAではわずか約6GBに削減されました。これは約73%の削減です。
より大規模なモデルでは、削減効果がさらに顕著です。LLaMA 13Bでは、16ビットLoRAが約52GBを必要としたのに対し、QLoRAは約10GBでした。約81%の削減です。LLaMA 33Bでは、16ビットLoRAが約120GBを必要としましたが、QLoRAは約20GBで、約83%の削減です。そして最も印象的なのは、LLaMA 65Bで、16ビットLoRAの約260GBに対し、QLoRAは約48GBで、約82%の削減を達成しました。
これらの数値の実践的な意味は重大です。LLaMA 65Bを16ビットLoRAでファインチューニングするには、A100 80GB を4台必要とします。これは、多くの研究者や組織にとって現実的ではありません。QLoRAを使用すれば、単一のA100 80GBで実行できます。あるいは、2台のRTX 3090(各24GB)でも可能です。これは、アクセシビリティを根本的に変えます。
バッチサイズへの影響も重要です。同じGPUメモリで、QLoRAははるかに大きなバッチサイズを使用できます。LLaMA 7BをRTX 3090(24GB)でファインチューニングする実験では、16ビットLoRAがバッチサイズ1でギリギリ収まるか、メモリ不足になりました。QLoRAでは、バッチサイズ4が可能で、スループットが約3倍向上しました。
性能面では、QLoRAは驚くべき結果を示しています。元の論文では、複数のベンチマークでQLoRAと16ビットLoRAを比較しました。MMLU(多分野知識ベンチマーク)では、LLaMA 65BをQLoRAでファインチューニングしたモデルが平均63.4%のスコアを達成しました。16ビットLoRAは63.9%でした。差はわずか0.5ポイントです。
より詳細な分析では、57のサブジェクト全体で、QLoRAと16ビットLoRAの差は平均で0.3ポイント未満でした。一部のサブジェクトでは、QLoRAが実際に16ビットLoRAを上回りました。たとえば、数学では、QLoRAが52.1%、16ビットLoRAが51.8%でした。これは統計的なばらつきの範囲内ですが、4ビット量子化が致命的な性能低下を引き起こさないことを明確に示しています。
Vicunaベンチマークでは、チャット形式での性能が評価されました。GPT-4を評価者として使用し、QLoRAでファインチューニングされたLLaMA 13BとChatGPTを比較しました。QLoRAモデルは、ChatGPTの性能の約90%を達成しました。これは、手頃なハードウェアでファインチューニングされたモデルが、商業的なAPIモデルに匹敵する性能を示すという、注目すべき成果です。
OpenAssistantデータセットでのファインチューニングでは、より詳細な比較が行われました。LLaMA 33BをQLoRAと16ビットLoRAの両方でファインチューニングし、人間評価者による盲検比較を行いました。評価者の約52%がQLoRAモデルを好み、48%が16ビットLoRAモデルを好みました。この差は統計的に有意ではなく、両者が実質的に同等であることを示しています。
異なるランクでの実験も行われました。r = 4、8、16、32、64での比較では、QLoRAは各ランクで16ビットLoRAとほぼ同等の性能を示しました。r = 8では、差は通常0.5ポイント未満でした。r = 64では、差はほとんど測定不可能でした。これは、QLoRAが異なるパラメータ効率レベルで一貫して機能することを示しています。
トレーニング速度も測定されました。同じハードウェアとバッチサイズで、QLoRAは16ビットLoRAと比較して約5%から10%遅くなりました。これは、逆量子化のオーバーヘッドによるものです。しかし、メモリ削減により大きなバッチサイズを使用できることを考慮すると、実効的なスループット(単位時間あたりのサンプル処理数)は、実際にはQLoRAの方が高いことが多いです。
具体的な数値で見てみましょう。LLaMA 7BをA100 40GBでファインチューニングする実験では、16ビットLoRAがバッチサイズ2で約8サンプル/秒を処理しました。QLoRAはバッチサイズ8で約28サンプル/秒を処理しました。イテレーションあたりの時間は長くなりましたが、大きなバッチサイズにより、全体のスループットは約3.5倍向上しました。
収束速度も評価されています。同じ最終性能に達するまでのステップ数を比較すると、QLoRAは16ビットLoRAとほぼ同じでした。つまり、量子化は収束を遅らせません。一部のタスクでは、QLoRAがわずかに速く収束することさえありました。これは、量子化のノイズが軽度の正則化効果を持つ可能性を示唆しています。
Double Quantizationの効果も個別に測定されました。Double Quantizationなしの4ビット量子化と比較すると、LLaMA 65Bで約0.37GBの追加メモリ削減が得られました。性能への影響は無視できる程度(0.1ポイント未満)でした。これは、Double Quantizationが実質的に「タダで」追加のメモリ削減を提供することを示しています。
NF4とINT4の比較も重要です。同じモデルをINT4とNF4でファインチューニングした実験では、NF4がINT4を一貫して上回りました。LLaMA 13BをMMLUで評価した場合、NF4が54.2%、INT4が51.8%で、約2.4ポイントの差がありました。これは、分布を意識した量子化の重要性を強調しています。
異なるハードウェアでの性能も評価されました。RTX 4090(24GB)では、LLaMA 13BをQLoRAでバッチサイズ4でファインチューニングできました。RTX 3090(24GB)でも同様です。さらに、RTX 3060(12GB)でも、LLaMA 7Bをバッチサイズ2でファインチューニングできました。これは、手頃な消費者向けハードウェアでも大規模モデルのファインチューニングが可能であることを示しています。
長期的なトレーニングでの安定性も確認されました。数日間にわたるトレーニング実行で、QLoRAは数値的に安定していました。損失の発散やNaN/Infの発生は観察されませんでした。これは、4ビット量子化がトレーニングの安定性を損なわないことを示しています。
14.5 実用上の利点と制約
Tatsunori Hashimoto: QLoRAの実験結果は印象的ですが、実際のプロジェクトでQLoRAを使用する際の利点と制約を理解することが重要です。これにより、適切なユースケースでQLoRAを効果的に活用できます。
QLoRAの最も明白な利点は、ハードウェアアクセシビリティの向上です。以前は高価なデータセンターGPUでしか不可能だったことが、消費者向けGPUで可能になりました。個人の研究者、学生、小規模なスタートアップが、自分のワークステーションで650億パラメータのモデルをファインチューニングできます。これは、AI研究とアプリケーション開発の民主化における重要な一歩です。
コスト削減も大きな利点です。クラウドでGPUをレンタルする場合、A100 80GBは1時間あたり約3ドルから5ドルです。4台必要な場合、コストは4倍になります。QLoRAにより、単一GPUで済むため、トレーニングコストが75%削減されます。大規模なファインチューニングプロジェクトでは、これは数千ドルから数万ドルの節約になります。
迅速なイテレーションも可能になります。メモリ制約が緩和されるため、より多くの実験を並行して実行できます。複数の異なるハイパーパラメータ設定や、異なるデータセットでのファインチューニングを同時に試すことができます。これにより、開発サイクルが加速され、より良いモデルをより速く見つけられます。
マルチタスク学習やパーソナライゼーションも容易になります。単一のベースモデルから、複数の異なるQLoRAアダプターを作成できます。各アダプターは数十MBから数百MB程度なので、数百の異なるタスクやユーザー向けのアダプターを管理することが現実的になります。
教育面での利点も無視できません。学生や新しい研究者が、実際に大規模モデルで実験できることは、学習体験を大幅に向上させます。理論を学ぶだけでなく、実践的な経験を積むことができます。QLoRAにより、教育機関は限られた予算で、最先端のAI教育を提供できます。
しかし、QLoRAには制約もあります。最も明白な制約は、性能のわずかな低下です。多くの場合、差は1%未満ですが、最高の性能が絶対的に必要な場合、この小さな差が重要かもしれません。競争的なベンチマークや、クリティカルなアプリケーションでは、16ビットLoRAやfull finetuningの方が適切かもしれません。
逆量子化のオーバーヘッドも考慮すべきです。Forward passとbackward passで、量子化された重みを16ビットに変換する必要があります。これには計算コストがかかり、トレーニング速度が約5%から10%低下します。大規模なバッチサイズを使用できない場合、この速度低下がボトルネックになる可能性があります。
実装の複雑さも増加します。QLoRAは、bitsandbytesライブラリとの統合を必要とし、一部の環境では追加の依存関係やセットアップが必要です。また、すべてのモデルアーキテクチャがすぐにサポートされるわけではなく、カスタムモデルでは追加の作業が必要な場合があります。
デバッグも若干困難になります。量子化エラーやメモリ関連の問題が発生した場合、その原因を特定することは、標準的なトレーニングよりも複雑です。経験の浅いユーザーにとって、これは障壁になる可能性があります。
推論時の考慮事項もあります。QLoRAでファインチューニングされたモデルを推論に使用する場合、いくつかのオプションがあります。一つは、4ビットベースモデルとLoRAアダプターを保持し、推論時も同じ構成を使用することです。これは最もメモリ効率的ですが、逆量子化のオーバーヘッドがあります。
もう一つのオプションは、LoRAアダプターをベースモデルにマージし、完全な16ビットモデルを作成することです。これにより、推論時のオーバーヘッドがなくなりますが、メモリ使用量が増加します。3つ目のオプションは、統合された16ビットモデルをさらに8ビットや4ビットに量子化して推論に使用することです。
ハードウェア依存性も制約の一つです。QLoRAは、NVIDIAのGPUとCUDAに最適化されています。AMDのGPUやAppleのM1/M2チップでは、サポートが限定的または存在しません。クロスプラットフォームのデプロイメントを計画している場合、これは考慮すべき点です。
データ型の互換性も注意が必要です。一部のツールやライブラリは、4ビット量子化されたモデルを直接サポートしていない場合があります。モデルをエクスポートしたり、他のフレームワークに移植したりする際、追加の変換ステップが必要になることがあります。
実用的な推奨として、QLoRAは次のような場合に最適です。リソースが非常に限られている、複数の実験を迅速に実行したい、大規模モデルで学習や実験をしたい、わずかな性能低下が許容できる、デプロイメントで高いメモリ効率が重要、という状況です。
一方、次のような場合は、16ビットLoRAやfull finetuningの方が適切かもしれません。最高の性能が絶対的に必要、十分な計算リソースがある、トレーニング速度が最優先、クロスプラットフォーム互換性が重要、または量子化の追加の複雑さを避けたい、という場合です。
多くのプロジェクトでは、ハイブリッドアプローチも有効です。まずQLoRAで迅速にプロトタイプを作成し、複数のアプローチを試します。最良の設定を特定したら、必要に応じて16ビットLoRAやfull finetuningで最終的な高性能モデルを訓練します。これにより、開発の効率性と最終製品の品質の両方を最適化できます。
QLoRAは、完璧なソリューションではありませんが、多くの状況で非常に価値のあるツールです。そのトレードオフを理解し、適切なユースケースで使用することで、限られたリソースでも大規模モデルの力を活用できます。QLoRAは、LLMファインチューニングの風景を変え、より多くの人々に最先端のAI技術へのアクセスを提供する重要な技術です。
以上で、今日の講義は終わりです。プレトレーニングから始まり、様々な最適化技術を経て、効率的なファインチューニング手法まで、LLMトレーニングの広範な側面をカバーしました。これらの技術を理解し、適切に適用することで、皆さんは効果的で効率的なLLMシステムを構築できるでしょう。
Stanford CME295 Transformers & LLMs | Autumn 2025 | Lecture 4 - LLM Training
For more information about Stanford’s graduate programs, visit: https://online.stanford.edu/graduate-education October 17, 2025 This lecture covers: • Pretraining • Quantization • Hardware optimization • Supervised finetuning (SFT) • Parameter-efficient finetuning (LoRA) To follow along with the course schedule and syllabus, visit: https://cme295.stanford.edu/syllabus/ Chapters: 00:00:00 Introduction 00:07:19 Pretraining 00:13:26 FLOPs, FLOPS 00:16:34 Scaling laws, Chinchilla law 00:24:49 Training optimizations overview 00:31:09 Data parallelism with ZeRO 00:35:51 Model parallelism 00:38:26 Flash Attention 00:52:37 Quantization 00:56:00 Mixed precision training 01:02:31 Supervised finetuning 01:09:21 Instruction tuning 01:37:53 Parameter-efficient finetuning with LoRA 01:45:16 QLoRA Afshine Amidi is an Adjunct Lecturer at Stanford University. Shervine Amidi is an Adjunct Lecturer at Stanford University. View the course playlist: https://www.youtube.com/playlist?list=PLoROMvodv4rOCXd21gf0CF4xr35yINeOy
youtu.be

.jpg?table=block&id=2ef4f05b-476b-8052-9dc6-dee2ca154e31&spaceId=376a3c22-baec-411c-a473-688d45d9966b&expirationTimestamp=1769594400000&signature=FinpGK82oTZaxxOZ9xyLlYQGLXg7jAg2PBSeWbHzMgo)