※本記事は、Stanford CS336 Language Modeling from Scratch | Spring 2025 | Kernels, Tritonの講義動画(https://www.youtube.com/watch?v=E8Mju53VB00 )の内容を基に作成されています。本記事では、講義の内容を詳細に要約・解説しております。なお、本記事の内容は原講義の見解を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や文脈については、オリジナルの講義動画をご視聴いただくことをお勧めいたします。
登壇者について
本講義は、Stanford大学コンピュータサイエンス学部の教授陣により実施されています。
Percy Liang氏 Stanford大学コンピュータサイエンス准教授、Center for Research on Foundation Models (CRFM) ディレクター
Tatsunori Hashimoto氏 Stanford大学コンピュータサイエンス助教授
コース情報
- Stanford CS336コース全体のプレイリスト:Stanford CS336 Language Modeling from Scratch
- コーススケジュールとシラバス:https://stanford-cs336.github.io/spri...
- 受講に関する詳細情報:https://online.stanford.edu/courses/c...
- Stanfordオンライン人工知能プログラム:https://stanford.io/ai
本コースは、Stanford Online(https://online.stanford.edu/ )を通じて提供されており、Stanford School of Engineeringのポータルとして、Stanford大学全体の学術・専門教育を提供しています。Stanford Onlineは、Stanford Engineering Center for Global & Online Education (CGOE)により運営・管理されています。
1. 講義の概要と目標
1.1 Assignment 2の紹介(プロファイリング、Tritonカーネル、Flash Attention 2)
今日は、GPUのための高性能コードを書くことの詳細について説明していきます。Assignment 2の一部として、皆さんには多くのプロファイリング作業を行ってもらうことになります。また、Flash Attention 2のための独自のTritonカーネルを書く必要があり、これらすべてのコードを非常に高性能にする必要があります。
この講義では、言語モデルの標準的なコンポーネントのための高性能コードを書くことに焦点を当てて、詳細に掘り下げていきます。Assignment 1は終了していますが、まだリーダーボードが残っており、そこでの提出やアップデートは引き続き可能です。遅延日数を使用している学生もいるかもしれませんので、Assignment 1を完了させてください。
Assignment 2が既に公開されており、前述の通り、多くのシステム関連の作業を行う必要があります。現在取り組めるGPUカーネルに関わる楽しい部分があり、来週は並列処理について話し、それがAssignmentのもう半分となります。データ並列処理などの高速な並列コードの記述について、来週扱う予定です。
1.2 高性能GPUコードの書き方
この講義で皆さんに覚えておいてもらいたい高レベルな重要事項は、高性能コードを書きたいなら、コードをベンチマークしプロファイリングすることを忘れてはいけないということです。これは非常に明白に聞こえるかもしれませんが、学生や人々が「これがボトルネックだと思うから、3時間かけて最適化しよう」と考える場面を多く見てきました。しかし結果的に、それは全くボトルネックではなかったということがよくあります。楽しい作業だったかもしれませんが、時間の配分が間違っていたのです。
実際に高性能な詳細プロファイラを使用すれば、ボトルネックが正確にどこにあるか、マシンが実際に何をしているのかを正確に見ることができます。それが分かれば、コードの実行の最も重要な部分に労力を集中できます。これが私が伝えたい高レベルな考え方です。なぜなら、GPU実行に関する詳細や、softmaxカーネルの書き方などの具体的な内容は変わる可能性があり、torch compileの自動JIT機能に頼りたいと思うかもしれないからです。
しかし、プロファイリングをすべきだという事実は、ツールが何であれ、本当に変わることはないでしょう。ですから、高性能コードを書く際には常にプロファイリングをすべきであるという考えを内在化してほしいのです。実際のところ、理論には限界があります。システムはこのコースの一部であり、かなりよく推論できる部分もあります。アーキテクチャについてはある程度推論が困難で、ルーフラインモデルなどについて考えることはできますが、行列乗算がどれくらい速いかということになると、それはライブラリのバージョンやハードウェア、どの部分がどのような理由でボトルネックになっているかなど、あらゆる種類のマイクロコードの事柄があり、完全には分からないのです。
そのため、最終的にはこれらのものを開発する際には、エンドツーエンドのベンチマークを行わなければなりません。
1.3 講義の構成(GPU復習、ベンチマーク・プロファイリング、カーネル実装、比較検証)
この講義の計画として、まずGPUについての簡単な復習を行います。これは、講義の残りの部分を理解するために必要な、GPUの基本コンポーネントについて確実に理解してもらうためです。
次に、ベンチマークとプロファイリングに関する非常に基本的な事柄を紹介します。これらはAssignmentにとって有用であり、一般的に高性能PyTorchや深層学習コードを書きたい場合にも役立ちます。
その後、実際にカーネルを書いていきます。C++でCUDAカーネルを書き、次に同じことをTritonで行います。最後に、PyTorchの既存のJITコンパイラを使用して自動最適化させるという簡単ですが非常に良い方法を試し、これらすべてを比較してプロファイリングとベンチマークを行います。
全体を通して、本当に深く掘り下げていきます。PTXまで、つまりマシンコードに非常に近いレベルまで下がって、私たちがこのコードを書く際に、GPUが実際にフードの下で何をしているのかを理解していきます。時間があれば、そして時間はあると思いますが、最後にsoftmaxの高速なTriton実装を書くことで締めくくります。
2. GPU実行モデルの復習
2.1 GPU アーキテクチャ(A100、H100のSM、メモリ階層)
GPUがどのように動作するかを思い出してみましょう。A100やH100のようなGPUを使用する場合、多数のSM(Streaming Multiprocessor)を持つことになります。各SM内には、計算を行うことができる多数のユニットがあります。INT32のものやFP32のものがあり、各SMは大量のスレッドを起動することになります。
そして、メモリ階層があります。これは、DRAMまたはグローバルメモリがあり、これは大きいが遅いものです。そして、はるかに高速なキャッシュがあります。実際、ここにレジスタファイルと呼ばれるものが見えますが、これは各スレッドがアクセスできる非常に高速なメモリです。今日、GPU用の高性能コードを書く際には、これらのレジスタを多用することになります。
基本的な実行モデルの構造として、スレッドブロックのコレクションがあり、ブロックは単一のSM上でスケジュールされます。これは、特にTritonのようなもので代码を書く際に考える原子単位のようなものです。各ブロック内には多数のスレッドがあり、スレッドが実際に計算を行います。
ベクトルがあり、そのベクトルの要素に対して操作を行う場合、各スレッドが入ってベクトルのいくつかの要素に対して一度に操作するようなコードを書くことになります。すべてのスレッドが一緒になって、ベクトルを完全に処理することになります。
2.2 スレッドブロックとスレッドの実行モデル
なぜスレッドブロックというものがあるのでしょうか。なぜ単にスレッドと大きなグローバルコンテキストだけではないのでしょうか。スレッドブロックは互いに通信することができます。SM内にはかなり高速な共有メモリがあります。行列乗算のような処理を行う必要がある場合、スレッドからスレッドへ情報を渡す必要があります。
スレッドブロック内では、これは非常に高速です。スレッドブロック間、つまりこれらのグループ間では、非常に高コストになります。必要なデータはすべて、同じスレッドブロック内、または同じグループ内に保持したいと思うでしょう。これにより、すべてが非常に高速に保たれます。これはL1キャッシュと同程度に高速になります。
これは素晴らしい場所にいることになります。これを使用してスレッド間で同期することができます。しかし、例えばブロック間で同期することはできませんし、何が起こるかを実際に制御することもできません。
先週お話しした重要な概念として、wavesと呼ばれるものがあります。wavesは通常考える固有のものではありませんが、性能にとっては重要なコンポーネントです。これらを実際に実行する際、スレッドは32スレッドの連続ブロックにグループ化され、これがwaveと呼ばれ、SM内で一度に実行されます。
私たちがやりたいことの一つは、すべてのwaveが等量の計算を持つことを確実にすることです。常にそれができるわけではありませんが、できる場合はそうしたいものです。理想的には、スレッドブロックの数がSMの数で割り切れるようにし、各waveが等量の作業を持つことを確実にしたいのです。
理想的には、SMよりもはるかに多くのスレッドブロックを持ち、高性能コードを書く際にはそれが実現するようにしようとします。
2.3 Waveの概念と性能への影響
先週お話しした重要な概念として、wavesと呼ばれるものがあります。wavesは通常考える固有のものではありませんが、性能にとっては重要なコンポーネントです。これらを実際に実行する際、スレッドは32スレッドの連続ブロックにグループ化され、これがwaveと呼ばれ、SM内で一度に実行されます。
講義中にwarpについて質問があり、warpは本質的に一緒に実行されるスレッドのグループです。warpが存在する理由は、必要な制御機構の量を減らすためです。これらすべてのスレッドを同時に実行するため、各スレッドに制御機構を持つ必要がありません。32個のブロックに対して制御機構があればよいのです。例えば、計算ユニットの数がwarpスケジューラーの数よりもはるかに多いことが分かります。
これにより、制御について心配することなく、はるかに多くの並列作業を行うことができます。これはCPUとのトレードオフの一つです。CPUは、制御や分岐予測などにより多くのシリコン面積を割いています。一方、GPUは、より単純な制御で計算により多くの重点を置いています。
私たちがやりたいことの一つは、すべてのwaveが等量の計算を持つことを確実にすることです。常にそれができるわけではありませんが、できる場合はそうしたいものです。理想的には、スレッドブロックの数がSMの数で割り切れるようにし、各waveが等量の作業を持つことを確実にしたいのです。理想的には、SMよりもはるかに多くのスレッドブロックを持ち、高性能コードを書く際にはそれが実現するようにしようとします。
2.4 算術強度(Arithmetic Intensity)の重要性
最後の概念で、おそらく最も重要な概念の中の一つが算術強度(Arithmetic Intensity)です。私たちは算術強度を高く保ちたいと考えています。メモリ移動のバイト数よりも多くのFLOPSを持ちたいのです。
これは、前回の講義のスケーリングプロットを覚えているなら、計算のスケーリングがメモリスケーリングよりもはるかに速いからです。そのため、多くの場合、計算はメモリバウンドになってしまい、実際にはすべての作業を完了させることができません。
一般的なルールとして、巧妙に行えば行列乗算は計算バウンドになります。他のすべてはメモリバウンドになりがちで、メモリバウンドである物事の量を巧妙に削減しようとするか、どれだけひどくメモリバウンドであるかを軽減しようとします。
3. ベンチマークとプロファイリングの基礎
3.1 高性能コード開発における測定の重要性
高性能コードを書きたいなら、コードをベンチマークしプロファイリングすることを忘れてはいけません。これは非常に明白に聞こえるかもしれませんが、学生や人々が「これがボトルネックだと思うから、3時間かけて最適化しよう」と考える場面を多く見てきました。しかし結果的に、それは全くボトルネックではなかったということがよくあります。楽しい作業だったかもしれませんが、時間の配分が間違っていたのです。
実際に高性能な詳細プロファイラを使用すれば、ボトルネックが正確にどこにあるか、マシンが実際に何をしているのかを正確に見ることができます。それが分かれば、コードの実行の最も重要な部分に労力を集中できます。これが私が伝えたい高レベルな考え方です。なぜなら、GPU実行に関する詳細や、softmaxカーネルの書き方などの具体的な内容は変わる可能性があり、torch compileの自動JIT機能に頼りたいと思うかもしれないからです。
しかし、プロファイリングをすべきだという事実は、ツールが何であれ、本当に変わることはないでしょう。システムはこのコースの一部であり、かなりよく推論できる部分もあります。アーキテクチャについてはある程度推論が困難で、ルーフラインモデルなどについて考えることはできますが、行列乗算がどれくらい速いかということになると、それはライブラリのバージョンやハードウェア、どの部分がどのような理由でボトルネックになっているかなど、あらゆる種類のマイクロコードの事柄があり、完全には分からないのです。そのため、最終的にはこれらのものを開発する際には、エンドツーエンドのベンチマークを行わなければなりません。
3.2 ベンチマーク関数の実装(ウォームアップ、CUDA同期の必要性)
今から、実例となる計算を使います。皆さんがAssignment 1で行っているすべてのことと比べると、これは最もシンプルなものです。非常にシンプルなMLPを実行します。128次元、16層、あるバッチサイズで、5つのステップを実行します。5つの異なるステップについて、単純にフォワードとバックワードを行います。
この講義を通して、このbenchmark関数を使用します。これはラッパー関数です。benchmarkは次のことを行います。ベンチマークしたい関数runがあり、いくつかのウォームアップ反復を行い、その後いくつかのトライアルを実行します。
このウォームアップとは何でしょうか。非常に重要なことの一つは、PyTorchコードを初めて実行し、それがGPUに何かを送る場合、非常に高速で透明に見えるかもしれませんが、何かが初めて実行される際、バックグラウンドでマシンコードがコンパイルされているということです。そのコード命令がGPUに送られている可能性があります。コードを初期化するためのあらゆる種類のことが起こっています。
そのため、起動速度ではなく、定常状態の速度を測定したいので、常にいくつかのウォームアップ反復を行いたいのです。何千、何千もの反復を実行する場合、興味があるのはその部分であり、CUDAコードのオンザフライコンパイルがどれくらい速いかではありません。そのためウォームアップがあり、常に少しのウォームアップを持つべきです。
もう一つの非常に重要なことは、torch.cuda.synchronizeというものを呼び出すことです。これは何でしょうか。GPUとCPUは基本的にコンピューター内の2つの独立した計算ユニットです。基本的に独立して実行できます。
実行モデルは、ここにあるPythonコードがCPUに存在するということです。何かを実行する際、多くのCUDAカーネルをGPUに送ります。「これらを実行してください」と言うのです。GPUはそれらを実行しに行きます。CPUは実際に続けて実行し、これらのCUDA実行が停止するのを待ちません。
これは高性能コードを書くには素晴らしいことですが、ベンチマークを行いたい場合の即座の問題が見えるでしょう。ベンチマークを行い、GPUが横で実行し、CPUが何か違うことをしているモデルがある場合、実際にはGPU実行時間を測定していないのです。
torch.cuda.synchronizeは基本的に、GPUとCPUが同じ状態にあることを確認し、実行中のキューされたものがなく、実行されているコードの観点で同じポイントにいることを確認します。このsleep例では、50ミリ秒間スリープしようとしているので、最終的に得られる時間がそれになります。私は3回時間を測定し、もちろんここでも、GPUとCPUの状態が同じであることを確認するため、runの最後でtorch.cuda.synchronizeを呼び出しています。CPUが先行して実行している場合、GPU実行が実際に終了するのを待ちます。
これで私は終了し、複数の測定を平均化します。なぜなら、各単一の測定はGPUの熱特性などのために変動する可能性があるからです。複数の複製を取り、平均を取ってそれを返します。これが私たちのベンチマークコードです。非常にシンプルですが、ここで重要な2つの部分を覚えておいてください。常にウォームアップを行うこと、CUDA synchronizeを呼び出すことを確実にすることです。これらを行えば、非常にシンプルです。これらを忘れると、大きな行列乗算が瞬時に終了したというような、明らかに真実ではないかなりクレイジーな数値を得ることになります。
3.3 行列乗算のスケーリング実験結果
行列乗算のベンチマークを行ってみましょう。これらのいくつかを見ていきます。これらは私たちが既に知っていることに数値を与えるだけですが、これを見ていき、同じページにいることを確認したいと思います。
クラスのH100 GPUでこれを実行しました。これらのサイズで行列乗算を行います。各次元についてこの種のベンチマーク結果を踏まえて、多くの行列乗算のタイミングを収集しました。ご覧のように、期待通り、行列サイズを増加させるにつれて、実行時間の超線形スケーリングが見られます。
もちろん、1024や2048のような最小サイズでは、実際には時間が全く増加しないことが分かります。これは、これらの行列乗算を行うだけで定数因子のオーバーヘッドがあるためです。これらの数値をCPUからGPUに運ぶ必要があり、カーネルを起動することにオーバーヘッドがあります。そのため、ゼロまでずっと超線形であるということはありません。
しかし、行列が十分に大きくなると、行列乗算で期待する種類のスケーリングが正確に見られます。
3.4 MLPベンチマークの線形スケーリング検証
次に、MLPをベンチマークしてみましょう。何をするかというと、MLPを大きくします。256次元、4層、バッチサイズ256、2ステップを取ります。これを行うのにかかる時間はどうでしょうか。それは6.2秒かかります。
ここで基本的なことができます。ステップ数を2から5まで拡張し、それらすべてをベンチマークできます。2、3、4、そして5ステップを得ます。行列乗算の場合とは異なり、MLPでのフォワードパスとバックワードパスの数であるステップ数を拡張する場合、実行時間はどのように振る舞うことを期待するでしょうか。線形スケーリングを期待し、それが私たちが見ているものです。
MLP実行ごとに約5秒あり、エンドツーエンドオブジェクトの実行時間について、約n倍5秒が見られます。
層数を2、3、4から5まで拡張することもでき、これも実行時間の増加をもたらします。再び層数に対して線形で、今回は1層が約5秒よりも少し少なくかかり、約4倍の層数倍を得て、線形スケーリングが再び現れます。当然のことながら、ステップと層の両方は明らかに実行時間と線形関係を持ち、それが最終的にここで見られるものです。
バッチサイズについてはスキップします。追跡されているものの量が少し扱いにくくなっているからです。
4. PyTorchプロファイラによる詳細分析
4.1 基本的なプロファイリング手法
これでベンチマークビットの終わりです。このようなnice関数を作ることができ、少しウォームアップを行い、CUDA synchronizeを行い、何でも測定したいもののランタイムを測定できます。これは良いことで、コードで常にこれを行うべきです。新しい派手なアーキテクチャの実行にどれくらい時間がかかるかを測定できます。
しかし、いくつかの問題を修正したい場合、ベンチマークは非常に粗い粒度のツールです。コードが遅いことは教えてくれますが、時間がどこに費やされているかは教えてくれません。そこで私たちがやりたいことは、代わりにプロファイリングを行うことです。これははるかに細かい粒度のオブジェクトとなります。
プロファイリングは本当に素晴らしいものです。なぜなら、時間がどこに費やされているか、どの関数に費やされているかを見るのに役立つだけでなく、何を呼び出しているかを見る際、通常はPyTorchインターフェースと相互作用しますが、PyTorchの下には呼び出されているCUDA関連のものの全宇宙があります。プロファイラを実行すると、実際に低レベルの呼び出しまで、実際に何が呼び出されているかをすべて見ることができます。
これにより、プログラムが実際にハードウェア上でどのように実行されているかについて、はるかに良い直感を得ることができます。いくつかの簡単な関数をプロファイリングして、何が起こっているかについてのいくつかの直感を得ていきます。
素晴らしいことの一つは、基本的なプロファイリングが必要な場合、PyTorchには使用できる非常に優れた内蔵プロファイラがあることです。これにより、Python PyTorchの世界を離れることなく、かなり合理的に見える出力を得ることができます。
4.2 加算操作のプロファイル分析(A10インターフェース、カーネル起動コスト)
sleepの例からプロファイリングに移りましょう。いくつかの関数をプロファイリングし、ここでその出力も見ることができます。
sleepの例を取りました。ここにsleep関数があり、sleep関数をプロファイリングする際、プロファイル関数は次のようになります。再びウォームアップがあり、torch.cuda.synchronizeがあり、プロファイラを呼び出してCPUとGPUの時間を両方追跡しています。何かを実行し、再び同期し、すべての時間にわたって平均テーブルを印刷します。
今、sleep関数をプロファイリングします。何が起こっているか見てみると、100%の時間がCUDA device synchronizeと呼ばれるものに費やされています。これは、GPU作業が行われていないためです。これは単なるno-opです。プロファイリングするには少しばかげたものです。
今度は自明でないものを見てみましょう。2つの行列を加算するこの基本操作を見てみましょう。AとBを取ってそれらを足し合わせるadd関数を定義しました。これは2つのランダムガウス行列をインスタンス化し、operationの引数にあるものを呼び出すヘルパー関数です。これは2つの2048サイズの行列を加算しています。
今、これをプロファイリングし、プロファイラを呼び出すと、このブロックのように見えるものが返されます。
PythonでaddA関数を呼び出す際、これはa plus bだけで、それが私たちが相互作用するすべてです。しかし実際に、いわば氷山の下では、はるかに多くのことが起こっています。これはGPUに送られ、最初にA10と呼ばれるものがあり、これはPyTorchのCインターフェースです。このラッパーが呼び出され、「いくつかの数値を加算します」と言います。これが呼び出されている外側のラッパーです。
次に、それは特定のカーネルに送られます。vectorized_elementwise_kernel_for_native_CUDA_functor_addと呼ばれるものです。これが実際に加算を行っているものです。また、CUDA launch kernelと呼ばれる他のものもあり、これも時間を取っています。これは実際に、CPUがコマンドを取ってGPUに送信している、カーネル起動です。それには時間がかかります。
最後に、CUDA device synchronizeがあります。GPUが終了して私たちに物を送り返すのを待っています。それも時間がかかります。同期バリアを持つという単なる行為が、私たちに時間を要することになります。基本的に、最終的にここで総時間は、CPU上で1.44ミリ秒、CUDA上で17マイクロ秒です。
それらは本当にGPU上で高速で、CPU上では遅いです。費やされているCPU時間を見ると、これはself CPU timeですが、C++インターフェースまたはCインターフェースが実際に多くのCPU時間を消費しているものであることが分かります。GPUに物を送る何かを行う際には、オーバーヘッドがあります。
4.3 行列乗算の分析(CuBLAS、XMMA GEMMカーネルの使い分け)
同じ話で、行列乗算を行いたい場合について見てみましょう。AにBを掛けるので、これはAとBの行列乗算です。再び2048の行列を使用しています。
プロファイリングを行うと、今度はa10_mmと見ることができます。これは行列乗算を行うための低レベルインターフェースです。これはCuBLASに送られ、CuBLASはNvidiaの高性能行列乗算CUDAライブラリです。そして、これは非常に特定のCuBLASカーネルに送られ、何らかのタイルサイズを持ちます。
ここでは名前が切り捨てられています。すぐにより詳細なバージョンをお見せします。これは基本的に、何らかのタイルサイズやブロック数などの非常に特定のセットを指しています。このものはパラメーター化されています。それが実際に行列乗算を行っているものです。再び、下部に同じ2つのもの、カーネル起動とCUDAデバイスの同期が見られます。
CPU時間とCUDA時間の分割を再び確認できます。行列乗算は実際により多くの時間がかかるため、CUDAでははるかに多くの時間を費やしており、ベクトルを2つ加算するよりも時間がかかります。
別の行列乗算の例があります。これは異なる次元なので、128次元の行列を掛け合わせています。128 × 128で、はるかに小さいです。実際に今度は、この異なるコマンドを直接実行していることが分かります。XMMA GEMMを実行しています。
GEMMは行列乗算のタイプで、これはfloat 32 float 32です。このカーネルの命名から何が実際に起こっているかを何となく見ることができ、これは何らかの種類のタイル化行列乗算で、CuBLASを通らずに、この特定のコマンドを直接実行しています。小さな行列乗算に対しては、異なるカーネルに送っていることが分かります。
この高レベル抽象化で動作している際、行列乗算を単一のものとして考えます。a @ bを呼び出して終わりです。しかし、フードの下では、持っている次元に応じて、持っているハードウェアに応じて、実際に非常に異なる行列乗算プリミティブにフードの下で送られます。
これは実際に、非常に異なる性能特性として現れます。面白いtipの一つは、後で話すtorch compileには実際に、ハードウェア上で行列乗算性能をマイクロベンチマークするオプションがあり、その後、モデル用の最高性能の行列乗算サブルーチンを実際に選択することです。過去に、これらのことを最適化することで実際に10%程度の無料の速度向上が得られることを発見しました。
このようなことを最適化することで、実際に現実世界で無料の利得が得られるのは非常にクールです。
4.4 複合操作の分析(C距離計算、GLU、Softmax)
これまでの操作は、ある意味で非常に退屈でした。行列乗算や加算のように、基本的には一対一です。CPU側で一つの操作があり、それがGPU操作に変換され、単に送られるだけです。これらすべてで、GPU上で何かを行う単一の操作だけがあります。
より複雑な操作、より複合的な動作を持つ2つの操作を見てみたいと思います。今やりたいことは、torch.cdistと呼ばれるこの操作を見ることです。これは2つの行列のセットについて、2つのベクトルのセット間のペアワイズユークリッド距離を計算します。これは、AとBの間の大きな距離行列計算になります。
これは明らかにはるかに複雑な操作です。ユークリッド距離を計算したい場合、内積を計算し、平方根を計算する必要があり、cdistのプロファイル出力がより複雑になることが分かります。
このtorch Pythonコマンドは、Cインターフェースで何らかの低レベルcdistにマップされることが分かります。これはa10_cdistで、それはa10_euclidean_distにマップされます。そして、これは多くのものに分解されます。a10_mm、a10_pow、そしてsumのように、ユークリッド距離を実際に計算するために必要なすべてのプリミティブだからです。
これらの各々について、行列乗算や連結、べき乗の取得のように、ここで呼び出されている対応するCUDAコマンドがあります。gmmがあり、これは行列乗算で、私たちが慣れ親しんだものです。これはGPU時間の78%の計算時間を取っています。配列のコピーと連結があり、これは実行時間の6%を取り、べき乗を取るこのvectorized elementwise kernelは、GPU時間の5%を取り、3%がsumに行きます。
今、GPU上でどこに時間が費やされているかの非常に良い低レベル内訳が得られます。これから、最適化に時間を費やすべき場所についてのいくつかの感覚を得ることができます。行列乗算を最適化できると思うなら、それは素晴らしいでしょう。なぜなら、それは70%以上の時間が費やされている場所だからです。
最後の例、最後の2つの例として話したいのは、GLUとsoftmaxです。これらは講義を通して私たちの実例となります。GLUは非線形性になります。覚えているなら、それはGaussian error unit、Gaussian error linear unitです。それはtanhと指数の積になると思います。
すべての種類の操作があります。AとBを加算し、次にgeluを呼び出して、MLPで持つかもしれない線形プラス非線形構造をシミュレートします。再び、基本的に同じ種類のマッピングが見られます。a plus bに対応するa10_addがあり、CUDAの同等物があります。実際に、CUDAで実装されたGELU関数があり、ここの下の方にあり、計算の約33%を取っています。かなり合理的です。
最後にsoftmaxがあります。すべてをいちいち詳細に説明するつもりはありません。しばらくするとすべて同じに見え始めるからですが、指摘したい本当にクールなことは、softmaxやGELUのような多くの本当にコアなプリミティブに対して、それらのためのカーネルが書かれているということです。
GPUが基本的なプリミティブを実行しているのではありません。これらすべてを計算する融合演算子があります。そのため、これらすべてについてCPUとGPU間で行ったり来たりすることはありません。
5. NVIDIA Nsight Systemsによる高度なプロファイリング
5.1 CPU-GPU間の非同期実行モデルの可視化
CPUが何をしていたかという質問に答えるために、より洗練されたものについて話しましょう。ベンチマーク用に最初に始めたMLPの例を取り上げ、そのMLPを最適化したい、本当に高速に実行させたいとしましょう。理想的には、これを良い細かい粒度でプロファイリングしたいでしょう。
torch profilerを使用すると、このようなものが得られます。MLPを覚えているなら、積み重ねられた線形層があります。フォワードとバックワードがあります。backwardが起こっている、行列乗算がある、linearがある、そして、バックワード用のaccumulate grad操作があることが大体分かります。
ここに行列乗算カーネルがあります。10個しかここに収まらないので、これは特定の時点で切り取られると思います。これは良いものです。時間の大部分が行列乗算に費やされていることを教えてくれます。しかし、残りの時間がどこに行くのか、なぜここで31%だけが私の時間にとどまり、60%はどこにあるのか疑問に思います。これはA10_mmですが、対応するカーネルがありません。
これは少し神秘的で、非常に複雑なモジュールに対しては、これは非常に良い視覚化ではありません。そのため、実際の大人向けプロファイラを取り出す必要があると思います。NvidiaのNsight Systemsという、GPU動作とパフォーマンスを見るNvidiaの詳細な方法を使用する必要があります。実際にこのMLPを実行している際に正確に何が起こっているかを見ることができます。
ここで複数の異なるものが見えます。ここにCUDA HWがあり、次にthreadsがあります。この上半分、このCUDA部分は、GPUが何をしているかです。このthreads部分では、CPUが何をしているかを見ることができます。
プロファイルした際に、コードにいくつかのアノテーションを追加しました。NVTXというもので基本的にコードにマーカーでアノテーションを付けます。プロファイラがここに来た時、このコードの部分がdefine modelと呼ばれるブロックに属することを知るでしょう。
range pushとrange popを言うこの範囲は、77行目から55行目まで、step_stepと呼ばれるもので注釈されるべきです。コードにこれらすべてのアノテーションを追加してからプロファイラを呼び出しました。
nvtxと言うこの行に戻ると、define modelが見えます。これは私がモデル構築呼び出しを包んだものです。そして、step zero、step one、step two、step three、step four、step fiveが見えます。各ステップが今このプロファイラで素晴らしく注釈されており、モデルが進行するにつれてモデルが行っているすべてのことを見ることができます。
この側から始めます。見ることの一つは、このコードの部分、これは実際には多くの作業を行いません。これは14秒しかかかりません。実際に、プロファイラの時間の大部分はオーバーヘッドに費やされています。大体ここまでの部分は、ライブラリを読み込むだけのようなものであり、それには長い時間がかかります。これは明らかに7.5秒かかります。すべてを初期化するだけで、その後、プログラムの7.5秒頃にGPU上で実際にモデルの構築を開始し、メモリフットプリントでここを見ると、これがメモリが割り当てられている場所で、GPUメモリでメモリ使用量が増加し始めます。
この時点でモデルが構築され、step zeroが実際にアクションが始まる場所です。
5.2 プリント文によるCPU-GPU同期への影響の実験
先ほど、CPU-GPU間で何が起こっているかについて質問がありました。この実行モデルの動作方法を説明しましょう。ここはCPU上のstep zeroで、ちょうどここから始まり、ここがforward passで、これがlayer zeroです。何が起こっているかを考えてみましょう。
前に言ったように、PyTorchでコードの一部に初めて遭遇する際、直接実行されるのではありません。実際に、オンザフライでのコンパイルなどを行います。runtime triggered module loadingは、layerと計算を初期化し、様々なコードをGPUに移動するために行われるオーバーヘッド作業のようなものです。これには長い時間がかかります。
layer zeroが完了した後、ここの任意のスライスを見ると、これらの各layerが本当に本当に素早いことが分かります。ここで何が起こっているかというと、CPU側でlayer oneをハイライトすると、それがGPU側のlayer 1がある場所ではないことに注意してください。前に言ったように、CPUとGPUは2つの異なる実行デバイスなので、layer zeroから始めて、layer zeroが完了し、layer oneを開始します。
今、CPUは実際に既にすべてのCUDAコマンド、CUDAカーネルをGPUに送信しています。この時点で、CPUがGPUにキューイングコマンドを送っているのです。「次にこれを実行しろ。次にこれを実行しろ。次にこれを実行しろ。」と言っているのです。
CPUはGPUよりもはるかに先行して実行しています。layer 1がGPU上で実行を開始する頃には、実際に、既にCPU上ではlayer 9にいるのです。CPUははるかに先行して実行しており、CPUが維持するキューが基本的にあり、固定数のCUDAカーネルをGPUに送信しています。そのキューの深度に達すると、先行して実行するのを停止します。しかし、その時点まで、できる限り先に先に進み続けます。
この場合、これは少し極端になります。再びズームアウトすると、これらのステップで、はるかに先行して実行していることに注意してください。step zeroがここ、step twoがここです。これがstep oneで、基本的に全く時間がかかりませんでした。CPUは基本的にGPUよりも完全に1つのstep forward and backwardを先行して実行している。
興味深いことの一つは、言語モデルを訓練するための様々なコードを書いている場合、反復間でlossをプリントするような通常のことを行うかもしれません。これはGPUが行っていることに影響を与えないはずに見えます。print文です。どれほど影響を与えるでしょうか?
しかし、少し考えてみると、これはGPU上の実行レイアウトに大きな影響を与えるでしょう。なぜなら、このprint文をプリントするために、このprint文はCPU上で起こり、CPUはlossを取得する必要があります。つまり、GPUがそのlossを計算するのを待つ必要があります。
何が起こるか見てみましょう。ここで、言ったように、CPU上のstep fourはGPUの同等物よりもはるかに前に起こります。今、切り替えてみましょう。今度は、print文があるバージョンをプロファイルしたものです。今、この選択範囲にズームインします。
今、step oneとstep twoが基本的に同期されていることが分かります。lossが計算されるのを待つ必要があるからです。これを見ると、「まだ少しオフセットがありますね。step two、step oneが完全に互いに整列していない」と言うでしょう。今、ズームバックインして、CPU上のstep oneに何が起こったか見てみましょう。
基本的に、CPU上のstep oneの終点も、optimizer stepが始まる場所でもあります。forward が完了する頃には、このCUDA stream synchronizeというものがあります。これは基本的に、CPU上で「GPUを待っているだけです。先行して実行できません。このlossが計算されて私に送り返されるのを待っています」と言っています。
これは、GPUを待つCPUの待機、待機、待機、待機、待機、待機のようなダミー操作です。backward stepが完了します。今、lossをプリントできます。lossをプリントしました。OK、今CPUは先行して実行を開始できます。先行して実行し、step twoのものを送信し始めます。これがここに達すると、コマンドが尽きます。再びlossを待っています。CUDA synchronize。待機、待機、待機、待機、待機。backward stepが完了しました。
今、lossをプリントできます。再び先行して実行します。
この場合、GPUは両方のケースで基本的にまだ完全な利用率です。しかし、常に大量のものをプリントしているような極端なケースでは、実際にCPUボトルネックを導入することになります。GPUがCPUを待ち続けなければならず、カーネルを事前に起動できないからです。
これは、プロファイラで見ることができる本当にクールなことの一種です。このCPU対GPUで、実際には互いに通信する異なるデバイスです。単一の統一されたオブジェクトではなく、これらのより高度なプロファイラを見始めなければ見ることができないでしょう。
5.3 メモリ使用量とカーネル実行タイミングの詳細分析
プロファイラで示したい他のことは、前に遊んでいたプロファイラのようなものと同様のビューを生成することもできるということです。測定したいもののいくつかの範囲を選択してみましょう。ウォームアップがあると言ったので、最初のいくつかのステップを除外すべきです。
step threeから始めて、いくつかのステップを測定します。この範囲でカーネルを取ることができます。これが計算を行っているものです。実際に多くの異なる種類の行列乗算があることが分かります。これは一つの行列乗算カーネルです。これは異なる行列乗算カーネルです。異なる種類のvectorized element kernelがあります。
これらすべてが異なる量の計算を取っています。これを取って、events viewで起こっているすべてのものを見せてもらうことができます。stats viewですべての時間がかかっているものも見ることができます。待ってください、平均時間が欲しいです。いや、CUDAカーネル実行サマリーが欲しいです。
カーネルの総持続時間を見たいので、どのカーネルが最も時間を取っているかを見て、これらのビュー間で集約できます。これは実際に非常に強力なツールで、何が遅くて何が速いかの集約ビューと、個々のカーネルが起動される時期、CPUコマンドがどこから来たかの両方を提供できます。
最後の脇道として、これがPythonでプログラミングしていても関係ない理由の一つです。Pythonはあまり高性能な言語ではありませんが、CPUが先行して実行でき、GPUにコマンドをキューイングできるため、CPUがボトルネックになることは決してありません。
このGPUとCPU間の切り離し、または切断の側面は、この素晴らしい高レベルプログラミング言語を使用できるのに、GPUから完全な利用率を得ることができる主要な理由の一つです。
5.4 アノテーション機能による実行フローの追跡
プロファイルした際に、コードにいくつかのアノテーションを追加しました。NVTXというもので基本的にコードにマーカーでアノテーションを付けます。プロファイラがここに来た時、このコードの部分がdefine modelと呼ばれるブロックに属することを知るでしょう。
例えば、range pushとrange popを言うこの範囲は、77行目から55行目まで、step_stepと呼ばれるもので注釈されるべきです。コードにこれらすべてのアノテーションを追加してからプロファイラを呼び出しました。
nvtxと言うこの行に戻ると、define modelが見えます。これは私がモデル構築呼び出しを包んだものです。そして、step zero、step one、step two、step three、step four、step fiveが見えます。各ステップが今このプロファイラで素晴らしく注釈されており、モデルが進行するにつれてモデルが行っているすべてのことを見ることができます。
これらのアノテーション機能により、実行の複雑な流れを構造化して理解することができます。特に、大規模なニューラルネットワークの訓練のような複雑な処理では、どのコード部分がどの時点で実行されているかを正確に把握することが、パフォーマンス最適化において非常に重要になります。
6. CUDAカーネルの実装
6.1 カーネル融合の概念と重要性
カーネル融合を思い出してください。これは講義で示した画像で、小さな工場があります。操作を行う必要がある度に、倉庫から工場へ、そして戻すために配送する必要があります。考えなしに一連の操作を素朴に順次行うと、倉庫からの多くの配送コストを支払うことになります。
私がすべきことは、すべての操作を一度に行う一つの工場を持つことです。そうすれば、このコストを複数回支払うことはありません。これは非常に重要です。
今、GLUを実行します。GLU用のカーネルを書き、いくつかの異なる方法でそのカーネルを書いて、それを行うことの性能への影響を見ていきます。GLUのPyTorch実装があり、それは次のようになります。torch.nn.functional.gelu。私が次に行う素朴なことと正確に一致させたいので、approximate equals tanhを呼び出します。
これは実際にガウシアンのCDFを掛けるのではなく、計算しやすいその近似になります。これがPyTorch GLUです。
今、愚かなことをするつもりです。このコードを見て、これは低性能になると言うでしょう。PyTorchで GLU を 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) として書くつもりです。魔法の公式ですが、これはGELUの良い近似です。これを調べるか、これが真実であることを自分で確信できます。
しかし、これを行うと、多くの操作が起こることが分かります。tanhがあり、x cubedがあり、定数による乗算があり、加算があり、0.5とxによる乗算があります。これが複数の異なるCUDAカーネルを含む場合、これはおそらく遅くなるでしょう。これが融合からの私たちの直感であるべきです。
6.2 GLU演算の手動実装 vs PyTorch実装の性能比較(8.1ms vs 1.1ms)
これが真実かどうか見てみましょう。これら二つは同じで、左上で同じ数値を計算することが分かり、ランダムガウシアンでこれを体系的にチェックできます。
今、この二つをベンチマークしてみましょう。manual timeは、本当に大きなGLUに対して8.1ミリ秒です。PyTorch timeは1.1ミリ秒です。融合されたバージョンは大幅に高速になります。実際、8倍高速です。
すごいですね。シンプルなカーネルを書くことによる大きな違いです。もちろん、おそらく行列乗算がまだボトルネックでしょうが、その8ミリ秒から1ミリ秒にできたら本当にクールでしょう。それは非常に満足のいくことでしょう。次の講義の数部分で、その1.1ミリ秒に近づこうとします。
今、フードの下で何が起こっているかを見てみましょう。NSISを見る必要はありません。なぜなら、manual GLUに対して知りたいのは非常に高レベルなことだけだからです。言ったように、多くの操作を行います。多くの乗算を行います。それはベクトル化されていますが、ここで起動されている多くのCUDAカーネルの束です。
右側を見ると、このCUDAカーネルが3回呼び出されていることに注意してください。ここにはたくさんの乗算が浮いているからです。加算もあります。tanhもあります。これらのそれぞれがおそらく遅く、最終的にこれを行うためにかなり大きなオーバーヘッドを被っています。
今、PyTorchで同じことをしてみましょう。これは本当に素晴らしいです。単一のCUDAカーネル起動があります。一度発生し、全体を処理するだけです。これが見たいものです。もちろん、これは単一のCUDAカーネルなので非常に高速で、私たちが望むものです。
6.3 CUDAカーネルの基本構造(グリッド、ブロック、スレッド)
これは本当に素晴らしく、何らかの方法でCUDAカーネルに到達したいと思います。最初に考えるかもしれないことは、どれくらいGPU効率的なコードを書くことについて知っているかに応じて、PyTorchの人々は可能な限り最低レベルの言語でこれを書いたに違いないということです。同じことをするつもりです。可能な限り最低レベルではありませんが、C++ APIに行き、C++でCUDAカーネルを書くつもりです。
CUDAカーネルを開いて、独自のものを書いてみましょう。どのように動作するのでしょうか。
C++版の全体を作成しました。私たちがCUDAと言う際、CUDAは実際にGPUをインターフェースし、プログラミングするためのC++ APIです。説明したGPUの論理モデルと同じように、何らかの関数fを書きます。このCUDAカーネルを呼び出す際、ベクトルや行列のすべての要素でFを自動的に呼び出します。そして、欲しいものすべてを並列計算できます。
命名法として、グリッドを持ちます。これはスレッドブロックのコレクションです。これをタスクがあると考えてください。それを断片に切り分けます。多くのブロックがあります。これは例えば2Dグリッドでの行座標で、列座標があります。これは行列で作業している場合に非常に有用です。
これらの各ブロックのサイズ、つまりスレッドブロックの数の観点でどれくらい大きいかがあります。これがブロックの次元です。そして、これらのブロック内にスレッドのコレクションがあり、これは例えば一つのスレッドブロックが存在する座標で、各スレッドは各ブロック内にあります。ここに階層構造があります。グリッドがあり、グリッド内にスレッドがあります。
基本的に各関数は3つのものを取ります。どのスレッドブロックに属しているかのブロックインデックス、ブロック次元とは何か、そして私のスレッドインデックスとは何か、私のインデックスとは何かです。これらで、行列やベクトルのどの座標にいるかを知ることができ、欲しいロジックを決定できます。
CUDAをデバッグしようとする際の最後の一つは、CUDA_LAUNCH_BLOCKING=1で起動したいということです。これにより実際にCUDAカーネルをデバッグできます。実行時のコストでエラーメッセージを返してくれます。これを行わなければ、CUDAコードを書いてデバッグする必要がある場合、ひどい時間を過ごすことになります。
6.4 GLU CUDAカーネルの詳細実装
ここに私のGLUコードがあり、少しずつ見ていきましょう。これらの部分が何をしているかを話しましょう。これは歩いて行く中で最も時間がかかるものになるでしょう。マシンコード以外は。一度これを理解すれば、他のすべての部分を理解できるはずです。少しゆっくりと進めていきます。
このコードには2つの部分があります。最初の部分、上のGLU kernel部分は実際のカーネルです。これが計算を行い、結果を返します。この部分、下のGLU関数は、ラッパーです。これはCPU上に存在し、実際にGPUに出て行って存在するカーネルの起動を調整します。
多分、このラッパー部分、このGLU関数から始めましょう。私たちは常に2つのことをチェックします。基本的にTritonやCUDAコードでは、常にチェックします。最初の一つは、XがGPUデバイス、何らかのCUDAテンソルに存在することを確認することです。そうでなければ、それは問題になるでしょう。GPU上で何もできません。
2番目のこと、これはあまり明白ではないかもしれませんが、Xが連続であることを確認したいのです。つまり、連続したメモリブロックに存在するということです。なぜなら、Xにインデックスを付ける際、多くのインデックス算術を行い、Xがメモリブロックに存在すると仮定するからです。そうでなければ、任意のレベルの一般性でこれを行うことは基本的に不可能です。
GLUを計算する際、入力Xを取り、Yを出力します。出力を割り当てる必要があります。torch.tensor Y equals torch.empty_like X。これは、Xの次元と同じような出力テンソル空間または出力テンソルへのポインタを与えてくださいと言っているだけです。zerosを呼び出していないことに注意してください。これは追加の操作を節約します。とにかくそれらに書き込むので、これらのyをゼロにする必要はありません。これは小さいですが、やってもよい最適化です。
基本的に私たちが書くすべてのコードで、グリッドを把握する必要があります。持っている要素の総数は何か?各ブロックのサイズは何か?ブロック内に持っているスレッドの数は何か?そして、総ブロック数は何か?ブロック数を把握する必要がある際、cdivを呼び出します。これは基本的にnum_elementsをblock_sizeで割った比率を取り、天井を取ります。なぜなら、block_sizeで割り切れない最後の要素のセットがまだ計算されることを確実にするため、切り上げる必要があるからです。切り下げではなく天井を取ります。
これはすべて非常にシンプルな簿記作業です。その後、カーネルを起動しましょうと言います。GLU kernelが起動されます。この角括弧は、与えられたブロック数と各ブロックのサイズでこれを言っています。これがカーネルコマンドに渡されます。そして、xとyへのポインタを渡します。実際にxとyの値を渡すのではなく、要素の総数を渡します。これを、カーネルの境界条件を基本的に計算するために必要です。
今、実際のカーネル自体に行きましょう。global void gelu kernelがあり、inとoutのポインタ、アイテム数の要素を取得します。このキーワードglobal、レンダリングが少し台無しにしていますが、これを_globalと考えるべきで、これはCUDAカーネル関数として区別するキーワードです。
何をしているのでしょうか?このスレッドは実際に単一の要素iに対して操作することになっています。しかし、iを入力として得られません。コードは実際に「あなたはベクトルの座標iにいます」と教えてくれません。どこにいるかを計算する必要があります。
どのようにそれを行うのでしょうか?ブロックインデックスを取ります。一次元しかないので、block_idx.xです。最初の座標だけです。それに各ブロックのサイズを掛けます。block_dim.x。これは基本的に現在のブロック内の開始点を教えてくれます。そして今、thread_idxを加えます。現在のブロックの開始がどこかを知っており、ブロック内のどこにいるかのオフセットを加えると、グローバル座標iが得られます。座標を得るためのいくつかの簿記計算です。
これも重要です。基本的に人々が書くすべてのCUDAコードでこのパターンを見ます。自然に境界チェックはありません。私が行うことは、座標を持っており、境界内の何かを処理することになっていることを確認するためにチェックします。ブロックの最後のスレッドのいくつかは、メモリ内の境界外のものを処理することになります。
それらにそれらに触れてほしくありません。基本的にi < num_elementsで条件付けます。その外側にいる場合は何もしません。
その中で、計算を行います。入力inがあります。i番目の要素にインデックスを付け、前と同じようにGLUを計算し、out[i]に割り当てて完了です。これがする必要があることのすべてです。これはすべてポインタのものなので、実際にここで何が起こっているかについてあまり心配する必要はありません。
6.5 境界条件チェックとメモリアクセスパターン
境界条件のチェックについて詳しく説明しましょう。これも重要です。基本的に人々が書くすべてのCUDAコードでこのパターンを見ます。自然に境界チェックはありません。私が行うことは、座標を持っており、境界内の何かを処理することになっていることを確認するためにチェックします。ブロックの最後のスレッドのいくつかは、メモリ内の境界外のものを処理することになります。
それらにそれらに触れてほしくありません。基本的にi < num_elementsで条件付けます。その外側にいる場合は何もしません。
講義中に質問がありました:連続でない場合はどうなるでしょうか。それはエラーを投げますか、カーネルを呼び出しますか。少なくとも私たちが書いたコードでは、assertなのでエラーを投げるでしょう。潜在的にそれを処理するコードを書くことはできますが、メモリが断片化される理由はほとんどありません。なぜなら連続して割り当てられるからです。本当にトリッキーなことをしていない限り、中間のメモリを割り当て解除することはないでしょう。
本当に高度なことをしていない限り、連続したメモリを持つことを期待すべきです。時々、転置やジャンプ操作を行って、メモリが連続でなくなることがあります。転置する場合、もはや連続でなくなります。列保存されているものを行トラバースしている場合、要素間でジャンプが発生します。
しかし、これは外側のラッパー部分で処理可能です。連続してインデックス付けされた何かを渡すことができます。多くの行列では、実際には気にしません。
もう一つの質問は、異なるブロックサイズを選択するとどうなるかでした。GPU関連の懸念が作用するでしょう。SMを飽和させるのに十分なブロックがあるか、各ブロック内に十分な作業があるかが、ここで重要になり得る2つのことです。しかし、1024のような比較的大きなブロックサイズでは、特定のポイントを過ぎればおそらく重要ではないと思います。なぜなら、高度なことは何もしていないからです。この非常にシンプルな例では、すべてエントリワイズ操作だからです。
6.6 性能結果:1.8msの達成
基本的にそれだけです。持っているCUDAのGELUコードを取り、このC++コードをインラインで読み込み、すべてをPython内でモジュールにコンパイルできます。すべて非常に素晴らしく便利です。実際にコマンドラインに出て何かを行う必要はありません。
これで、CUDA geluが定義されました。これはこれのコンパイルです。Python内からこれを呼び出すことができ、CバインディングでこのguyGuy呼び出すために使用します。
CUDA GLUを呼び出し終わりました。manual GLUとCUDA GLUが同じであることを確認できます。今、この2つをベンチマークしてみましょう。
PyTorchを実行するのにかかる時間があり、前と同じように約1.1ミリ秒です。manual timeは、覚えているように8.1ミリ秒です。それで、ドラムロール、私たちのCUDA timeは何でしょうか?それを1.8まで下げました。PyTorchの実装ほど良くはありませんが、PyTorch timeにかなり近づいています。8ミリ秒から1.8ミリ秒になったので、悪くありません。そのCコードはそれほど書くのが困難ではありませんでした。
今、いくつかのプロファイリングも行います。ここで何が起こっているかを見ることができます。GLU kernelと呼ばれ、これがGPUに送られたコードです。それからempty_likeがあり、これが初期化です。それからempty_stridedがあり、それからCUDA launch kernelとCUDA device synchronizeがあります。基本的にそれが起こっていることのすべてです。
再び、この単一のCUDAカーネルがGPU時間の100%を消費していることに注意してください。私たちが望んでいたものと同じです。
さらなる最適化を行うことはできますが、これは既にカーネル融合の問題を実際に解決しています。すべての演算子を一緒に融合させました。かなり良いです。この種の要素ワイズ操作はCUDAで書くのが簡単です。新しい種類の非線形性がある場合、本当に望むなら自分でCUDAカーネルを簡単に書くことができます。
より興味深い操作は、複数の値を読み取る必要がある、リダクションを行うようなものです。これらはもう少し複雑になります。Flash attentionは少し複雑になりますが、Assignment でそれを行う必要がある際、それほど複雑ではありません。
7. Tritonによるカーネル開発
7.1 Tritonの概要(OpenAI 2021年開発のDSL)
CUDAカーネルについて、それほど痛くありませんでしたが、GPUカーネルを書くためのより良いPython抽象化があればとても良いでしょう。これがTritonであり、Tritonは非常に素晴らしいものです。GPUについて文字通りすべてを管理する必要がない良い中間地点があります。
TritonはOpenAIによって2021年に開発されたドメイン特化言語で、GPUプログラミングをはるかにアクセシブルにします。すべてをPythonのような形で書き、もはやスレッドについて考えません。スレッドブロックについて考えます。Tritonは、面倒だが自動的に最適化できる多くのものを管理します。
メモリのコアレッシングを管理できます。VRAMから一度に4つの隣接する値をバーストモードと呼ばれるもので取得することを覚えているでしょう。メモリの取得が隣接する4要素以上の呼び出しにグループ化されることを本当に確実にしたいのです。それらを自動的に処理します。それらをグループ化します。
SM内で複数のスレッドがある際に、どのようなメモリに書き込んでいるかのような共有メモリ管理を行います。SM内の異なるスレッドがする必要があるかもしれないスレッドの停止や開始は、すべて自動的に管理されます。しかし、SM間でのスケジューリング、または異なるSMが行うことは手動です。
プログラミングモデルは、SM中心レベルで考え、コンパイラがより多くの低レベルの詳細を処理するということです。Tritonは、多くのPyTorch実装をかなり上回ることができるので非常に素晴らしいです。CUDAをすべて書くようなものですが、まだ非常に馴染みのあるPythonの土地にいます。
私が思うに非常に過小評価されている利点は、ここに書かれている通りです。すべてPythonです。ステップスルーできます。かなり素晴らしくデバッグできます。
7.2 CUDAとTritonの比較(Python風記法、自動最適化)
Tritonカーネルを見ていきましょう。再び、GLUを書きます。これをTritonで行います。私は、他のコードとできるだけ類似した構造になるように、このコードを書きました。これは、いわばCPU側のコード、ラッパーコードです。
torch テンソルであるXを取り、上部に2つのassertがあります。empty_likeを使用して出力テンソルYを割り当て、まったく同じ座標計算コンポーネントがあります。カーネルの起動も非常に似ています。num_blocksの注釈があり、block_sizeは括弧の一部ではなく最後にありますが、基本的にカーネルに同じ情報を渡しています。
今、Tritonカーネルは、ここのこのコードです。これは前にやっていたのと同じことを行いますが、今度はPythonで素晴らしく書かれています。ここでのメンタルモデルは、入力がx_pointer、y_pointerが出力ベクトル、開始座標、block_sizeが各ブロックがどのくらい大きいか、num_elementsが配列の最後になります。
今、557行目から561行目を取得する必要があります。これは私のインデックスの計算を行っています。前にi equals何らかの式を行いました。ここで同じ計算を行っています。現在のブロックの開始がどこかを計算しています。それは私のブロックIDにブロックのサイズを掛けたものです。
たとえば、ブロック1に住んでいるとしましょう、それはここの中間のこの点を取得します。その後、ブロック内のどこに住んでいるかを知る必要があります。それはオフセットになります。しかし、今、一つの違いに注意してください。オフセットを取得しません。なぜなら、スレッドをプログラミングしているのではなく、ブロックをプログラミングしているからです。
それはどういう意味でしょうか?私のオフセットは実際にはベクトルで、単一の値ではありません。なぜなら、これは基本的にベクトル化された操作を行うことになるからで、ベクトル化された操作は異なるスレッドによって処理されるでしょう。ここで私のオフセットは、ブロックの開始プラス、block_sizeオフセットのこの範囲のベクトルです。
私のオフセットは、ブロック1内のこれらすべての座標を一度にです。もちろん、最後にいる場合、端から外れる可能性があります。ベクトルの境界から外れて住んでいる可能性のあるものを処理するためにマスクが必要です。
今、単一のベクトル化された操作で、x_pointer plus offsetsですべてを一度にロードします。これらは私が担当する値で、マスクアップされており、Xにロードされます。これは私の内部値、必要な内部の一時ベクトルです。この一時ベクトルで、まったく古いGLU計算を行います。
tanhがないので、手動で計算しますが、この式は私たちがここに持っているものと同じであることを自分で確信できます。yは、ここで計算された式になります。完了したら、出力バッファまたは出力ベクトルに書き戻す必要があります。ターゲットを計算します。これはy_pointer plus offsetsです。一時的な値yを取り、それを保存します。
これは前にあったものと非常に似ていますが、これはベクトル化されたバージョンです。一度に全体のブロックで操作できます。スレッドの視点から考える代わりに、ブロックの視点から考えていますが、それほど異なってはいません。これはすべてかなり似たようなものです。
7.3 GLU Tritonカーネルの実装
Tritonカーネルの実装について詳しく見ていきましょう。@triton.jitデコレータを使用してカーネル関数を定義します。この関数は、x_pointer(入力)、y_pointer(出力)、block_size(各ブロックのサイズ)、num_elements(総要素数)を引数として受け取ります。
最初に、現在のブロックの開始位置を計算します。これはblock_id = tl.program_id(0)とblock_size_を掛け合わせることで求められます。次に、オフセットベクトルを作成します。これはoffsets = block_start + tl.arange(0, block_size)
という形で、ブロック内の各要素の位置を表します。
重要な点は、Tritonではスレッド単位ではなくブロック単位でプログラミングを行うということです。つまり、offsetsは単一の値ではなく、ベクトルです。これにより、一度に複数の要素を処理できます。
境界チェックのために、マスクを作成します。mask = offsets < num_elements
により、有効な要素のみを処理することを保証します。データの読み込みはx = tl.load(x_pointer + offsets, mask=mask)
で行い、ベクトル化された操作として一度に複数の要素を読み込みます。
GLUの計算では、tanhが直接利用できないため、手動で計算します。具体的には、gelu_part = 0.5 * (1.0 + tl.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x * x * x)))
という式を使用し、最終的にy = x * gelu_part
でGLUの結果を得ます。
最後に、計算結果を出力バッファに書き戻します。tl.store(y_pointer + offsets, y, mask=mask)
により、マスクを適用しながら結果を保存します。
このように、Tritonを使用することで、CUDAの複雑さを回避しながら、効率的なGPUカーネルを実装できます。Python風の記法により、コードの可読性と保守性が大幅に向上し、同時に高い性能を実現できます。
7.4 PTXアセンブリコードの解析(レジスタ使用、メモリコアレス、ベクトル化)
最後に行いたいクールなことの一つは、もちろんTritonはGPU用の低レベルのほぼマシンコードにコンパイルされるということです。私たちのTriton Geluカーネルが生成されたこの非常に低レベルなPTXコードと呼ばれるものを見ることができ、実際に非常にクールです。GPUが実際にスレッドレベルでどのように動作するかを見ることができます。
これはTritonコンパイラによって生成されたTriton Geluカーネルです。最初に、本当に基本的なものをいくつか行います。何をしているのでしょうか?値を保存する必要があると言っています。中間計算を保存する必要があります。
Bは実際に、基本的にバイトのような型なしのものを意味します。32ビットサイズのバイトが必要です。計算を行うためのfloatが必要で、fと呼ばれます。64ビットの別のレジスタセットが必要です。これは別のレジスタセットです。これらすべてのレジスタが、一時的な計算のために必要です。
ここから、基本的に座標を計算し始めます。申し訳ありませんが、この部分は関数への様々な引数をロードしています。x_pointerやy_pointerのようなものがここでロードされます。ここから、Tritonカーネルの座標オフセットを計算し始めます。
ここまで来ると、このld.global、これはx_pointerから私の一時的なレジスタに値をロードするために使用されるコードです。基本的に、RD1のメモリ位置を使用してR2、R3、R4、R5をロードすると言っています。一度に4つのものをロードしていることに注意してください。なぜなら、コアレッシングを巧妙に処理しているからです。4つの値を無料で取得できることを知っています。
これら4つの値すべてを一度に操作すべきです。なぜなら、それらを取得するからです。同じことを再び行います。ここで再び同じことを行います。そして、基本的に浮動小数点演算mul.f32を取得し始めます。これは基本的にtanh計算を通過します。
すべての異なる部分を説明するつもりはありませんが、ここで定数を掛けています。x cubeのようなことを行っています。同じ数値を複数回掛けます。ここで、2^xを計算しますが、e^xが欲しいです。指数化されたベースを取得するためにlog 2を掛けます。GPUが最終結果を得るために行う、すべての異なる文字通りのステップバイステップの操作を本当に見ることができます。
すべてをスキップして最後に行きます。これは、必要な浮動小数点計算のすべてです。最後に、R38からR41の値を、出力のメモリ位置であるRD4に保存します。これは低レベルで実際に起こっていることの一種です。各スレッドが一度に4つの値で動作し、その一時的な保存がレジスタであり、本当に本当に高速なローカルに持っている保存であることが分かります。
これを見て、これはおそらくかなり高速なコードになるだろうと分かります。
7.5 性能結果:1.848ms(CUDA実装と同等)
これでPTXは終わり、あらゆる種類のものについて何をしているかを見て行くことができます。しかし、今度は実際に戻ってベンチマークを行いましょう。
manual GLU 8.1秒、PyTorch time 1.1秒、CUDA time 1.84秒、Triton time 1.848秒を得ました。より速くはなりませんでしたが、Tritonコードを書くのがはるかに簡単でした。Pythonで書きました。ブロックについて考えました。ベクトル化された加算を行うことができました。より洗練されたことを行っている場合、基本的にTritonは多くのメモリ関連のことを処理してくれます。
実際にかなり良いです。プロファイリングを再び行うと、すべてのGPU時間を消費する単一のカーネル起動が見られます。それは素晴らしいです。それがTritonカーネルです。
私が話したいこの最後のものは、少なくともこの種の部分で、torch compileです。もちろん、CUDAカーネルを書くのはクールで、とても良い気分にさせてくれます。しかし、それを行う必要がないかもしれません。ここで行っていたことは非常にシンプルでした。x cubedや指数演算のようなものを取って、それらすべてを単一のCUDAカーネルに押し込んでいただけです。
多くのことをせずに、それを行うことができるかもしれません。
8. Torch Compileによる自動最適化
8.1 JITコンパイルによる自動最適化
私が話したいこの最後のものは、少なくともこの種の部分で、torch compileです。もちろん、CUDAカーネルを書くのはクールで、とても良い気分にさせてくれます。しかし、それを行う必要がないかもしれません。ここで行っていたことは非常にシンプルでした。x cubedや指数演算のようなものを取って、それらすべてを単一のCUDAカーネルに押し込んでいただけです。
多くのことをせずに、それを行うことができるかもしれません。いくつかの異なる方法を示しましたが、最後に話したいのは、torch compileと呼ばれるもので、最適化されていないPyTorchコードを取り、より最適化されたコードを書くものです。
ここで、カーネル融合のような最適化を自動で試行することになります。このcompiled GLUは、生成する実際の出力において同等になります。
8.2 性能結果:1.47ms(手動実装を上回る性能)
今、実行時間を見てみましょう。いくつかの実行時間の変動がありますが、基本的に同じ種類の数値です。8.1秒のmanual、1.1秒のPyTorch、1.8秒、そしてtorch compileで1.47秒です。
ここでのパンチラインは、現代のJITコンパイラは非常に優秀だということです。多くのことを行うことなく、演算融合のような最適化を行うことができます。フードの下を見ると、基本的に再び起こるのは一つのことです。これは、基本的に私たちのTritonコードと同様の種類のことを行っているfused add multiply tanh Tritonコードです。しかし、実際には私たちが行ったものよりもわずかに最適化されています。
そのため、私たちのコードよりもわずかに良い性能を得ています。torch compileは非常に素晴らしいです。
8.3 自動生成されたTritonコードの分析
フードの下を見ると、基本的に再び起こるのは一つのことです。これは、fused add multiply tanh Tritonコードです。これは基本的に私たちのTritonコードと同様の種類のことを行っていますが、実際には私たちが行ったものよりもわずかに最適化されています。
torch compileは、フードの下でTritonを生成しています。私たちが手動で書いたTritonカーネルと本質的に同じような操作をしていますが、コンパイラによってより洗練された最適化が適用されています。自動生成されたコードは、私たちが手動で実装したものを上回る性能を示しており、これは現代のJITコンパイラの高度な最適化能力を証明しています。
そのため、私たちのコードよりもわずかに良い性能を得ています。これは、コンパイラが私たちが見落としていた最適化機会を発見し、それを適用した結果と言えるでしょう。
8.4 実用的な最適化戦略の議論
講義中に質問がありました。torch compileがうまく機能しない場合をどのように感じるか、実装したい独自バージョンのように、flash attentionなどはできないのではないかという質問でした。
質問を言い換えると、torch compileより良くできることをいつ知るかということだと思います。演算融合や最適化すべき行列乗算のような単純なものに対して、torch compileは非常に優れています。前に述べたように、torch compileは行列の形状を知っている場合、どのカーネルを送るかを把握できるような最適化を行うことができます。それらのことについて、それより良くできるとは疑問です。
しかし、Flash Attention 1、2、3のようなものがあります。これらは非常に非自明な最適化です。最近、torch compileやJaxのXLAコンパイラのようなものでもそれらを行うことができますが、それは後知恵でそれらが行うべき正しい最適化であることを知っているからです。
これらの最適化のいくつかは、JITコンパイラで把握するのは少し非自明だと思います。Flash Attention 3には、H100ハードウェアを活用する追加のハードウェアレベル最適化があり、JITコンパイラでは明白ではありません。
そのため、torch compileでは非常に困難だと思うものもありますが、それでも良くできると思います。しかし一般的に、ここでのポイントは、家に帰って、「言語モデルのすべての部分にCUDAカーネルを書くぞ」と言うべきではないということです。それはおそらく時間の良い使い方ではないでしょう。
しかし、複雑な部分を持つ新しいアーキテクチャを書いていて、利用率を得られていないが得られると思う場合、それはTritonを本当に取り出す時かもしれません。
9. Softmax実装による応用例
9.1 リダクション操作を含むより複雑なカーネルの実装
時間的には基本的に終わりですが、Tritonの最後の例を素早く見ることができます。Assignment 2で役に立つかもしれない、softmaxを行うことです。
一つの違いは、今まで基本的な要素ワイズ操作だけを行っていたということです。これは本当に簡単です。なぜなら、各要素に対して操作するだけで、そのような種類のものには複雑さがないからです。今度は、リダクション操作を持つsoftmaxを行いましょう。すべての要素を加算する必要があります。
どのようにそれを行うのでしょうか?やりたいことは、行列の各行を正規化することです。これを高速にしたいと思います。これの素朴なバージョンはかなり遅くなるでしょう。
今、Tritonカーネルを書きます。怠惰になりたい場合、これを行う最も簡単な方法について、少し考えてみてください。softmaxを書きたいとしましょう。行列の各行を正規化し、これらの行列がかなり小さいと想像してください。小さな行列のためのカーネルを書いているだけです。
これを行っている場合、正しいブロック設計は何でしょうか?おそらく私たちがすべきことは、グリッドを実際に行にすることです。各SMが単一の行を処理します。行全体をSMに収めることができれば、SM内でその行を合計し、除算するだけです。それは素晴らしいでしょう。
これが私たちの非常に素朴なsoftmaxカーネルのシンプルな設計になります。私たちが行うことは、各ブロックを行にすることです。ブロックサイズは基本的に列数プラス、すべての列を収めることができるように少しのバッファにすべきです。これはTriton next power of two of nです。これは列をパディングする良い方法です。
各ブロックを行にします。ブロック数は正確に行数です。
9.2 ブロック設計戦略(行単位での処理)
これを行っている場合、正しいブロック設計は何でしょうか?おそらく私たちがすべきことは、グリッドを実際に行にすることです。各SMが単一の行を処理します。行全体をSMに収めることができれば、SM内でその行を合計し、除算するだけです。それは素晴らしいでしょう。
これが私たちの非常に素朴なsoftmaxカーネルのシンプルな設計になります。私たちが行うことは、各ブロックを行にすることです。ブロックサイズは基本的に列数プラス、すべての列を収めることができるように少しのバッファにすべきです。これはTriton next power of two of nです。これは列をパディングする良い方法です。
各ブロックを行にします。ブロック数は正確に行数です。
この設計の利点は、行のデータがすべて一つのSM内の共有メモリに収まることです。これにより、行内の要素間でのリダクション操作(最大値の計算、合計の計算)を効率的に行うことができます。SM内でのメモリアクセスは非常に高速であり、グローバルメモリへの頻繁なアクセスを避けることができます。
各SMが一つの行を担当するため、並列性も確保されます。複数の行を同時に処理でき、行列全体のsoftmax計算を効率的に実行できます。この設計パターンは、行単位での処理が必要な多くの操作に適用できる基本的なアプローチです。
9.3 Triton Softmaxカーネルの詳細実装
Tritonのsoftmaxカーネルがあり、期待される方法で書かれています。今度は、ベクトルではなく行列があります。x_pointersとy_pointersがあり、行列のストライドが必要です。そして、基本的に前と同じ種類のコードです。実際、行オフセットを取得するのはより簡単です。なぜなら、各行がブロックだからです。
どの行インデックスにいるかを把握できます。列オフセットを取得できます。これは前と同じ種類のコードになります。各行を私のSMのローカルメモリに読み込みます。そして、softmaxのように見える方法で計算を行います。
行があります。最大値を減算します。指数を取ります。合計し、除算します。これにより、softmaxで正規化された行が得られ、グローバルメモリに書き戻します。複雑さは全くありません。
計算がSMに適切に収まる際はいつでも、Tritonコードを書くことは、少しのロードとストア、ブロックがどこにあるかを追跡することを除いて、通常のPythonコードを書くことと非常に似ています。生活はかなりシンプルです。
具体的には、まずrow_idx = tl.program_id(0)
で現在処理している行のインデックスを取得します。列オフセットはcol_offsets = tl.arange(0, BLOCK_SIZE)
で計算します。行の開始ポインタをx_pointer + row_idx * x_row_stride
で求め、tl.load()
を使用して行データをロードします。
softmaxの計算では、まずtl.max()
で行の最大値を求め、数値安定性のためにすべての要素から減算します。次にtl.exp()
で指数を計算し、tl.sum()
で合計を求めます。最後に各要素を合計で除算して正規化し、tl.store()
で結果を書き戻します。
9.4 性能比較:手動実装(3.7ms)vs Torch Compile(1.3ms)vs PyTorch(1.5ms)vs Triton(1.9ms)
戻って、すべての異なるコードがどれくらい速いかを見ることができます。再び、ズームアウトして確認します。manual timeは3.7秒かかります。compile timeは1.3秒でtorch compileです。PyTorch timeは1.5秒です。Triton timeは1.9秒です。まだ少し遅いです。
torch compileは、特に操作の形状とサイズを知っている場合、実際にネイティブのPyTorch実装よりも良くできます。
最終的に、プロファイラーで見ることができます。manual softmaxは災害のようなものです。あちこちで起こっているあらゆる種類のクレイジーな操作が見られます。戻って上に行くと、あらゆる操作が起こっています。x、max、sumがあります。なぜなら、物事を素朴に実装し、あちこちでメモリの読み書きがあるからです。
compiled softmaxは、非常に速い一つの融合softmax操作になります。そして、pytorch softmaxもあり、これも一つのCUDA kernel呼び出しです。私たちのTriton softmaxと同じです。すべてのために素晴らしい単一の融合カーネルがあります。
この性能比較から分かることは、手動で複数の操作を組み合わせた実装は、複数のカーネル起動とメモリ転送のオーバーヘッドにより大幅に性能が劣化することです。一方、torch compileは形状情報を活用した高度な最適化により、ネイティブ実装を上回る性能を実現しています。Tritonによる手動実装も単一の融合カーネルとして動作し、合理的な性能を示していますが、torch compileの自動最適化には及ばない結果となっています。
10. 実験結果と知見
10.1 カーネル融合の劇的な効果:単純な手動実装で8倍の高速化
この講義を通して得られた最も重要な知見の一つは、カーネル融合が性能に与える劇的な影響です。GLUの実装において、PyTorchでの素朴な手動実装(0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))))は8.1ミリ秒かかりましたが、PyTorchのネイティブ実装では1.1ミリ秒と、8倍もの性能差が生じました。
この性能差の根本的な原因は、手動実装では複数の個別のCUDAカーネルが起動されることです。tanh、x cubed、定数との乗算、加算、0.5とxとの乗算など、それぞれが独立したカーネルとして実行され、各操作の間でGPUのグローバルメモリとの間でデータの読み込みと書き込みが発生します。プロファイリング結果からも、同じCUDAカーネルが3回呼び出されるなど、大量のメモリ転送オーバーヘッドが確認できました。
一方、PyTorchのネイティブ実装や私たちが書いたCUDA/Tritonカーネルでは、すべての操作が単一のカーネル内で実行されます。これにより、中間結果はGPUのレジスタや共有メモリに保持され、グローバルメモリへの読み書きが最小限に抑えられます。
この結果は、前に説明した工場と倉庫の比喩を完璧に実証しています。操作ごとに倉庫(グローバルメモリ)から工場(計算ユニット)への配送コストを支払うのではなく、すべての操作を一つの工場で行うことで、配送コストを一度だけ支払えばよいのです。この基本原理は、言語モデルの最適化における最も重要な概念の一つであり、理論的な理解と実際の性能向上が直結した結果と言えるでしょう。
10.2 Torch Compileの実用性:多くの場合で手動実装を上回る性能
この講義で最も印象的だった発見の一つは、torch compileの優秀さです。GLUの実装において、私たちの手動CUDA実装が1.8ミリ秒、Triton実装が1.848ミリ秒だったのに対し、torch compileは1.47ミリ秒を達成し、手動実装を上回る最高性能を示しました。
Softmaxの実装でも同様の傾向が見られました。手動実装が3.7ミリ秒、Triton実装が1.9ミリ秒、PyTorchネイティブが1.5ミリ秒だったのに対し、torch compileは1.3ミリ秒で最速でした。
torch compileがこのような優秀な結果を出す理由は、フードの下で自動的にTritonコードを生成し、私たちが手動で書いたものよりもさらに最適化されたコードを作成するからです。実際に生成されたコードを見ると、fused add multiply tanh Tritonカーネルとして、私たちのTriton実装と本質的に同じことを行いながら、より洗練された最適化が適用されていることが分かりました。
現代のJITコンパイラの能力は非常に高く、演算融合のような最適化を開発者が多くの労力をかけることなく自動的に実行できます。特に、行列の形状やサイズが既知の場合、torch compileは適切なカーネルを選択し、ハードウェア特性に応じた最適化を行うことができます。
この結果が示す実用的な教訓は明確です。家に帰って「言語モデルのすべての部分にCUDAカーネルを書こう」と考えるべきではありません。それはおそらく時間の良い使い方ではないでしょう。代わりに、torch compileのようなツールを第一選択肢として使用し、手動実装は本当に特殊な最適化が必要な場合に限定すべきです。開発効率と性能の両方を考慮すると、torch compileは非常に実用的で効果的な最適化手法と言えるでしょう。
10.3 プロファイリングの必要性:推測ではなく測定に基づく最適化
この講義を通して一貫して強調された最も重要な原則は、高性能コードを書く際には推測ではなく測定に基づいて最適化を行うべきだということです。高性能コードを書きたいなら、コードをベンチマークしプロファイリングすることを忘れてはいけません。これは非常に明白に聞こえるかもしれませんが、学生や人々が「これがボトルネックだと思うから、3時間かけて最適化しよう」と考える場面を多く見てきました。しかし結果的に、それは全くボトルネックではなかったということがよくありました。
実際に高性能な詳細プロファイラを使用すれば、ボトルネックが正確にどこにあるか、マシンが実際に何をしているのかを正確に見ることができます。MLPの例では、torch profilerの結果だけでは31%の時間しか説明できず、残りの60%がどこに費やされているかが不明でした。しかし、NVIDIA Nsight Systemsを使用することで、CPU-GPU間の非同期実行、初期化コスト、メモリ割り当てパターンなど、詳細な実行フローを可視化できました。
プロファイリングによって発見された興味深い事実の一つは、単純なprint文がCPU-GPU同期に与える影響でした。print文がない場合、CPUはGPUより最大1ステップ分先行して実行していましたが、print文を追加すると、lossの値を取得するためにCPU-GPU間の同期が必要となり、実行パターンが劇的に変化しました。このような微細な影響は、プロファイラなしには発見できませんでした。
C距離計算のプロファイル分析では、GPU時間の78%が行列乗算、6%がコピー、5%がべき乗、3%が和に費やされていることが明確に示されました。このような詳細な内訳により、最適化の優先順位を正しく設定できます。行列乗算を最適化できれば大きな効果が期待できますが、和の最適化に時間を費やしても限定的な効果しか得られないでしょう。
理論には限界があります。ルーフラインモデルなどについて考えることはできますが、行列乗算がどれくらい速いかということになると、ライブラリのバージョンやハードウェア、どの部分がどのような理由でボトルネックになっているかなど、あらゆる種類のマイクロコードの事柄があり、完全には分からないのです。そのため、最終的にはこれらのものを開発する際には、エンドツーエンドのベンチマークを行わなければなりません。
プロファイリングをすべきだという事実は、ツールが何であれ、本当に変わることはないでしょう。システム開発において、測定駆動のアプローチを内在化することが、効果的な最適化を行うための基本原則なのです。
10.4 開発効率 vs 性能のトレードオフ:Tritonの中間的な位置づけ
この講義を通して、異なる実装手法の開発効率と性能のトレードオフが明確に示されました。Tritonは、CUDA(高性能・高難度)とPyTorch(低性能・簡単)の間の理想的な中間解として位置づけられることが実証されました。
CUDAカーネルの実装では、グリッド、ブロック、スレッドの概念を理解し、座標計算、境界チェック、メモリアクセスパターンなど、多くの低レベル詳細を手動で管理する必要がありました。GLU CUDAカーネルでは、blockIdx.x * blockDim.x + threadIdx.x
による座標計算、i < num_elements
による境界チェック、連続メモリレイアウトの確認など、エラーが発生しやすい多くの要素を扱わなければなりませんでした。
一方、Tritonでは同じ機能をPython風の記法で実装できました。tl.program_id(0)
、tl.arange(0, block_size)
、tl.load()
、tl.store()
といった直感的なAPIにより、低レベルの詳細を抽象化しながら高性能を実現できました。特に重要なのは、Tritonがメモリコアレッシング、共有メモリ管理、スレッド同期などを自動的に処理してくれることです。
性能面では、GLU実装においてCUDA(1.8ms)とTriton(1.848ms)がほぼ同等の結果を示しました。Softmax実装でも、Triton(1.9ms)は手動実装(3.7ms)を大幅に上回り、PyTorch(1.5ms)に近い性能を達成しました。これは、開発効率を大幅に向上させながら、性能の妥協を最小限に抑えられることを意味します。
Tritonの真価は、より複雑な最適化が必要な場面で発揮されます。Flash Attentionのような非自明な最適化では、CUDAレベルの制御が必要になることもありますが、多くの実用的なカーネルではTritonで十分な性能を得られます。デバッグの容易さも重要な利点で、Pythonでステップ実行でき、従来のCUDAデバッグの困難さを回避できます。
この結果は、新しいアーキテクチャで複雑な部分を持ち、利用率を得られていないが得られると思う場合、Tritonを本当に取り出す時だということを示しています。開発者は、プロトタイピングから本格的な最適化まで、段階的にアプローチを選択できる柔軟性を得られるのです。
10.5 実装戦略の指針:一般的な操作はTorch Compile、特殊な最適化が必要な場合のみ手動実装
この講義で得られた実験結果から、実用的な最適化戦略についての明確な指針が浮かび上がりました。torch compileの卓越した性能(GLUで1.47ms、Softmaxで1.3ms)は、多くの場合において手動実装(CUDA 1.8ms、Triton 1.848ms)を上回り、開発コストゼロで最高の性能を実現することを実証しました。
torch compileが特に優秀な領域は、演算融合や既知の形状に対する最適化です。行列乗算については、torch compileには実際にハードウェア上で行列乗算性能をマイクロベンチマークするオプションがあり、その後モデル用の最高性能の行列乗算サブルーチンを実際に選択します。これにより、10%程度の無料の速度向上が得られることも確認されています。このようなことを最適化することで、実際に現実世界で無料の利得が得られるのは非常にクールです。
一方、手動実装が依然として必要な領域も存在します。Flash Attention 1、2、3のような非常に非自明な最適化がその例です。最近では、torch compileやJaxのXLAコンパイラのようなものでもそれらを行うことができますが、それは後知恵でそれらが行うべき正しい最適化であることを知っているからです。Flash Attention 3には、H100ハードウェアを活用する追加のハードウェアレベル最適化があり、JITコンパイラでは明白ではありません。
実用的な判断基準として、複雑な部分を持つ新しいアーキテクチャを書いていて、利用率を得られていないが得られると思う場合、それはTritonを本当に取り出す時です。しかし、家に帰って「言語モデルのすべての部分にCUDAカーネルを書くぞ」と言うべきではありません。それはおそらく時間の良い使い方ではないでしょう。
効果的なリソース配分戦略は以下のようになります:
第一選択肢として、torch compileによる自動最適化を試す。これにより、多くの場合で手動実装を上回る性能を開発工数ゼロで得られます。
torch compileで十分な性能が得られない場合、または非常に特殊な最適化パターンが必要な場合のみ、Tritonによる手動実装を検討する。
CUDAレベルの実装は、Tritonでも表現できない極めて特殊な最適化が必要な場合に限定する。
この階層的なアプローチにより、開発効率と性能の両方を最適化し、限られた開発リソースを最も効果的に活用できるのです。