※本記事は、Stanford大学のオンラインコース「Stanford CS336: Language Modeling from Scratch」の講義内容を基に作成されています。このコースに関する詳細情報はhttps://stanford-cs336.github.io/ でご覧いただけます。Stanford大学のAIプログラムについてはhttps://stanford.io/ai をご参照ください。本記事では、講義の内容を要約しております。なお、本記事の内容は原講義の内容を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や文脈については、オリジナルの講義をご視聴いただくことをお勧めいたします。
【登壇者紹介】 ・Percy Liang: Computer Science准教授、Center for Research on Foundation Models (CRFM)ディレクター
1. イントロダクション
1.1 前回の講義の概要
Percy Liang: 前回の講義では、言語モデルの概要について説明し、言語モデルを「スクラッチから構築する」とはどういう意味なのか、そしてなぜそれを行いたいのかという点について話しました。また、トークン化についても取り上げました。これは最初の課題の前半部分になります。トークン化は、生のテキストデータを言語モデルが処理できる形式に変換する重要なプロセスです。
今日の講義では、実際にモデルを構築していく過程に進みます。PyTorchで必要となる基本要素について説明し、効率性、特にリソース(メモリや計算能力)の使用方法に注目していきます。
1.2 今回の講義の焦点(PyTorchの基本要素と効率性)
今日の講義では、実際にモデルを構築するために必要なPyTorchの基本要素について説明します。具体的には、テンソル、モデル、オプティマイザー、そしてトレーニングループについて取り上げます。特に重要なのは効率性です。メモリと計算リソースの両方をどのように使用するかに注目していきます。
効率性はこの分野で非常に重要です。なぜなら、モデルが大きくなるにつれて、使用するリソースの量は直接的にコスト(金銭的なコスト)に変換されるからです。この数字をできるだけ小さくすることが望ましいのです。今日はトランスフォーマーについては詳しく説明しません。その概念的な概要については次回Tatsuが説明する予定です。トランスフォーマーについて学ぶ方法は多くありますが、課題1に取り組めば、確実にトランスフォーマーについて理解できるでしょう。
この講義で扱う知識の種類としては、「メカニクス」(PyTorchの基本的な仕組み)、「マインドセット」(リソース計算の考え方)、そして少しだけ「直感」にも触れていきます。ただし、今回の講義は主にメカニクスとマインドセットに焦点を当てています。
2. リソース会計の重要性
2.1 ナプキン計算の事例
リソース会計の重要性を理解するために、いくつかの質問から始めましょう。これらの質問は「ナプキン計算」で答えられるものです。つまり、簡単な概算で答えが出せる計算です。ナプキンを取り出して計算してみましょう。
例えば、700億パラメータの密なトランスフォーマーモデルを15兆トークンでトレーニングする場合、1,024台のH100 GPUを使用するとどれくらいの時間がかかるでしょうか?
このような問題の考え方を簡単にスケッチしてみます。まず、トレーニングに必要な総フロップ数を計算します。これは「6×パラメータ数×トークン数」という式で求められます。この「6」という係数がどこから来るのかについては、この講義で説明していきます。
次に、H100が提供する1秒あたりのフロップ数を調べます。MFU(モデルフロップ使用率)を0.5と仮定しましょう。これを用いて、1日あたりのハードウェアが提供するフロップ数を計算できます。1,024台のH100を1日使用した場合のフロップ数です。
最後に、モデルトレーニングに必要な総フロップ数を、ハードウェアが1日で提供するフロップ数で割ります。結果は約144日となります。
このように、最終的にはとても単純な計算になります。今日の講義では、この「6×パラメータ数×トークン数」という式がどこから来るのかを詳しく説明していきます。
もう一つの質問として、もしあまり工夫をしない場合、8台のH100でAtom Wを使用して最大どのくらいの大きさのモデルをトレーニングできるでしょうか?
H100には80ギガバイトのHBMメモリがあります。パラメータ1つあたりに必要なバイト数は、パラメータ自体、勾配、オプティマイザの状態を合わせて16バイトです。なぜそうなるのかも後で説明します。パラメータ数は、総メモリ量をパラメータあたりに必要なバイト数で割ることで計算でき、約400億パラメータになります。
これは非常に大まかな計算であり、バッチサイズやシーケンス長に依存する活性化関数は考慮していません。これらは課題1で重要になってきますが、今回は詳しく説明しません。このような概算計算は、効率性を考える上で非常に重要です。大規模なモデルでは、フロップ数が直接的にコストに変換されるからです。
2.2 フロップ計算と実行時間の見積もり
フロップ計算と実行時間の見積もりは、モデルトレーニングのコストを理解する上で重要です。先ほどの例で使用した「6×パラメータ数×トークン数」という式の詳細については、この講義の中で説明していきます。特に、勾配計算のコストを分析する際に、この「6」という係数がどこから来るのかが明らかになります。
フロップ(浮動小数点演算)の計算方法を知ることは非常に重要です。効率性がこの分野の重要な要素であり、実際に使用しているフロップ数を正確に把握する必要があります。数字が大きくなると、これらは直接ドルに変換されるため、できるだけ小さくしたいものです。
例えば、GPT-3は約3.23×10^23フロップを要し、GPT-4は約2×10^25フロップを必要としたと推測されています。また、米国の行政命令では、1.26×10^26フロップを超える基盤モデルは政府に報告する必要があるとされていましたが、これは現在取り消されています。ただし、EUのAI法では依然として1×10^25という基準が維持されています。
ハードウェアのパフォーマンスについても理解しておく必要があります。A100は1秒あたり312テラフロップのピークパフォーマンスを持ち、H100はスパーシティを活用した場合1,979テラフロップ、スパーシティなしでは約50%のパフォーマンスを発揮します。
NVIDIAの仕様書を見ると、フロップ性能はどのような計算を行うかによって大きく異なります。FP32を使用すると、FP16やBF16と比較して性能が大幅に低下します。さらに、FP8を使用すると、より高速になります。ただし、「スパーシティあり」という注釈がある点に注意が必要です。この講義で扱う行列は多くの場合密であるため、実際にはこの数値の半分程度の性能しか得られません。
これらの知識を基に、例えば8台のH100を2週間使用した場合のフロップ数を計算できます。これは、8×フロップ/秒×1週間の秒数で、約4.7×10^21フロップになります。この数値を他のモデルトレーニングのフロップ数と比較することで、自分のプロジェクトの規模を把握することができます。
2.3 メモリ要件の見積もり
メモリ要件の見積もりも、モデルトレーニングにおいて重要な側面です。先ほど触れた質問「8台のH100で最大どのくらいの大きさのモデルをトレーニングできるか」を詳しく見ていきましょう。
H100には80ギガバイトのHBMメモリがあります。パラメータ1つあたりに必要なバイト数は、パラメータ自体、勾配、オプティマイザの状態を合わせて16バイトになります。この「16」という数字の由来については後ほど詳しく説明します。
総パラメータ数は、利用可能な総メモリ量をパラメータあたりに必要なバイト数で割ることで計算できます。つまり、約400億パラメータということになります。ただし、これは非常に大まかな計算であり、活性化関数のためのメモリは考慮していません。活性化関数のメモリ要件はバッチサイズとシーケンス長に依存するため、これは課題1で重要になってきます。
この種の概算計算は、効率性を追求する上で非常に重要です。大規模なモデルでは、メモリと計算リソースの効率的な使用が直接的にコストに影響するからです。大規模言語モデルのトレーニングでは、利用可能なハードウェアリソースの制約内で最大限のモデルサイズを実現することが常に課題となります。
効率性を追求するためには、こうした「ナプキン計算」をマスターし、モデルのメモリフットプリントと計算要件を正確に予測できることが重要です。これにより、リソースを最大限に活用し、コストを最小限に抑えながら最も効果的なモデルをトレーニングすることが可能になります。
3. メモリ会計
3.1 テンソルの基本概念
メモリ会計の話から始めて、その後計算会計について説明します。まずは基本的な構成要素であるテンソルから見ていきましょう。
テンソルは、ディープラーニングにおけるすべてのデータを保存するための基本的な構成要素です。パラメータ、勾配、オプティマイザの状態、データ、活性化関数の値など、すべてがテンソルとして保存されます。これらは言わば原子のようなものです。テンソルについては多くのドキュメントがあり、恐らく皆さんもテンソルの作成方法には慣れているでしょう。
テンソルを作成する方法はいくつかあります。また、テンソルを作成して初期化しない方法もあります。特にパラメータの場合、特別な初期化方法を使用したい場合があるでしょう。
テンソルがどれだけのメモリを使用するかについて考えてみましょう。私たちが興味を持つテンソルはほとんどの場合、浮動小数点数として保存されます。浮動小数点数を表現する方法はいくつかあります。最もデフォルトの方法はfloat32です。
float32は32ビットを持っており、1ビットは符号用、8ビットは指数部用、そして23ビットは仮数部用に割り当てられています。指数部は動的範囲を与え、仮数部は異なる値を指定します。float32はFP32または単精度とも呼ばれ、コンピューティングにおけるゴールドスタンダードと言えます。一部の人々はfloat32をフル精度と呼ぶこともありますが、これは少し紛らわしいです。科学計算の専門家に話すと、float32をフル精度と呼ぶとバカにされるかもしれません。彼らはfloat64やさらに高い精度を使用するからです。しかし、機械学習の世界ではfloat32が最大限必要な精度でしょう。ディープラーニングはそういった意味でかなり厳密さを欠くものだからです。
テンソルのメモリ使用量は非常に単純に、テンソル内の値の数とそれぞれの値のデータ型によって決まります。例えば、4×8の行列を作成した場合、デフォルトではfloat32型になります。サイズは4×8で、要素数は32です。各要素のサイズは4バイト(32ビット)で、メモリ使用量は単純に要素数×要素サイズで、128バイトになります。
これは非常に簡単な計算ですが、直感的に理解するために、例えばGPT-3のF4層にある一つの行列は、約2.3ギガバイトのサイズになります。つまり、一つの行列でもかなり大きくなる可能性があるのです。
3.2 浮動小数点表現の種類(FP32, FP16, BF16, FP8)
float32がデフォルトですが、行列は大きくなるため、当然小さくしたいと考えるでしょう。メモリ使用量を減らすだけでなく、小さくすることで計算も高速化できます。
もう一つの浮動小数点表現はfloat16と呼ばれ、名前が示す通り16ビットです。指数部と仮数部の両方が縮小され、指数部は8から5に、仮数部は23から10に減少しています。これは半精度(half precision)と呼ばれ、メモリ使用量を半分に削減します。
しかし、float16の問題点は動的範囲が十分でないことです。例えば、1e-8(10の-8乗)のような小さな数値をfloat16で表そうとすると、基本的に0に丸められてアンダーフローが発生します。つまり、float16は非常に小さな数値や非常に大きな数値を表現するのには適していません。
小規模なモデルのトレーニングではfloat16を使用しても問題ないかもしれませんが、大規模なモデルでは多くの行列を扱うため、不安定さやアンダーフロー、オーバーフローなどの問題が発生する可能性があります。
この問題に対処するために、BF16(Brain Float)という表現が2018年に開発されました。ディープラーニングでは、仮数部よりも動的範囲の方が重要であるという認識から生まれたものです。BF16は指数部により多くのビットを割り当て、仮数部を減らしています。
BF16はfloat16と同じメモリ使用量ですが、float32と同じ動的範囲を持っています。これは非常に有利に聞こえますね。デメリットとしては、仮数部によって決まる精度が低下することですが、ディープラーニングではそれほど重要ではありません。例えば、1e-8をBF16で表現すると、0ではない値が得られます。
詳細は省略しますが、すべての浮動小数点表現の完全な仕様を調べることができます。
BF16は通常、計算に使用される標準的な形式です。しかし、オプティマイザの状態やパラメータの保存にはfloat32が必要です。そうしないと、トレーニングが乱れてしまう可能性があります。
さらに小さいFP8(8ビット浮動小数点)も2022年にNVIDIAによって開発されました。FP8を見ると、FP16やBF16と比較して本当に少ないビット数しかありません。より解像度を重視するか、より動的範囲を重視するかによって、2つのバリアントがあります。
FP8はH100でサポートされていますが、以前の世代のGPUでは利用できません。
高レベルな視点から見ると、float32はオプティマイズを追求していない場合に使用される安全な選択肢ですが、より多くのメモリを必要とします。FP8やBF16に下げると、不安定さが生じる可能性があります。現時点では、ディープラーニングにfloat16を使用するメリットはあまりないでしょう。
さらに洗練されたアプローチとしては、パイプラインの特定の部分(順伝播、逆伝播、オプティマイザ、勾配の累積など)で必要な最小精度を検討し、混合精度トレーニングを行う方法があります。例えば、一部の人々はアテンションメカニズムにfloat32を使用して安定性を確保し、単純な行列乗算にはBF16を使用するという選択をしています。
3.3 精度とメモリ使用量のトレードオフ
精度とメモリ使用量のトレードオフについて、もう少し詳しく説明します。先ほど述べたように、浮動小数点数の表現方法によって、メモリ使用量と計算効率が大きく変わります。
float32を使用すると、最も高い精度が得られますが、最もメモリを消費します。一方、BF16やFP8のような低精度表現を使用すると、メモリ使用量を大幅に削減でき、計算も高速化できますが、数値的安定性が犠牲になる可能性があります。
混合精度トレーニングの考え方は、パイプラインの各部分で最適な精度を選択することです。例えば、アテンションメカニズムのような複雑な計算にはfloat32を使用して数値的安定性を確保し、単純な行列乗算にはBF16を使用するという方法があります。
一般的な推奨事項としては、パラメータとオプティマイザの状態にはfloat32を使用し、計算にはBF16を使用するというアプローチがあります。BF16は一時的なものとして考えることができます。つまり、パラメータをBF16にキャストしてモデルを実行し、時間をかけて蓄積するものには高い精度を使用するという方法です。
このトレードオフは、モデルのサイズや要件によって異なります。小規模なモデルでは、float32を使用しても問題ないかもしれませんが、大規模なモデルでは、メモリと計算効率を最適化するために、慎重に精度を選択する必要があります。
重要なのは、使用する精度によって、モデルのトレーニング安定性とパフォーマンスが大きく影響を受けるということです。適切なバランスを見つけることが、効率的なディープラーニングシステムの構築に不可欠です。混合精度トレーニングは、このバランスを実現するための有効なアプローチですが、実装は複雑になる可能性があります。幸いなことに、PyTorchには混合精度トレーニングを容易にするツールが用意されています。
4. テンソル操作
4.1 CPUとGPUのメモリ管理
では次に計算について話しましょう。計算の効率はハードウェアに大きく依存します。デフォルトでは、テンソルはCPUに保存されます。例えば、PyTorchでX = torch.zeros(32, 32)
と記述すると、そのテンソルはCPUに配置されます。つまり、CPUメモリに格納されることになります。
しかし、GPUを使用しない場合、計算速度は桁違いに遅くなってしまいます。そのため、テンソルを明示的にGPUに移動させる必要があります。これを図で説明すると、CPUにはRAMがあり、そのデータをGPUに移動させる必要があります。このデータ転送には時間がかかります。
テンソルをPyTorchで扱う際には、常にそれがどこに存在しているかを意識する必要があります。変数やコードを見ただけでは、テンソルがCPUにあるのかGPUにあるのか判断できない場合があります。計算とデータ移動を慎重に管理したい場合は、テンソルの場所を常に把握しておく必要があります。コード内の様々な場所で、テンソルの位置をアサート(確認)することで、ドキュメント化したり確認したりするのも良い方法です。
では、現在使用しているハードウェアを見てみましょう。この例では、H100クラスタを使用しています。これは皆さんがアクセスできるものです。このGPUはH100で、80GBの高帯域幅メモリを持っています。キャッシュサイズなどの情報も表示されています。
先ほど作成したテンソルXがCPUにある場合、それをGPUに移動するには、PyTorchの一般的な関数である.to()
を使用します。または、最初からGPU上にテンソルを直接作成することもできます。その場合、移動する必要はありません。
メモリ割り当てを確認してみると、この操作の前後でのメモリ使用量の差は、正確に2つの32×32行列(4バイトのfloat型)のサイズになるはずです。計算すると192バイトになります。これは、コードが説明通りに動作していることを確認するための簡単なチェックです。
4.2 テンソルのビューとストレージの共有
テンソルをGPUに移動したら、次は何をするでしょうか?課題1や一般的なディープラーニングアプリケーションで必要となる様々なテンソル操作があります。ほとんどのテンソルは、他のテンソルに対して操作を実行することで作成され、各操作にはメモリと計算のフットプリントがあります。これらを理解しておきましょう。
まず、PyTorchのテンソルとは実際に何なのかを考えてみましょう。テンソルは数学的なオブジェクトですが、PyTorchでは実際には割り当てられたメモリへのポインタです。例えば、4×4の行列がある場合、実際には長い配列のように見え、テンソルはその配列にアドレスするための方法を指定するメタデータを持っています。
このメタデータは、テンソルの各次元につき1つのストライド値を持ちます。2次元テンソルの場合、ストライド0とストライド1があります。ストライド0は、次の行に移動するために何個スキップする必要があるかを指定します。行を下に移動する場合、4個スキップするのでストライド0は4です。次の列に移動するには1個スキップするので、ストライド1は1です。
このストライド情報を使って、例えば位置(1,2)の要素を見つけるには、インデックスにストライドを掛けるだけです。1×4 + 2×1 = 6となり、これが配列内のインデックスになります。これがテンソルの内部動作の基本的な仕組みです。
この理解が重要な理由は、複数のテンソルが同じストレージを使用できるからです。同じデータを複数の場所にコピーしたくない場合、これは非常に便利です。例えば、2×3の行列があり、多くの操作ではコピーを作成せずに異なるビュー(view)を作成します。
例を見てみましょう。テンソルxが[1,2,3,4,5,6]であり、yはx[0](最初の行)を指しているとします。テンソルのストレージが同じかどうかを確認する関数を使うと、これらが同じストレージを共有していることがわかります。つまり、コピーは作成されていません。
列1を取得する場合も同様に、コピーは作成されません。view
関数を使って、テンソルを異なる次元として見ることもできます。例えば、2×3テンソルを3×2テンソルとして見ることができます。これもコピーを作成しません。
転置操作も同様にコピーを作成しません。そして重要なことに、Xを変更するとYも変更されます。これは、XとYが同じ基礎となるストレージへのポインタにすぎないからです。
注意すべき点は、一部のビューは「連続的(contiguous)」であり、テンソルを通過すると配列をスライドするようにストレージを進むことができますが、一部はそうではありません。特に転置すると、テンソルを通過する際に飛び飛びになります。非連続テンソルを持っている場合、さらに別の方法で見ようとすると、うまく機能しない場合があります。
場合によっては、非連続テンソルをまず連続的にしてから、望む操作を適用する必要があります。この場合、xとyは同じストレージを共有しなくなります。なぜなら、contiguous
操作がコピーを作成するからです。
要約すると、ビューはテンソルをスライスしたり分割したりする無料の方法です。コードを読みやすくするために異なる変数を定義しても、メモリは割り当てられないので自由に使用できます。ただし、contiguous
やreshape
のような操作はコピーを作成する可能性があるので、何をしているのか注意する必要があります。
4.3 連続性と非連続性テンソル
テンソルの連続性(contiguity)について詳しく説明しましょう。先ほど触れたように、連続的なテンソルとは、テンソルを通過する際にストレージ内の配列をスライドするように進むことができるものです。一方、非連続テンソルはそうではありません。
特に転置操作を行うと、非連続テンソルになります。例えば、2×3のテンソルを転置して3×2にすると、ストレージ内を移動する際にスキップが必要になります。元々は行方向に進んでいたものが、転置後は列方向に進む必要があるからです。このような非連続テンソルに対して、さらに別のビュー操作を適用しようとすると、うまく機能しない場合があります。
PyTorchでは、非連続テンソルかどうかを確認するためのis_contiguous()
メソッドが提供されています。非連続テンソルを扱う必要がある場合は、まずcontiguous()
メソッドを使って連続的なテンソルに変換し、その後で望む操作を適用するとよいでしょう。
ただし、このcontiguous()
操作はメモリにコピーを作成することを意味します。これは、元のテンソルのストレージとは別の新しいストレージにデータをコピーするからです。そのため、メモリ効率を考慮する場合は、可能な限りテンソルの連続性を維持するように操作を設計することが重要です。
連続的なテンソルは、多くのGPU操作で最適なパフォーマンスを発揮します。特に大規模な行列計算では、連続的なメモリレイアウトがキャッシュ効率を向上させ、計算速度を上げることができます。ただし、テンソルをコピーするコストと、非連続テンソルで操作を行うコストのトレードオフを考慮する必要があります。
大規模な言語モデルのトレーニングでは、こうした細かい最適化が積み重なって全体のパフォーマンスに大きな影響を与えます。テンソルのストレージパターンを理解し、連続性を意識した操作を行うことで、より効率的なモデルトレーニングが可能になります。
5. Einsum表記法
5.1 Einsumの利点と使用方法
少し話題を変えて、Einsumについて説明します。Einsumを使う動機は以下の通りです。通常、PyTorchでテンソルを定義し、x @ y.transpose(-2, -1)
のような演算を行います。このコードを見たとき、「-2とは何か?」と考えるでしょう。恐らく、これはシーケンス次元を指し、-1は隠れ次元を指していると推測できます。ただし、これらのインデックスで検索すると混乱しやすいです。
コードを見て-1、-2という数字があると、何を意味しているのか理解するのが難しくなります。優れたコードを書く場合はコメントを追加するでしょうが、コメントはコードと不一致になる可能性があり、デバッグが困難になります。
この問題の解決策がEinsumです。Einsumはアインシュタインの縮約記法に着想を得たもので、インデックスの代わりに次元に名前を付ける考え方です。
Jax typingと呼ばれるライブラリがあり、次元と型を指定するのに役立ちます。通常、PyTorchではコードを書いて、コメントで次元を示します。Jax typingを使うと、文字列として次元を書き下げることができ、より自然なドキュメント化が可能です。
ただし、これは強制力を持たないことに注意してください。PyTorchの型は少し不正確なので、チェッカーを使用して強制することも可能ですが、デフォルトでは強制されません。
Einsumは基本的に、良い記録管理機能を持った行列乗算のようなものです。例えば、バッチ次元、シーケンス次元、隠れ次元を持つテンソルXとYがあるとします。以前は複雑な演算を行っていましたが、Einsumを使うと、単に次元の名前を書き下げるだけです。
入力テンソルの次元名を「バッチ、シーケンス1、隠れ次元」と「バッチ、シーケンス2、隠れ次元」のように書き、出力に現れるべき次元を指定します。ここではバッチ、シーケンス1、シーケンス2を指定しています。出力に名前のない次元(隠れ次元)は合計され、名前のある次元は反復されます。
この記法に慣れると、-2
や-1
を使うよりも便利になります。より洗練された使い方としては、...
(省略記号)を使って任意の数の次元をブロードキャストすることもできます。これにより、例えばバッチの代わりに「バッチ1、バッチ2」などのより複雑な次元構造にも対応できます。
Einsumは最良の方法で次元を縮約する順序を自動的に見つけ出し、最も効率的な実装を使用します。PyTorch compileと共に使用すると、一度だけ最適化された実装を生成し、その後は同じ実装を再利用します。これは手作業で設計したものよりも優れた性能を発揮します。
5.2 行列乗算とテンソル操作
Einsumを使った行列乗算とテンソル操作について詳しく見ていきましょう。行列乗算はディープラーニングの基本中の基本ですが、Einsumを使うとより直感的に表現できます。
標準的な行列乗算の例を考えてみましょう。16×32の行列と32×2の行列を掛け合わせると、16×2の行列が得られます。しかし、機械学習アプリケーションでは、通常すべての操作をバッチで行いたいと考えます。言語モデルの場合、これは通常、バッチ内の各例とバッチ内の各シーケンスに対して何かを行うことを意味します。
したがって、単なる行列ではなく、典型的には次元がバッチ、シーケンス、そして行いたい処理によって構成されるテンソルを持つことになります。この場合、それは各トークンのための行列です。
PyTorchはこれをうまく処理してくれます。4次元テンソルと行列を掛け合わせる場合、実際には各バッチと各トークンに対してこれら2つの行列を掛け合わせていることになります。結果として、最初の2つの次元に対応する結果の行列が得られます。
これは特に複雑なことではありませんが、Einsumを使うとより明確に表現できます。Einsumでは、次元の名前を明示的に指定し、どの次元が出力に現れるべきかを指定します。これにより、コードがより読みやすく、エラーが少なくなります。
特に複雑なテンソル操作では、Einsumの利点が顕著になります。例えば、あるテンソルの1つの次元が実際には複数の次元の平坦化表現である場合、それを展開して操作し、再び平坦化したい場合があります。
具体的な例として、バッチ、シーケンス、そして8次元のベクトルを持つテンソルがあるとします。この8次元のベクトルは実際には、ヘッド数×隠れ次元の平坦化表現です。そして、隠れ次元に対して操作を行いたい重みベクトルがあります。
Einsumを使うと、この操作を非常にエレガントに表現できます。まずrearrange
を呼び出して、この次元が実際にはヘッド×隠れ次元1であることを明示します。ヘッドの数を指定する必要があります。なぜなら、数を2つに分割する方法は複数あるからです。
そして、変換されたテンソルxに対してEinsumを使って変換を実行します。「隠れ次元1」はxに対応し、「隠れ次元1、隠れ次元2」はwに対応し、出力には「隠れ次元2」を含めます。これにより、望みの変換が得られます。
最後に、再びrearrange
を使って2つの次元を1つにグループ化します。これは、他のすべての次元をそのままにしながら、平坦化操作を行うだけです。
このようなEinsumの使い方に慣れると、コードが非常に読みやすくなり、次元の管理が容易になります。特に複雑な次元の操作が必要な場合や、テンソルの形状が変更される場合に効果的です。この表記法についてのチュートリアルがありますので、興味があればそちらを参照することをお勧めします。
課題1では、この表記法を使用する必要はありませんが、指針として提供していますので、習得する価値はあるでしょう。
5.3 次元名による明確なコード
次元名を使ったコードの明確化について、もう少し詳しく説明しましょう。Einsumの最大の利点の一つは、次元に明示的な名前をつけることでコードの可読性と保守性が大幅に向上することです。
従来のPyTorchコードでは、x.transpose(-2, -1)
のような表現を使いますが、これは読者にとって理解しづらいものです。-2や-1が何を意味するのか、コードを見ただけでは判断できません。コメントを追加したとしても、コードが変更された際にコメントが更新されない可能性があります。
一方、Einsumを使用すると、次元に「batch」「sequence」「hidden」といった意味のある名前をつけることができます。これにより、コードの意図が明確になり、誤りを防ぐことができます。例えば:
# 従来の方法
output = x @ w.transpose(-2, -1)
# Einsumを使用した方法
output = einsum('b s h1, h1 h2 -> b s h2', x, w)
従来の方法では、行列乗算と転置の組み合わせが何を意味するのか理解するのに時間がかかりますが、Einsumを使用した方法では、「バッチ、シーケンス、隠れ次元1」の入力テンソルと「隠れ次元1、隠れ次元2」の重みを掛け合わせ、「バッチ、シーケンス、隠れ次元2」の出力を得ることが一目でわかります。
特に複雑な次元の操作が必要な場合、Einsumの利点は顕著になります。例えば、複数のヘッドを持つアテンション機構を実装する場合、バッチ、シーケンス、ヘッド数、隠れ次元など多くの次元を管理する必要があります。Einsumを使用すると、これらの次元間の関係を明確に表現できます。
また、rearrange
関数と組み合わせることで、次元のリシェイプや再編成も直感的に表現できます。例えば、隠れ次元を複数のヘッドと小さな隠れ次元に分割する操作は次のようになります:
# 隠れ次元をヘッドと小さな隠れ次元に分割
x_reshaped = rearrange(x, 'b s (h d) -> b s h d', h=num_heads)
このコードは、元のテンソルの形状が「バッチ、シーケンス、(ヘッド×小さな隠れ次元)」であり、これを「バッチ、シーケンス、ヘッド、小さな隠れ次元」に変換することを明確に示しています。
Einsumと次元名を使用することで、テンソル操作の意図が明確になり、コードの読みやすさと保守性が向上します。また、次元に関する誤りを減らし、デバッグを容易にする効果もあります。これは特に複雑なモデルを構築する際に非常に価値のある利点です。
6. 計算コスト分析
6.1 フロップの定義と単位
では、テンソル操作の計算コストについて話しましょう。様々な操作を導入しましたが、それらにどれだけのコストがかかるのでしょうか?
浮動小数点演算(floating-point operation)とは、加算や乗算などの浮動小数点数に対する操作のことです。フロップカウントにおいて主に重要になるのは、以下のような基本的な演算です。
ここで一つ私の個人的な不満を述べておきたいのですが、「フロップス(flops)」という言葉を使う場合、実際には何を意味しているのか不明確になることがあります。小文字の「s」で「flops」と書く場合、これは「floating point operations」(浮動小数点演算の数)を意味し、実行した計算量を測定します。一方、大文字の「S」で「FLOPS」と書く場合、これは「floating points per second」(1秒あたりの浮動小数点演算数)を意味し、ハードウェアの速度を測定するために使用されます。
このクラスでは混乱を避けるため、大文字の「S」は使用せず、代わりに「/s」を付けて1秒あたりを示すことにします。
フロップスについてのいくつかの直感を得るために、いくつかの数字を見てみましょう。GPT-3のトレーニングには約3.23×10^23フロップスを要しました。GPT-4は2×10^25フロップスだったと推測されています。米国の行政命令では、1.26×10^26フロップスを超える基盤モデル(foundation model)は政府に報告する必要があるとされていましたが、これは現在取り消されています。しかし、EUのAI法ではまだ1×10^25という基準が維持されており、これは取り消されていません。
ハードウェアのパフォーマンスを見ると、A100は1秒あたり312テラフロップスのピークパフォーマンスを持ち、H100はスパーシティ(疎性)を活用した場合1秒あたり1,979テラフロップス、スパーシティなしでは約50%の性能を発揮します。
NVIDIAの仕様書を見ると、フロップス性能は実行する計算の種類によって大きく異なることがわかります。FP32(単精度浮動小数点)を使用すると、性能は著しく低下します。FP16やBF16を使用すると大幅に改善し、FP8ならさらに高速になります。ただし、ここには「スパーシティあり」という注釈があることに注意してください。このクラスで扱う行列は多くの場合密(dense)なので、実際には記載されている数値の正確に半分の性能しか得られません。
これらの知識を基に、例えば8台のH100を2週間使用した場合のフロップス数を計算してみましょう。これは「8 × フロップス/秒 × 2週間の秒数」で、約4.7×10^21フロップスになります。この数値を他のモデルのフロップスカウントと比較することで、計算の規模を把握することができます。
6.2 行列乗算のコスト計算(2×次元の積)
簡単な例を通して、行列乗算のフロップ数を計算する方法を説明しましょう。トランスフォーマーモデルには触れませんが、線形モデルでも多くの基本的な構成要素と直感を得ることができます。
n個のデータポイントがあり、各ポイントはd次元のベクトルであると仮定しましょう。線形モデルは単にこのd次元ベクトルをk次元ベクトルに変換します。具体的には、データポイント数をB、次元をD、出力の次元数をKとしましょう。そして、データ行列X、重み行列W、そして線形モデルを定義します。
さて、この計算にはどれだけのフロップが必要でしょうか?行列乗算を行う際、各j,k,i(インデックス)の組み合わせに対して、2つの数値を掛け合わせ、その結果を合計に加える必要があります。つまり、関与するすべての次元の積の2倍のフロップが必要になります。左側の次元、中間の次元、右側の次元の積の2倍です。
行列乗算のフロップ数を計算する一般的なルールは、「2 × 3つの次元の積」です。この場合、「2 × B × D × K」となります。このルールは覚えておくと便利でしょう。
他の操作のフロップ数は通常、行列やテンソルのサイズに対して線形(比例)です。一般に、十分大きな行列では、ディープラーニングで遭遇する他のどの操作も行列乗算ほど計算コストが高くなることはありません。
もちろん、行列が十分に小さい場合は、他の操作のコストが支配的になる可能性がありますが、それは通常、良いレジームではありません。ハードウェアは大きな行列乗算向けに設計されているからです。少し循環的な論理かもしれませんが、実質的には行列乗算が支配的なコストになるようなモデルのみを考慮することになります。
この「3つの次元の積の2倍」という数字は非常に役立つものです。アルゴリズムや行列の形状によって多少最適化の余地はありますが、フロップ数の概算としては適切な大きさと言えるでしょう。
また、この計算では加算と乗算は同等とみなされています。行列乗算は単なる数学的な計算ですが、機械学習の文脈で解釈するなら、Bはデータポイント(トークン)の数、D×Kはパラメータ数と考えることができます。このモデルの順伝播に必要なフロップ数は「2 × トークン数 × パラメータ数」となります。
この関係は、実はトランスフォーマーモデルにも一般化できます。シーケンス長やその他の要素によって多少の違いはありますが、基本的な関係は維持されます。このシンプルな関係性を理解することは、モデルのコンピュテーショナルコストを概算する上で非常に有用です。
6.3 GPUの理論的ピーク性能
行列乗算のフロップ数を計算できるようになりましたが、これが実際の実行時間にどのように変換されるのかを理解する必要があります。結局のところ、実際に気にするのは、モデルの実行にどれだけ待たなければならないかということです。
実際に行列乗算の処理時間を測定してみましょう。この関数は単に行列乗算を5回実行し、その実行時間を測定します。この例では、行列乗算に約0.16秒かかっています。1秒あたりのフロップ数(実際のフロップ数を実行時間で割ったもの)は約5.4×10^13フロップス/秒です。
これをA100やH100のマーケティング資料と比較してみましょう。仕様書を見ると、フロップス/秒はデータ型に依存することがわかります。H100のFP32(単精度浮動小数点)での約束されたフロップス/秒は約67テラフロップス/秒です。
ここで役立つ概念がMFU(Model FLOPs Utilization、モデルフロップ使用率)です。これは、実際のフロップス/秒を約束されたフロップス/秒で割ったものです。つまり、モデルに有用な浮動小数点演算の実際の数を、実行にかかった時間で割り、これを約束されたフロップス/秒(光沢のあるパンフレットからの数値)で割ります。
この例では、MFUは約0.8(80%)です。通常、0.5(50%)以上のMFUは良好と見なされ、MFUが5%程度だと非常に悪いと考えられます。通常、通信やオーバーヘッドなどのすべての要素を考慮すると、90%や100%に近づくことはできません。これは、フロップの純粋な計算のみを考慮しているからです。
MFUは通常、行列乗算が計算全体を支配している場合に高くなります。より小さな行列や他の操作が多い場合、効率は低下することがあります。
また、BF16を使用して同じ計算を行うと、処理時間が0.16秒から0.03秒に短縮され、実際のフロップス/秒が大幅に向上します。しかし、スパーシティを考慮しても約束されたフロップス/秒は依然として非常に高いため、BF16のMFUは実際には低くなっています。これは意外に低いように見えるかもしれませんが、約束されたフロップス/秒は時に楽観的すぎることがあります。
このように、コードのベンチマークを常に行い、特定のパフォーマンスレベルを当然のように期待しないことが重要です。実際のハードウェア性能は、理論値と大きく異なる場合があります。
要約すると、行列乗算は計算コストの大部分を占め、一般的な経験則としては、フロップ数は次元の積の2倍です。1秒あたりのフロップ数(フロップス/秒)はハードウェアとデータ型に依存し、高性能なハードウェアを使用するほど、また小さなデータ型を使用するほど、通常は高速になります。MFUは、ハードウェアをどれだけ効率的に使用しているかを評価するのに役立つ概念です。
7. モデルフロップ使用率(MFU)
7.1 実際のフロップ数と理論的フロップ数の比率
モデルフロップ使用率(MFU)は、GPUの実際の性能を評価するための重要な指標です。MFUは、実際に達成されたフロップス/秒を、ハードウェアの理論的なピークパフォーマンス(約束されたフロップス/秒)で割った値です。
MFUの計算方法を具体的に見てみましょう:
MFU = 実際のフロップス/秒 ÷ 約束されたフロップス/秒
ここで「実際のフロップス/秒」とは、モデルが実際に実行した有用な浮動小数点演算の数を実行時間で割ったものです。例えば、先ほどの行列乗算の例では、計算に必要なフロップ数を実行にかかった時間(0.16秒)で割ると、約5.4×10^13フロップス/秒になりました。
「約束されたフロップス/秒」とは、ハードウェアの仕様書に記載されている理論的なピーク性能です。H100のFP32での約束されたフロップス/秒は約67テラフロップス/秒です。
これらの値を使ってMFUを計算すると:
MFU = 5.4×10^13 ÷ 67×10^12 ≈ 0.8
つまり、この特定の計算では、H100の理論的なピーク性能の約80%を達成しています。これは非常に良い利用率です。一般的に、MFUが0.5(50%)以上であれば良好と見なされます。0.5未満、特に0.1(10%)未満の場合は、ハードウェアを効率的に使用できていないことを示しています。
MFUが重要な理由は、単純にハードウェアの仕様を見るだけでは、実際のパフォーマンスを予測するのが難しいからです。理論的なピーク性能は、理想的な条件下でのみ達成できる数値であり、実際のアプリケーションでは通信オーバーヘッド、メモリアクセスのレイテンシ、およびその他の要因により、完全に達成することはほぼ不可能です。
また、MFUは同じハードウェア上で異なるモデルアーキテクチャやトレーニング手法のパフォーマンスを比較するのにも役立ちます。例えば、あるアーキテクチャが0.7のMFUを達成し、別のアーキテクチャが0.4のMFUしか達成できない場合、最初のアーキテクチャの方がハードウェアをより効率的に使用していることがわかります。
ただし、MFUは単なる効率の指標であり、タスクに対するモデルの実際の有効性を示すものではないことに注意が必要です。より効率的なモデルが常により良い結果を生み出すわけではありません。しかし、大規模なモデルトレーニングでは、計算効率が直接的にコストに影響するため、高いMFUを達成することは重要な目標となります。
7.2 データ型によるパフォーマンスの違い
データ型の選択がパフォーマンスに与える影響は非常に大きいです。先ほど見たように、FP32(単精度浮動小数点)とBF16(Brain Float 16)では実行時間に大きな違いがありました。同じ行列乗算操作をBF16で実行すると、処理時間が0.16秒から0.03秒に短縮され、約5倍の高速化が実現しました。
この違いはなぜ生じるのでしょうか?それは、現代のGPUアーキテクチャがより小さなデータ型に最適化されているからです。NVIDIAの仕様書を見てみると、H100やA100などのGPUが提供するピークパフォーマンスはデータ型によって大きく異なることがわかります。
例えば、H100のFP32でのピーク性能は約67テラフロップス/秒ですが、BF16やFP16では約束されたフロップス/秒が1,000テラフロップス/秒を超えます。FP8を使用するとさらに高速になり、場合によっては2,000テラフロップス/秒に近づきます。
このように、データ型を変更するだけで、理論的なピーク性能が10倍以上向上することもあります。ただし、これはスパーシティを考慮した場合の数値であり、実際の密な行列演算では約半分の性能になることを忘れないでください。
また、実際のMFUの計算においても、データ型によって大きな違いが生じます。例えば、BF16を使用した場合のMFUは、FP32を使用した場合よりも低くなる傾向があります。これは、BF16の約束されたフロップス/秒が非常に高いため、実際の計算速度がその理論値に追いつくのが難しいからです。
一般的な傾向として、より小さなデータ型(BF16、FP8など)を使用すると、計算は高速化されますが、数値的な精度は低下します。このトレードオフをどう扱うかは、モデルの要件によって異なります。トレーニングの初期段階では高い精度が必要かもしれませんが、後半では低精度で十分かもしれません。
また、モデルの異なる部分で異なるデータ型を使用する「混合精度トレーニング」も一般的なアプローチです。例えば、数値的に敏感なアテンション層ではFP32を使用し、その他の層ではBF16を使用するという方法があります。
データ型の選択は、計算パフォーマンス、メモリ使用量、数値的安定性のバランスを取る重要な決断です。効率的なディープラーニングシステムを構築するためには、これらのトレードオフを理解し、適切なデータ型を選択することが不可欠です。
7.3 ベンチマークの重要性
理論的なパフォーマンス指標だけに頼らず、実際にコードのベンチマークを取ることの重要性について強調しておきたいと思います。これまで見てきたように、理論値と実際のパフォーマンスには大きな乖離がある場合があります。
BF16を使用した例では、MFUが予想よりも低くなりました。これは「驚くほど低い」結果かもしれませんが、仕様書に記載されている約束されたフロップス/秒が時として楽観的すぎることを示しています。マーケティング資料と実際のパフォーマンスには差があるのが現実です。
なぜベンチマークが重要なのでしょうか?いくつかの理由があります:
- ハードウェアの特性の違い: 同じモデル名のGPUでも、製造ロットやマイクロアーキテクチャの微妙な違いにより、パフォーマンスに差が出ることがあります。
- システム構成の影響: GPUだけでなく、CPUの性能、メモリ帯域幅、PCIeレーンの数など、システム全体の構成がパフォーマンスに影響します。
- ソフトウェアスタックの最適化: CUDA、cuDNN、PyTorchなどのバージョンによって、同じハードウェア上でも性能が異なります。
- モデル固有の特性: 同じパラメータ数でも、モデルのアーキテクチャによって計算パターンが異なり、GPUの利用効率に影響を与えます。
- バッチサイズとシーケンス長の影響: これらの設定が変わると、メモリアクセスパターンやキャッシュ効率が変化し、パフォーマンスに大きな影響を与えます。
したがって、理論的な計算や仕様書の数値はあくまで出発点と考え、実際のワークロードで必ずベンチマークを取ることが重要です。特に大規模なモデルトレーニングでは、小さな効率の違いが大きなコスト差につながるため、この作業は非常に価値があります。
ベンチマークを取る際には、以下のようなポイントに注意すると良いでしょう:
- 十分な回数(通常は数回から数十回)測定し、平均と分散を計算する
- ウォームアップ実行を含めて、キャッシュ効果を排除する
- 実際のワークロードに近い条件で測定する
- 可能であれば、異なる設定(バッチサイズ、シーケンス長など)で測定する
最終的に、理論値だけでなく実際のベンチマーク結果に基づいて、モデルのトレーニング時間や必要なリソースを見積もることで、より正確な計画が可能になります。効率的なリソース利用につながるだけでなく、予期せぬパフォーマンスの問題を事前に特定することができます。
8. 勾配計算のコスト
8.1 順伝播と逆伝播のフロップ計算
これまでは主に行列乗算や前方向のパスのフロップ数について検討してきましたが、勾配を計算するコストも理解する必要があります。なぜなら、トレーニング時には順伝播だけでなく、逆伝播も行うからです。
簡単な例として、線形モデルの予測と平均二乗誤差(MSE)を考えてみましょう。非常に興味深い損失関数ではありませんが、勾配計算のフロップ数を理解するには十分です。
順伝播では、入力テンソルXと重みWを用いて、線形積(行列乗算)によって予測を計算し、その後、損失を計算します。逆伝播では、単にloss.backward()
を呼び出すと、自動的に勾配が計算されます。この場合、テンソルに付随する勾配変数(grad
)が、私たちが求めるものになります。
PyTorchでの勾配計算は皆さんも経験があるはずですが、勾配計算に必要なフロップ数を詳しく見ていきましょう。
もう少し複雑なモデル、2層の線形ネットワークを考えてみます。入力Xは次元B×Dで、最初の重み行列W1はD×D、次の重み行列W2はD×Kです。まず隠れ層の活性化H1を計算し、それを使って次の層の出力H2を計算し、最後に損失を計算します。
順伝播のフロップ数を計算すると、W1に関連する行列乗算は「2×B×D×D」フロップが必要で、W2に関する行列乗算は「2×B×D×K」フロップが必要です。つまり、合計で「2×(パラメータ数)×(データポイント数)」フロップが必要になります。
逆伝播はもう少し複雑です。モデルはX→H1→H2→損失という流れですが、逆伝播では様々な勾配(∂損失/∂H2、∂損失/∂W2、∂損失/∂H1、∂損失/∂W1など)を計算する必要があります。
まず、W2に関する勾配を見てみましょう。連鎖律によると、∂損失/∂W2は基本的にH1と∂損失/∂H2の積になります。これは行列乗算に似た演算で、「2×B×D×K」フロップが必要です。
次に、H1に関する勾配も計算する必要があります。これはW2と∂損失/∂H2の積になり、同様に「2×B×D×K」フロップが必要です。
W2に関する計算だけで、合計「4×B×D×K」フロップが必要です。同様に、W1に関する勾配計算も行うと、「4×B×D×D」フロップが必要になります。
全体を合わせると、逆伝播のフロップ数は「4×(パラメータ数)×(データポイント数)」となります。順伝播と逆伝播を合わせると、「6×(パラメータ数)×(データポイント数)」というトータルのフロップ数になります。
これが、冒頭で示した「6×パラメータ数×トークン数」という公式の由来です。単純な線形モデルでこの関係が成り立ち、多くのモデルでもこの比率が大まかに当てはまります。特に、パラメータシェアリングなどの特殊なケースを除けば、ほとんどのモデルでこの関係が成り立ちます。
8.2 勾配計算の分解(2層線形ネットワークの例)
2層線形ネットワークの勾配計算を詳細に分解していきましょう。この部分はやや複雑になりますが、勾配計算のフロップ数がどこから来るのかを理解するために重要です。
まず、モデルを再確認しましょう。入力Xは次元B×Dで、これに第1層の重み行列W1(D×D)を掛けて隠れ層の活性化H1を得ます。次にH1に第2層の重み行列W2(D×K)を掛けて出力H2を得て、最後に損失を計算します。
逆伝播では、以下の勾配を計算する必要があります:
- ∂損失/∂H2(出力に対する損失の勾配)
- ∂損失/∂W2(第2層の重みに対する損失の勾配)
- ∂損失/∂H1(隠れ層の活性化に対する損失の勾配)
- ∂損失/∂W1(第1層の重みに対する損失の勾配)
まず、W2に対する勾配を計算してみましょう。連鎖律によると: ∂損失/∂W2 = ∑(H1 × ∂損失/∂H2)
ここで、H1はB×D次元のテンソルで、∂損失/∂H2はB×K次元のテンソルです。この計算は行列乗算に似ており、「2×B×D×K」フロップが必要です。
次に、H1に対する勾配を計算します。これは逆伝播を続けるために必要です: ∂損失/∂H1 = ∂損失/∂H2 × W2^T
ここで、∂損失/∂H2はB×K次元、W2^TはK×D次元です。これも行列乗算のような演算で、同じく「2×B×D×K」フロップが必要です。
したがって、W2に関連する勾配計算だけで、合計「4×B×D×K」フロップが必要になります。
同様の計算をW1についても行います。W1に対する勾配は: ∂損失/∂W1 = ∑(X × ∂損失/∂H1)
XはB×D次元、∂損失/∂H1はB×D次元です。この計算には「2×B×D×D」フロップが必要です。
また、入力Xに対する勾配も計算する必要があります(ただし、これは必ずしも保存される必要はありません): ∂損失/∂X = ∂損失/∂H1 × W1^T
これにも「2×B×D×D」フロップが必要です。
よって、W1に関連する勾配計算には合計「4×B×D×D」フロップが必要です。
全体を合わせると:
- W2関連:4×B×D×K フロップ
- W1関連:4×B×D×D フロップ
この合計は「4×B×(D×K + D×D)」となり、これは「4×B×(パラメータ数)」、つまり「4×(データポイント数)×(パラメータ数)」と表現できます。
これを順伝播のフロップ数「2×(データポイント数)×(パラメータ数)」と合わせると、トータルで「6×(データポイント数)×(パラメータ数)」のフロップが必要になります。
より視覚的に理解するためのアニメーションもありますが、要点は、順伝播は2倍の係数、逆伝播は4倍の係数を持ち、合計で6倍の係数になるということです。これは連鎖律を用いた勾配計算の詳細な分析から導かれます。
「6×パラメータ数×データポイント数」というこの関係は、単純な線形モデルだけでなく、より複雑なモデル(トランスフォーマーなど)にも概ね当てはまります。これが、モデルトレーニングのコスト見積もりに使用される基本的な公式です。
8.3 合計コスト:6×パラメータ数×トークン数
これまでの計算を要約すると、モデルトレーニングの総計算コストは「6×パラメータ数×トークン数」という式で表すことができます。この「6」という係数はどこから来るのでしょうか?
先ほど見たように、この係数は順伝播と逆伝播の計算コストを合わせたものです:
- 順伝播:2×パラメータ数×トークン数
- 逆伝播:4×パラメータ数×トークン数
- 合計:6×パラメータ数×トークン数
この関係は単純な線形モデルに基づいていますが、実際には多くのモデルでこの比率が概ね当てはまります。これは、ほとんどのニューラルネットワークの計算の大部分が行列乗算で占められ、各計算が基本的に新しいパラメータに触れる構造になっているためです。
もちろん、この関係が常に厳密に成り立つわけではありません。例えば、パラメータシェアリングを使用して1つのパラメータで数十億のフロップを実行するようなモデルも設計できますが、一般的な深層学習モデルではそのような極端なケースは稀です。
この「6×」という係数は、特にトランスフォーマーベースの言語モデルのようなアーキテクチャにも適用できます。トランスフォーマーはより複雑ですが、基本的な計算パターンは同様で、主要な計算コストは行列乗算から来ています。
ただし、トランスフォーマーの場合は注意点があります。シーケンス長に依存する計算(特にアテンションメカニズム)があるため、シーケンス長が非常に長い場合、この関係は完全には当てはまらなくなります。アテンションのコストはシーケンス長の2乗に比例するため、長いシーケンスでは追加のオーバーヘッドが生じます。しかし、多くの実用的なケースでは、この「6×」という近似は依然として有効です。
この関係式は、モデルトレーニングのリソース要件を素早く見積もるための強力なツールです。例えば、あるモデルのトレーニングにかかる時間やコストを概算したり、特定のハードウェア構成で訓練可能な最大モデルサイズを推定したりするのに役立ちます。
また、この関係式は効率化の余地を特定するのにも役立ちます。例えば、勾配計算(4×の部分)がトレーニングコストの大部分を占めていることから、勾配圧縮や効率的な勾配通信などの技術が特に重要であることがわかります。
最終的に、「6×パラメータ数×トークン数」という式は、ディープラーニングモデルのトレーニングコストを理解するための基本的な目安となります。この関係を理解することで、リソースを効率的に計画し、コストを最小限に抑えながら最大の成果を得ることができます。
9. モデル構築の基礎
9.1 パラメータ初期化(Xavier初期化)
次に、モデル構築の基礎に移りましょう。PyTorchではパラメータはnn.Parameter
オブジェクトとして保存されます。
パラメータ初期化について考えてみましょう。例えば、入力次元と隠れ次元を持つパラメータWがあるとします。これは線形モデルの場合です。入力を生成し、それをモデルに通してみましょう。
ランダムな正規分布(ガウス分布)からパラメータを初期化すると、一見問題なさそうに見えますが、出力を見ると非常に大きな値が出てきます。これは、重みの値が隠れ次元の平方根に比例して大きくなるためです。大規模なモデルでは、これにより値が爆発し、トレーニングが非常に不安定になる可能性があります。
そこで、隠れ次元に依存せず、少なくとも爆発しないことが保証される方法でパラメータを初期化したいと考えます。一つの簡単な方法は、単に入力次元の平方根で割ることです。
パラメータWを入力次元の平方根で割って再定義し、モデルに通してみると、出力は安定して0周辺に集中します。これは実際には正規分布N(0,1)に近づきます。
これはディープラーニング文献で広く研究されており、定数を除けばXavier初期化として知られています。さらに安全を期すため、正規分布は無限の裾野を持つため、多くの場合は値を-3から3に切り詰めます。これにより、極端に大きな値が現れるのを防ぎます。
このように適切な初期化は、特に大規模なモデルでの安定したトレーニングに不可欠です。初期値が適切でないと、勾配消失や勾配爆発などの問題が発生し、モデルが収束しない、あるいは非常に遅い収束を示す可能性があります。
Xavier初期化の背後にある直感は、各層の出力の分散を一定に保つことです。これにより、ネットワークの深さに関わらず、信号が安定して伝播します。より具体的には、順伝播と逆伝播の両方で分散を保存するように重みを初期化します。
大規模言語モデルなどの深いネットワークでは、こうした初期化の選択が特に重要になります。適切な初期化は、トレーニングの初期段階での収束を大幅に速め、場合によっては収束そのものを可能にします。
また、モデルアーキテクチャによって最適な初期化方法が異なる場合があります。例えば、ReLUアクティベーションを使用するネットワークではKaiming初期化が一般的ですが、トランスフォーマーではまた異なる初期化方法が使われることもあります。
重要なのは、初期化が単なる実装の詳細ではなく、モデルの性能と安定性に大きく影響する重要な設計選択だということです。
9.2 カスタムモデルの実装
では、シンプルなモデルを構築してみましょう。このモデルは次元とレイヤー数を持ち、各レイヤーは行列乗算を行う線形モデルです。私はこれを「Cruncher」と名付けましたが、これは単なる深層線形ネットワークです。
このカスタムモデルは以下のような構造を持ちます:
- 指定された数のレイヤーを持つ
- 各レイヤーは線形変換(行列乗算)を行う
- すべての中間レイヤーはD×D次元の行列
- 最終レイヤー(ヘッド)はD次元の出力を生成
具体的なPyTorchでの実装例を見てみましょう。まず、モデルのパラメータを定義します:
- 最初のレイヤー:D×D次元の行列
- 2番目のレイヤー:D×D次元の行列
- ヘッド(最終レイヤー):D次元の出力を生成する行列
モデルのパラメータ数を計算すると、D^2 + D^2 + D = 2D^2 + D となります。これは予想通りの結果です。
このモデルをGPUに移動させることも重要です。計算を高速化するためです。ランダムなデータを生成し、モデルに通してみましょう。フォワードパスは単に各レイヤーを順番に適用し、最後にヘッドを適用するだけです。
ここで一般的な注意点として、乱数の取り扱いについて触れておきます。乱数は初期化、ドロップアウト、データの順序付けなど、多くの場所で使用されます。バグを再現しやすくするためにも、常に固定の乱数シードを使用することをお勧めします。
これにより、モデルや少なくともできる限り多くの要素を再現可能にします。特に、乱数の各ソースに異なるシードを使用すると便利です。これにより、例えば初期化を固定したままデータの順序を変えるなど、特定の要素だけを変更することができます。
デバッグ時には決定論的な振る舞いが非常に役立ちます。コード内では多くの場所で乱数が使用される可能性があるため、どの乱数生成器を使用しているかを常に意識し、安全のためにすべての乱数生成器のシードを設定することをお勧めします。
このようにして、基本的なカスタムモデルを実装することができます。実際のアプリケーションでは、より複雑な構造やレイヤー、活性化関数などを追加することになりますが、基本的な設計原則は同じです。特に、パラメータの初期化、GPUへの移動、乱数の管理などの側面は、どのようなモデルを構築する場合でも重要です。
9.3 ランダム性の管理
ランダム性の管理はモデル構築において重要な側面です。ランダム性はさまざまな場所に現れます:パラメータの初期化、ドロップアウト、データの順序付けなどです。特にバグを再現しようとしている場合、この不確定性が問題になることがあります。
ベストプラクティスとして、常に固定の乱数シードを使用することを強く推奨します。これにより、モデルの動作をできる限り再現可能にすることができます。特に、異なるランダム性のソースに対して異なるシードを使用することが効果的です。例えば:
# 一般的なPyTorchの乱数シードを設定
torch.manual_seed(42)
# CUDA操作の乱数シードを設定
torch.cuda.manual_seed(42)
# NumPyの乱数シードを設定
np.random.seed(43)
# データローダーの乱数シードを設定
dataloader = DataLoader(dataset, batch_size=32, shuffle=True,
generator=torch.Generator().manual_seed(44))
このように異なるソースに異なるシードを使用すると、例えば初期化を固定したままデータの順序だけを変えるなど、特定の要素のみを変更することができます。これは実験や分析において非常に有用です。
決定論的な振る舞いはデバッグ時に特に役立ちます。エラーが発生した場合、同じ条件を再現して問題を特定しやすくなります。また、モデルの性能評価においても、ランダム性を制御することで結果の信頼性が向上します。
コード内では様々な場所で乱数が使用される可能性があります。PyTorchにはtorch.random
、torch.cuda.random
、Pythonのrandom
モジュール、NumPyのnp.random
など、複数の乱数生成器があります。どの乱数生成器を使用しているかを常に意識し、安全のためにすべての乱数生成器のシードを設定することをお勧めします。
ディープラーニングの実験では、同じモデルアーキテクチャと同じデータセットであっても、異なる乱数シードで初期化すると異なる結果が得られることがあります。複数の実行で結果を平均化したり、異なるシードでの振る舞いの変動を分析したりすることで、モデルの堅牢性を評価することができます。
また、一部の操作(特にGPU上での並列計算)では完全な決定性を保証するのが難しい場合があります。PyTorchではtorch.backends.cudnn.deterministic = True
を設定することで、CUDNNのアルゴリズムを決定論的なものに制限できますが、これはパフォーマンスを犠牲にする可能性があります。
ランダム性の管理は単なる技術的な詳細ではなく、再現可能な科学と効率的なデバッグの基盤となる重要な側面です。特に大規模モデルのトレーニングでは、時間とリソースの制約から、問題を効率的に診断し解決するために決定論的な振る舞いが不可欠です。
10. 最適化アルゴリズム
10.1 SGD、Momentum、AdaGrad、RMSProp、Adamの比較
最適化アルゴリズムについて説明しましょう。モデルを定義したら、次はそれをトレーニングするためのオプティマイザが必要になります。様々なオプティマイザがありますので、それぞれの直感的な理解を提供します。
まず、確率的勾配降下法(SGD)があります。これは最も基本的なアプローチで、バッチの勾配を計算し、その方向に単純にステップを進めます。疑問なく、直接的なアプローチです。
次に、モーメンタム付きSGDがあります。これは古典的な最適化からの考え方で、Nesterovに由来します。勾配の実行平均(running average)を保持し、瞬間的な勾配ではなく、その実行平均に対して更新を行います。これにより、最適化の軌跡がより滑らかになり、局所的な最小値を回避しやすくなります。
AdaGradは、勾配の大きさに基づいてパラメータごとに学習率を調整します。具体的には、勾配の二乗の累積値で学習率をスケーリングします。これにより、頻繁に更新されるパラメータには小さな学習率が、めったに更新されないパラメータには大きな学習率が適用されます。
RMSPropはAdaGradの改良版で、単純な平均ではなく、指数移動平均(exponential averaging)を使用します。これにより、最近の勾配に重みを置くことができ、学習率の減衰が速すぎるというAdaGradの問題を軽減します。
最後に、2014年に登場したAdamは、基本的にRMSPropとモーメンタムを組み合わせたものです。勾配の一次モーメント(平均)と二次モーメント(分散)の両方を維持し、それに基づいてパラメータを更新します。つまり、勾配の実行平均とその二乗の実行平均の両方を保持します。
これらのオプティマイザの進化を見ると、一貫したパターンがあります。基本的なSGDから始まり、安定性と収束速度を向上させるために、より洗練された統計的手法が導入されてきました。Adamは現在、多くのディープラーニングアプリケーションでデフォルトのオプティマイザとなっていますが、特定のタスクや状況では他のオプティマイザが有利な場合もあります。
例えば、SGDは計算コストが低く、メモリ要件も少ないため、非常に大規模なモデルに適しています。一方、Adamはより高速な収束を提供しますが、より多くのメモリを必要とします。RMSPropは、再発的なニューラルネットワークなど、勾配が非常に不安定な場合に特に有効です。
最適なオプティマイザの選択は、モデルのアーキテクチャ、データセットの性質、計算リソースの制約など、多くの要因に依存します。よく使われる経験則としては、まずAdamで試し、問題が発生した場合、または特定の要件がある場合に他のオプションを検討するというものがあります。
課題1ではAdamを実装することになりますが、オプティマイザの仕組みをより深く理解するために、まずはAdaGradの実装を見ていきましょう。
10.2 最適化アルゴリズムの実装(AdGradの例)
課題1でAdamオプティマイザを実装することになっていますが、ここではより単純なAdaGradオプティマイザの実装を通じて、PyTorchでオプティマイザを実装する方法を見ていきましょう。
PyTorchでオプティマイザを実装するには、torch.optim.Optimizer
クラスを継承し、必要なメソッドをオーバーライドします。特に重要なのはstep
メソッドで、ここに最適化アルゴリズムの核心部分を実装します。
まずデータを定義し、フォワードパスで損失を計算し、勾配を計算した後、optimizer.step()
を呼び出します。このstep
メソッドの中で、オプティマイザが実際に働きます。
AdaGradオプティマイザの実装を見てみましょう。実装の構造は以下のようになります:
class AdaGrad(torch.optim.Optimizer):
def __init__(self, params, lr=1e-2, eps=1e-10):
defaults = dict(lr=lr, eps=eps)
super(AdaGrad, self).__init__(params, defaults)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# 勾配の取得
grad = p.grad.data
# オプティマイザ状態の取得または初期化
state = self.state[p]
if len(state) == 0:
state['G'] = torch.zeros_like(p.data)
# 勾配の二乗を累積
G = state['G']
G.add_(grad * grad)
# パラメータの更新
p.data.addcdiv_(value=-group['lr'], tensor1=grad,
tensor2=torch.sqrt(G) + group['eps'])
return loss
この実装の中で、パラメータはグループ化されています。例えば、レイヤー0、レイヤー1、最終的な重みなどのグループがあります。各パラメータには状態が関連付けられており、これはパラメータからオプティマイザの状態へのディクショナリとして保存されます。
AdaGradでは、勾配の二乗の累積値(G
)を状態として保存します。この値は各ステップで更新され、勾配の二乗を累積していきます。そして、この累積値の平方根で学習率をスケーリングすることで、パラメータごとに適応的な学習率を実現します。
パラメータの更新式は以下のようになります:
パラメータ -= 学習率 * 勾配 / (√(G) + eps)
ここでeps
は小さな値で、ゼロ除算を防ぐために使用されます。
この実装では、オプティマイザの状態は複数のステップにわたって保持されます。これにより、過去の勾配情報を活用して、より効果的な最適化が可能になります。
また、オプティマイザのステップ終了時にはメモリを解放することも可能です。これは特にモデル並列化について話すときに重要になります。
この実装例から、PyTorchでのオプティマイザの基本的な構造と、特に状態の管理方法がわかります。Adamなど他のオプティマイザも同様の構造を持ちますが、状態の更新方法やパラメータの更新式が異なります。
10.3 オプティマイザの状態管理
オプティマイザの状態管理は、効率的かつ効果的なモデルトレーニングにおいて重要な側面です。先ほどのAdaGradの例で見たように、最適化アルゴリズムは単にパラメータを更新するだけでなく、様々な状態情報を保持し管理する必要があります。
PyTorchでは、オプティマイザの状態はself.state
ディクショナリに保存されます。このディクショナリはパラメータテンソルをキーとし、そのパラメータに関連する状態情報を値として持ちます。例えば、AdaGradでは各パラメータに対して勾配の二乗の累積値を保存します。
state = self.state[p]
if len(state) == 0:
state['G'] = torch.zeros_like(p.data)
この状態管理の重要な点は、状態がステップ間で持続することです。つまり、状態は一度初期化されると、モデルのトレーニング全体を通じて更新され続けます。これにより、オプティマイザは過去の勾配情報を「記憶」し、より賢明なパラメータ更新を行うことができます。
異なるオプティマイザは異なる状態情報を保持します:
- SGDはモーメンタムを使用する場合、速度(velocity)を保存します
- AdaGradは勾配の二乗の累積値を保存します
- RMSPropは勾配の二乗の移動平均を保存します
- Adamは一次モーメント(勾配の移動平均)と二次モーメント(勾配の二乗の移動平均)の両方を保存します
これらの状態情報はメモリ使用量に直接影響します。例えば、Adamは各パラメータに対して2つのテンソルを保存するため、SGDよりも2倍のメモリを必要とします。大規模モデルの場合、これは重要な考慮事項になります。
オプティマイザの状態を管理する上で考慮すべき重要な側面がいくつかあります:
- メモリ効率: 状態テンソルはパラメータと同じサイズを持つため、特に大規模モデルでは大量のメモリを消費します。必要に応じて、混合精度トレーニングを使用して状態テンソルの精度を下げることができます。
- チェックポイント保存と読み込み: トレーニングを中断して後で再開する場合、モデルパラメータだけでなく、オプティマイザの状態も保存する必要があります。そうしないと、最適化の軌跡が乱れ、パフォーマンスに影響する可能性があります。
# チェックポイントの保存
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'iteration': iteration
}, 'checkpoint.pth')
# チェックポイントの読み込み
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
iteration = checkpoint['iteration']
- デバイス間の移動: 計算デバイス(CPU/GPU)を切り替える場合、オプティマイザの状態も適切に移動する必要があります。
- 勾配のクリッピングと正規化: 状態更新の前に勾配をクリッピングまたは正規化することで、トレーニングの安定性を向上させることができます。
- 学習率スケジューリング: 多くの場合、オプティマイザの基本的な動作を補完するために、学習率スケジューラを使用します。これにより、トレーニングの進行に応じて学習率を調整できます。
効果的なオプティマイザの状態管理は、トレーニングの安定性、収束速度、およびメモリ効率に直接影響します。特に大規模言語モデルのようなリソース集約型のモデルでは、これらの考慮事項が成功と失敗の差を分けることがあります。
11. 総合的リソース要件
11.1 パラメータ、活性化、勾配、オプティマイザ状態のメモリ要件
ここまで個別の要素について説明してきましたが、ここではモデルトレーニングの総合的なメモリ要件について考えていきましょう。特に、オプティマイザの状態に関するメモリ要件も含めて、全体像を把握することが重要です。
モデルトレーニングにおけるメモリ使用量は主に以下の4つの要素から構成されます:
- パラメータ: モデルの重みとバイアス
- 活性化: フォワードパス中の中間出力
- 勾配: バックワードパス中に計算される勾配
- オプティマイザ状態: オプティマイザが保持する追加情報
先ほど構築した2層線形ネットワークモデルの例で、これらのメモリ要件を具体的に計算してみましょう。
パラメータのメモリ要件: このモデルのパラメータ数は次の通りです:
- 各層のサイズがD×Dの行列(レイヤー数をLとする)
- 最終ヘッドがDの出力を生成
したがって、総パラメータ数は: パラメータ数 = D² × L + D
活性化のメモリ要件: 活性化のメモリ要件は、バッチサイズ(B)、入力次元、および層の数に依存します:
- 各層の活性化は次元B×D
- 層の数はL
したがって、総活性化数は: 活性化数 = B × D × L
バッチ内の各データポイントの各次元について、各層の活性化値を保存する必要があります。
勾配のメモリ要件: 勾配のメモリ要件はパラメータ数と同じです: 勾配数 = D² × L + D
各パラメータに対して1つの勾配値が必要です。
オプティマイザ状態のメモリ要件: オプティマイザ状態のメモリ要件はオプティマイザの種類によって異なります。AdaGradの場合: オプティマイザ状態数 = D² × L + D
AdaGradでは、各パラメータに対して勾配の二乗の累積値(G)を1つ保存します。
合計メモリ要件: FP32(float32)を使用した場合、各値は4バイトです。したがって、合計メモリ要件は:
合計メモリ(バイト) = 4 × (パラメータ数 + 活性化数 + 勾配数 + オプティマイザ状態数)
= 4 × [(D² × L + D) + (B × D × L) + (D² × L + D) + (D² × L + D)]
具体的な数値(D=32、L=2、B=64など)を代入すると、必要なメモリ量を正確に計算できます。
Adamを使用する場合は、オプティマイザ状態が2倍になることを考慮する必要があります。Adamは一次モーメントと二次モーメントの両方を保存するためです。
このように、トレーニング時のメモリ要件を正確に把握することは、特に限られたGPUメモリで大規模モデルをトレーニングする際に非常に重要です。リソース要件を事前に計算することで、メモリ不足エラーを避け、利用可能なハードウェアに合わせてバッチサイズや精度を調整することができます。
課題1では、トランスフォーマーモデルに対して同様の計算を行うことになります。トランスフォーマーはより複雑で、様々な行列乗算やアテンションメカニズムなど多くの要素がありますが、計算の一般的な形式は同じです。パラメータ、活性化、勾配、オプティマイザ状態のメモリ要件を特定し、それらを合計します。
11.2 トータルメモリ使用量の計算
先ほど説明した各コンポーネントのメモリ要件を具体的な数値で計算し、トータルのメモリ使用量を算出してみましょう。これにより、モデルが実際にどれだけのメモリを必要とするのかを正確に把握できます。
私たちの2層線形ネットワークモデルの例で、具体的なパラメータを設定します:
- 次元 D = 8
- レイヤー数 L = 2
- バッチサイズ B = 4
このモデルのメモリ要件を計算していきます。
パラメータのメモリ要件: パラメータ数 = D² × L + D = 8² × 2 + 8 = 128 + 8 = 136
活性化のメモリ要件: 活性化数 = B × D × L = 4 × 8 × 2 = 64
勾配のメモリ要件: 勾配数 = パラメータ数 = 136
オプティマイザ状態のメモリ要件: AdaGradの場合、オプティマイザ状態数 = パラメータ数 = 136
合計メモリ要件: FP32(float32)を使用する場合、各値は4バイトです。
合計メモリ(バイト) = 4 × (パラメータ数 + 活性化数 + 勾配数 + オプティマイザ状態数) = 4 × (136 + 64 + 136 + 136) = 4 × 472 = 1,888バイト(約1.9KB)
この計算から、このシンプルなモデルは約1.9KBのメモリを必要とすることがわかります。これは非常に小さな値です。
しかし、実際のディープラーニングモデルでは、次元とレイヤー数が桁違いに大きくなります。例えば:
- 次元 D = 1024
- レイヤー数 L = 12
- バッチサイズ B = 32
この場合、計算は次のようになります:
パラメータ数 = 1024² × 12 + 1024 = 12,583,936 活性化数 = 32 × 1024 × 12 = 393,216 勾配数 = 12,583,936 オプティマイザ状態数 = 12,583,936(AdaGrad)
合計メモリ(バイト) = 4 × (12,583,936 + 393,216 + 12,583,936 + 12,583,936) = 4 × 38,145,024 = 152,580,096バイト(約145.5MB)
このケースでも、まだ比較的小さなモデルですが、必要なメモリは約145.5MBです。
実際の大規模言語モデルでは、次元が数千から数万、レイヤー数が数十、バッチサイズも大きくなるため、必要なメモリは数十GBから数百GBに達することもあります。例えば:
- 次元 D = 4096
- レイヤー数 L = 24
- バッチサイズ B = 256
このような設定では、メモリ要件は簡単に10GB以上になります。
さらに、このシンプルな計算では考慮していない追加のメモリオーバーヘッドも存在します:
- アテンションメカニズム(トランスフォーマーモデルの場合)
- キャッシュと一時的なテンソル
- PyTorchやCUDAのランタイムオーバーヘッド
- フラグメンテーション(メモリの断片化)
実際のモデルトレーニングでは、これらすべての要素を考慮する必要があります。そのため、理論的な計算よりも20-30%多くのメモリを予備として確保することが賢明です。
また、混合精度トレーニングを使用すると、一部のコンポーネントで低精度(例:BF16)を使用できるため、メモリ要件を大幅に削減できます。例えば、活性化とバックワード計算にBF16を使用し、パラメータとオプティマイザ状態にFP32を使用すると、メモリ使用量を30-40%削減できる可能性があります。
正確なメモリ使用量の計算により、限られたGPUリソースを最大限に活用し、メモリエラーなしに可能な限り大きなモデルをトレーニングすることができます。
11.3 トレーニングループの実装
ここまでモデルの構築とメモリ要件について説明しましたが、実際にモデルをトレーニングするためのトレーニングループの実装に進みましょう。典型的なトレーニングループは非常にシンプルで、以下の基本的なステップで構成されています。
# モデルの定義
model = Cruncher(dim=8, num_layers=2)
model.to(device) # GPUに移動
# オプティマイザの定義
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# データローダーの準備
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# トレーニングループ
num_epochs = 10
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(data_loader):
# データをGPUに移動
data, target = data.to(device), target.to(device)
# 勾配をゼロに初期化
optimizer.zero_grad()
# フォワードパス
output = model(data)
# 損失計算
loss = criterion(output, target)
# バックワードパス(勾配計算)
loss.backward()
# パラメータ更新
optimizer.step()
# 進捗表示(例:100バッチごと)
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}')
このトレーニングループには、効率的なトレーニングのための重要な要素がすべて含まれています:
- モデルとオプティマイザの定義: トレーニングを開始する前に、モデルとオプティマイザを適切に設定します。モデルはGPUに移動し、オプティマイザは適切な学習率で初期化されます。
- バッチ処理: データローダーはデータセットをバッチに分割し、各イテレーションでバッチを提供します。このアプローチにより、巨大なデータセット全体を一度にメモリに読み込む必要がなくなります。
- GPU使用: データとターゲットは計算前にGPUに移動されます。これにより、すべての計算が同じデバイス上で効率的に実行されます。
- 勾配の初期化:
各バッチの前に
optimizer.zero_grad()
を呼び出して勾配をクリアします。これにより、前のバッチからの勾配が蓄積されることを防ぎます。 - フォワードパスとロス計算: モデルにデータを通し、予測を生成し、損失を計算します。
- バックワードパス:
loss.backward()
を呼び出して、モデル全体を通じて勾配を計算します。これはPyTorchの自動微分機能を使用して行われます。 - パラメータ更新:
optimizer.step()
を呼び出して、計算された勾配に基づいてパラメータを更新します。 - 進捗モニタリング: 定期的に損失や他のメトリクスを表示して、トレーニングの進行状況を監視します。
実際のトレーニングループには、以下のようなさらなる拡張を含めることがよくあります:
- 検証ループ:一定のエポックごとに検証セットでモデルを評価
- 学習率スケジューラ:トレーニングの進行に応じて学習率を調整
- 勾配クリッピング:勾配爆発を防ぐために勾配のノルムを制限
- チェックポイント保存:定期的にモデルとオプティマイザの状態を保存
- 早期停止:検証損失が改善しなくなったらトレーニングを停止
- テンソルボードやその他のロギング:トレーニングメトリクスの視覚化と記録
言語モデルのトレーニングには長い時間がかかり、途中でクラッシュする可能性があるため、チェックポイントを定期的に保存することは特に重要です。すべての進捗を失わないようにするために、モデル、オプティマイザ、現在のイテレーション数を保存する必要があります。
言語モデルのような大規模モデルでは、トレーニングループに分散トレーニングやメモリ最適化などの追加の複雑さが加わることがありますが、基本的な構造はここで示したものと同じままです。
12. 効率化テクニック
12.1 混合精度トレーニング
モデルトレーニングの効率を向上させるためのテクニックとして、混合精度トレーニングは非常に重要です。データ型の選択はトレーニングの精度と安定性、コストのトレードオフに直接影響します。
高精度(float32)は精度が高く安定していますが、メモリと計算コストが高くなります。一方、低精度(BF16、FP16、FP8など)は効率的ですが、数値的な不安定さをもたらす可能性があります。
混合精度トレーニングの基本的な考え方は、モデルのさまざまな部分に異なる精度を使用することです。一般的な推奨事項としては:
- フォワードパスにはBF16(または他の低精度)を使用する
- モデルパラメータとオプティマイザの状態にはfloat32を使用する
- 計算中の中間結果には低精度を使用し、最終的な結果には高精度を使用する
この方法は、2017年の「混合精度トレーニングの探求」という論文に遡ります。PyTorchには混合精度トレーニングを自動的に扱うためのツールが用意されています。手動で精度を指定するのは面倒な場合があるためです。
通常、モデルはクリーンでモジュール化された方法で定義されますが、精度の指定はその設計を横断するものです。PyTorchのtorch.cuda.amp
(Automatic Mixed Precision)モジュールは、この問題に対処するための便利なツールを提供しています。
# 混合精度トレーニングの例
from torch.cuda.amp import autocast, GradScaler
# モデルとオプティマイザの定義
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler() # 勾配スケーラー
# トレーニングループ
for epoch in range(num_epochs):
for batch in data_loader:
# 入力をGPUに移動
inputs = inputs.cuda()
targets = targets.cuda()
# 混合精度の自動キャストを使用
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# 勾配のスケーリングとバックプロパゲーション
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
このコードでは、autocast()
コンテキストマネージャがフォワードパスとロス計算中に自動的に精度を低下させ、GradScaler
がアンダーフローを防ぐために勾配をスケーリングします。
混合精度トレーニングの利点は多岐にわたります:
- メモリ使用量の削減:低精度のテンソルは高精度の半分のメモリしか使用しないため、より大きなバッチサイズやモデルサイズが可能になります。
- 計算速度の向上:特に最新のGPUはFP16やBF16演算のために最適化されており、2-3倍の高速化が可能です。
- バンド幅の改善:小さなデータ型によりメモリバンド幅の使用が減少し、データ転送が高速化します。
ただし、混合精度トレーニングには注意点もあります:
- 数値的安定性:一部の演算(特に小さな勾配や大きな活性化)は低精度で不安定になる可能性があります。
- 精度の損失:一部のケースでは、混合精度が最終的なモデルの精度に悪影響を与える可能性があります。
- 実装の複雑さ:自動ツールを使用しない場合、混合精度の実装はコードの複雑さを増加させます。
最近のトレンドとして、研究者はより低い精度(FP8など)でもトレーニングを安定させる方法を模索しています。例えば、FP8をトレーニング全体で使用する手法が発表されています。しかし、低精度になるほど数値的不安定性の制御が難しくなります。
トレーニング中に低精度を使用することが難しい一方で、事前トレーニング済みのモデルを推論用に量子化するのは比較的容易です。つまり、トレーニングには安定した精度を使用し、モデルが訓練された後に低精度に変換することで、推論時の効率を向上させることができます。
混合精度トレーニングは、大規模言語モデルのトレーニングにおいて事実上の標準となっており、リソース効率と計算性能を大幅に向上させることができます。
12.2 チェックポイントの保存と読み込み
言語モデルのトレーニングには長い時間がかかり、途中でクラッシュする可能性が常にあります。トレーニングの進捗を失わないためには、チェックポイントを定期的に保存することが非常に重要です。チェックポイントとは、モデルの現在の状態を保存したスナップショットであり、トレーニングを再開するために必要なすべての情報を含んでいます。
チェックポイントに保存すべき重要な要素は以下の通りです:
- モデルのパラメータ: モデルの重みとバイアス
- オプティマイザの状態: モーメンタムや適応学習率情報など
- 現在のイテレーション/エポック番号: トレーニングの進行状況
- その他の必要なメタデータ: 学習率スケジューラの状態、乱数生成器の状態など
PyTorchでのチェックポイントの保存と読み込みは非常に簡単です:
# チェックポイントの保存
def save_checkpoint(model, optimizer, epoch, iteration, path="checkpoint.pt"):
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'iteration': iteration,
# 必要に応じて他のメタデータを追加
}, path)
print(f"Checkpoint saved at epoch {epoch}, iteration {iteration}")
# チェックポイントの読み込み
def load_checkpoint(model, optimizer, path="checkpoint.pt"):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
iteration = checkpoint['iteration']
# 必要に応じて他のメタデータを取得
print(f"Checkpoint loaded from epoch {epoch}, iteration {iteration}")
return epoch, iteration
チェックポイントの頻度を決める際の考慮事項:
- トレーニングの不安定さ: 不安定なトレーニングではより頻繁なチェックポイントが必要
- 保存にかかる時間: 大きなモデルほど保存に時間がかかる
- ディスク容量: 頻繁なチェックポイントは大量のディスク容量を消費する
- 復旧に許容できる最大損失時間: クラッシュした場合に失っても構わない最大トレーニング時間
実用的なアプローチとしては、以下のような戦略が有効です:
- 初期エポックでは頻繁に(例:10,000イテレーションごと)チェックポイントを保存
- トレーニングが安定してきたら頻度を下げる(例:50,000イテレーションごと)
- 複数の「最近の」チェックポイントと、いくつかの「マイルストーン」チェックポイントを保持
- 時間ベースの間隔(例:2時間ごと)でチェックポイントを保存することも検討
大規模モデルのチェックポイントサイズを管理するためのテクニック:
- パラメータ共有: シャーディングやその他のテクニックを使用して、チェックポイントの読み込みと保存を効率化
- 増分チェックポイント: 前回のチェックポイントからの変更のみを保存
- 低精度での保存: チェックポイントを低精度(例:float16)で保存してサイズを半減
- 分散チェックポイント: マルチGPU設定での効率的な保存と読み込み
特に大規模な言語モデルでは、チェックポイントが数百GBに達することもあり、効率的な保存と読み込みが非常に重要になります。例えば、一部の大規模モデルでは、チェックポイントの保存に1時間以上かかることもあります。
また、定期的なチェックポイントに加えて、「最良」のモデル(例えば、検証損失が最小のモデル)を別に保存することも良い習慣です。これにより、最終的な評価やデプロイメントのために最適なモデルを選択できます。
チェックポイントの保存と読み込みは、単なる安全策ではなく、効率的なモデル開発のための重要なツールです。特に長時間のトレーニングを必要とする大規模言語モデルでは、堅牢なチェックポイント戦略が不可欠です。
12.3 活性化チェックポイント(概要)
先ほど、なぜ活性化(アクティベーション)を保存する必要があるのかという質問がありました。逆伝播時には、各層の勾配計算に層の活性化値が必要になるからです。例えば、i層の勾配は、その層の活性化に依存します。
しかし、すべての活性化を保存する必要はありません。これが「活性化チェックポイント」(Activation Checkpointing)と呼ばれる技術です。この技術を使用すると、メモリ使用量を大幅に削減できますが、計算コストが増加するというトレードオフがあります。
活性化チェックポイントの基本的な考え方は非常にシンプルです:
- フォワードパス中に、すべての層の活性化を保存するのではなく、選択した「チェックポイント層」の活性化のみを保存する
- バックワードパス中に、保存されていない活性化が必要になった場合、チェックポイントから再計算する
これにより、メモリと計算時間のトレードオフが生まれます。活性化の保存に使用されるメモリが減少する一方で、保存されなかった活性化を再計算するための追加の計算時間が発生します。
PyTorchでは、torch.utils.checkpoint
モジュールを使用して活性化チェックポイントを実装できます:
import torch.utils.checkpoint as checkpoint# 通常のフォワードパスdef forward(self, x): layer1_output = self.layer1(x) layer2_output = self.layer2(layer1_output) layer3_output = self.layer3(layer2_output) return layer3_output# チェックポイントを使用したフォワードパスdef forward_with_checkpoint(self, x): # layer1の出力はチェックポイントされる layer1_output = self.layer1(x) # layer2とlayer3の計算をチェックポイント関数でラップ def custom_forward(x): layer2_output = self.layer2(x) layer3_output = self.layer3(layer2_output) return layer3_output # layer2と3の活性化は保存されず、必要なときに再計算される output = checkpoint.checkpoint(custom_forward, layer1_output) return output
活性化チェックポイントは特に深いネットワークで効果的です。例えば、24層のトランスフォーマーモデルでは、4層ごとにチェックポイントを設定することで、活性化のメモリ要件を約6分の1に削減できます。
この技術のメリットとデメリットを以下にまとめます:
メリット:
- メモリ使用量の大幅な削減(多くの場合50%以上)
- より大きなモデルやバッチサイズのトレーニングが可能
- 同じハードウェアでより深いネットワークを構築可能
デメリット:
- 追加の計算コスト(通常20-30%の計算オーバーヘッド)
- トレーニング時間の増加
- 実装の複雑さの増加
活性化チェックポイントの最適な戦略は、モデルのアーキテクチャとハードウェアの制約によって異なります。一般的なアプローチには以下があります:
- 均等間隔チェックポイント: モデルの層を均等に分割し、各区切りにチェックポイントを配置
- メモリ使用量に基づくチェックポイント: メモリ使用量の多い層に優先的にチェックポイントを配置
- 計算コストに基づくチェックポイント: 再計算コストの低い層に優先的にチェックポイントを配置
大規模言語モデルのトレーニングでは、活性化チェックポイントはほぼ必須の技術となっています。GPT-3やLLaMAなどのモデルでは、この技術がなければ利用可能なハードウェアでトレーニングすることは不可能だったでしょう。
活性化チェックポイントは、伝統的なメモリと計算のトレードオフを示す優れた例です。限られたリソースで大規模モデルをトレーニングするためには、こうしたトレードオフを慎重に考慮し、最適な戦略を選択する必要があります。
Stanford CS336: Language Modeling from Scratch | Spring 2025 | Pytorch, Resource Accounting
For more information about Stanford's online Artificial Intelligence programs visit: https://stanford.io/ai To learn more about enrolling in this course visit: https://online.stanford.edu/courses/cs336-language-modeling-scratch To follow along with the course schedule and syllabus visit: https://stanford-cs336.github.io/spring2025/ Percy Liang Associate Professor of Computer Science Director of Center for Research on Foundation Models (CRFM) Tatsunori Hashimoto Assistant Professor of Computer Science For more information about Stanford's online Artificial Intelligence programs visit: https://stanford.io/ai To learn more about enrolling in this course visit: https://online.stanford.edu/courses/cs336-language-modeling-scratch To follow along with the course schedule and syllabus visit: https://stanford-cs336.github.io/spring2025/ Percy Liang Associate Professor of Computer Science Director of Center for Research on Foundation Models (CRFM) Tatsunori Hashimoto Assistant Professor of Computer Science View the entire course playlist: https://www.youtube.com/playlist?list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_
youtu.be