※本記事は、Stanford CS336「Language Modeling from Scratch」Spring 2025コースの「Parallelism 2」講義の内容を基に作成されています。この講義は、Stanford Onlineを通じて提供されているコンピュータサイエンス教育プログラムの一部です。講義の詳細情報およびコース全体のプレイリストは、Stanford CS336の公式サイト(https://stanford-cs336.github.io/spri... )でご覧いただけます。本記事では、講義の内容を要約しております。なお、本記事の内容は講義内容を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や文脈については、オリジナルの講義動画をご視聴いただくことをお勧めいたします。
講師紹介
- Percy Liang: スタンフォード大学コンピュータサイエンス准教授、Center for Research on Foundation Models (CRFM) ディレクター
- Tatsunori Hashimoto: スタンフォード大学コンピュータサイエンス助教授
Stanford Onlineは、スタンフォード大学工学部のCenter for Global & Online Education (CGOE)によって運営・管理されており、学位プログラム、単位認定教育、専門資格プログラム、無料・オープンコンテンツなど、幅広い教育機会を世界中に提供しています。詳細については https://online.stanford.edu/ をご参照ください。
1. イントロダクション
1.1 複数GPU並列処理の概要
Percy Liang:これは今週2回目のシステム講義で、我々が持っているハードウェアを最大限活用してモデルの学習を高速化する方法について話します。先週は単一GPU内での並列処理について話しましたが、今週は複数GPU間での並列処理について扱います。
皆さんには、この図を頭に描いてもらいたいと思います。我々は複数のノードを持っており、これらは基本的にそれぞれ複数のGPUを持つコンピュータで、通常は8つのGPUを搭載しています。各GPU内には、実際に作業を行う多数のStreaming Multiprocessor(SM)があります。ここで緑色で示されているのは、本質的にメモリと通信部分です。
各SM内には非常に小さなL1キャッシュがあり、GPU上にはより大きなHigh Bandwidth Memory(HBM)があります。そして、異なるGPU間を接続するNVLinkがあります。重要な点は、計算はこれらのALU上のSM内で実行される必要があるということです。計算には入力が必要で、出力を書き込む必要があります。一般的に、入力と出力は比較的遠くに配置される可能性があります。運が良ければL1キャッシュにあり、そうでなければHBMにあります。
今週我々が話している複数GPUおよび複数ノード学習では、必要なデータが別のGPU上にある可能性があります。ゲームの名前は、データ転送のボトルネックを避けるために、すべての計算をどのように構造化するかということです。なぜなら、算術強度を高く保ち、GPUを飽和状態にして効率的に動作させたいからです。一般的に、データ転送は大幅に遅くなるため、それがボトルネックになってしまいます。
先週、我々はfusionやtilingを含む、GPU内でこれを実現するための様々な技術を見ました。基本的なアイデアは、HBMから読み書きする代わりに、L1キャッシュや共有メモリ(同じ速度を持つ)にロードし、ローカルのスクラッチパッド上で作業してから、慎重にHBMに書き出すということでした。
今週は、GPU間およびノード間での通信を見始めており、モデルとパラメータ、最適化状態を複製および分割する必要があります。これをどのように行うかがコストを決定します。
1.2 ハードウェア階層とメモリ構造
Percy Liang:ここで、私は少し自由を取って、すべてを一種の階層に配置しています。小さく高速なものから大きく低速なものまでと考えることができます。
最も小さく最も高速なのは、単一ノード単一GPU上のL1キャッシュで、これは非常に高速ですが非常に小さいものです。次に、単一GPU上のHBMがあります。そして、同一ノード上のGPU間にはNVLinkがあり、最後にNVSwitchがあります。もちろん、これはすべてNvidiaエコシステム内での話です。
このアイデアは、データ転送を最小化する多くの中核概念が実際には同じであるということです。しかし、今やメカニクスが少し異なります。なぜなら、L1はこれらのNVスイッチとは異なる動作をするからです。
この講義は主に、前回の講義の概念をコード内で具体化することについてです。いくつかの新しいことがありますが、Tatsuは異なるタイプの並列処理についての優れた概要を提供してくれました。私はそれをコードに定着させて、何が起こっているのかをより深く理解しようと思います。そして、この標準出力ファイルを参照します。これは、この講義を実行した出力です。
マルチプロセッシングを使用する場合、このフレームワークがうまく機能しないという軽微な問題がいくつかありましたが、それについては省略します。この講義には2つの部分があります。パート1では、構成要素である集合操作について見ていきます。これらは前回議論したもので、NCCLとPyTorchでどのように実装されているか、そしていくつかのベンチマークを行います。パート2では、実際の分散学習であるデータ、テンソル、パイプライン並列処理について見ていきます。
1.3 データ転送ボトルネックの課題
Percy Liang:データ転送のボトルネックを避けるための課題について説明しましょう。我々が記憶しておきたいのは、算術強度を高く保つということです。GPUを飽和状態にして、効率的に動作させたいのです。一般的に、データ転送は大幅に遅くなるため、それがボトルネックになってしまいます。
先週、我々はGPU内でこれを実現するためのfusionやtilingを含む様々な技術を見ました。基本的なアイデアは、HBMから読み書きする代わりに、L1キャッシュや共有メモリ(これは同じタイプの速度を持つ)にロードし、ローカルのスクラッチパッド上で作業してから、慎重にHBMに書き出すということでした。
今週は、GPU間およびノード間での通信を見始めており、モデルとパラメータ、最適化状態を複製および分割する必要があります。これをどのように行うかがコストを決定することになります。
この階層において、データ転送を最小化する多くの中核概念は実際には同じですが、メカニクスが少し異なります。なぜなら、L1はこれらのNVスイッチとは異なる動作をするからです。この理解が、効率的な並列処理を実現するための基盤となります。
2. 集合通信操作(Collective Operations)
2.1 集合通信の基本概念と用語
Percy Liang:集合操作について始めましょう。集合操作は、分散プログラミングに一般的に使用されるプリミティブです。集合的というのは、多くのノードを持っているということを意味します。これらは実際には、少なくとも1980年代の並列プログラミング文献から存在する、かなり古いものです。
一般的に、これらは自分でpoint-to-point通信を管理しようとするよりも、より良い抽象化を提供します。これらは時の試練に耐えた、本当に試され信頼されたプリミティブです。
少し用語について説明します。World sizeは本質的にデバイスの数を指します。例えば4つです。Rankは、線形代数に慣れている場合は混乱するかもしれませんが、実際にはデバイスを指します。4つのデバイスがある場合、rank 0、rank 1、rank 2、rank 3があります。
集合操作は以下の通りです。Broadcastは、一つのrank上にt0があり、それを他のすべてのrankまたはすべてのrank上に配置したい場合です。これは非常に分かりやすいです。
Scatterは似ていますが、4つの値があり、それぞれの値を異なるrank上に配置したい場合です。各rankは同じ値ではなく、異なる値を取得します。Gatherは、各rankが異なる値を持っており、それらをすべて一つのrank上にまとめるという、scatterの逆のようなものです。
Reduceは、連結する代わりに加算することを除いて、gatherと同じです。All gatherは、すべてのdestinationに対して行うことを除いて、gatherと同じです。Gatherは単にrank 0、rank 1、rank 2のような個別のrankでしたが、all gatherはすべてに対して行います。
最後に、reduce scatterについては良い図が見つからなかったので、前回のものを再利用しています。これは、多くの異なる値を取り、加算や他の可換演算を実行して、一つのrank上に配置するreduceのようなものです。しかし、scatterのように、ベクトルやテンソルの異なる部分を異なるrank上に配置します。
覚えておいてください、all reduceはreduce plus all gatherと等価です。この用語を覚える方法は以下の通りです。all gatherがどれで、reduce scatterがどれかというのは混乱しがちですが、reduceは単に合計、最小値、最大値、平均などの結合的で可換的な演算を実行することを意味します。Broadcast、scatterはgatherの逆であり、allは単にすべてのdestination、つまりすべてのデバイスを意味します。
これは前回の復習になることを願っています。これらのプリミティブに基づいて構築していくので、質問はありますか?みんなが理解していると有用です。
2.2 基本的な集合通信操作の種類
Percy Liang:集合通信操作の種類について詳しく説明しましょう。まず、Broadcastから始めます。これは一つのrank上にt0があり、それを他のすべてのrankまたはすべてのrank上に配置したいというものです。これは非常に単純明快です。
Scatterは似ていますが、4つの値があり、それぞれの値を異なるrank上に配置したい場合です。各rankは同じ値ではなく、異なる値を取得します。
Gatherは、Scatterの逆のようなもので、各rankが異なる値を持っており、それらをすべて一つのrank上にまとめるものです。
Reduceは、連結する代わりに加算することを除いて、Gatherと同じです。
All gatherは、すべてのdestinationに対して行うことを除いて、Gatherと同じです。Gatherは単にrank 0、rank 1、rank 2などの個別のrankでしたが、All gatherはすべてのrankに対して行います。
最後に、Reduce scatterについては良い図が見つからなかったので、前回のものを再利用しています。これは、多くの異なる値を取り、加算や他の可換演算を実行するReduceのようなものですが、一つのrank上に配置します。しかし、Scatterのように、ベクトルやテンソルの異なる部分を異なるrank上に配置します。
重要なことを覚えておいてください。All reduceは、Reduce plus All gatherと等価です。
この用語を覚える方法は以下の通りです。All gatherがどれで、Reduce scatterがどれかというのは混乱しがちですが、Reduceは単に合計、最小値、最大値、平均などの結合的で可換的な演算を実行することを意味します。Broadcast、ScatterはGatherの逆であり、Allは単にすべてのdestination、つまりすべてのデバイスを意味します。
これが前回の復習になることを願っています。これらのプリミティブに基づいて構築していくので、みんなが理解していると有用だからです。
2.3 PyTorchでの集合通信の実装例
Percy Liang:PyTorchでの集合通信の実装例を見てみましょう。私が書いたユーティリティ関数があります。コードを見たい場合は確認できますが、これは関数を取り、基本的にPython multiprocessingのラッパーで、この関数を実行する4つのプロセスを実行するものです。
この関数内にいるとき、実際にはworld size個のプロセスがこの同一の関数を実行していると考えるべきです。rankは0から1、world size - 1まで順番に索引付けされています。今、私は講義が並列ではないので、rank の一つを順を追って説明しています。
一般的に最初に行うことは、プロセス自体を初期化する必要があることです。本質的に、複数のプロセスを実行しているので、それらが互いを見つける必要があります。単一のホストに接続して、互いの存在を知る必要があります。これはすべてのデータが行く場所ではないことに注意してください。データはNCCLを通って行きますが、これは調整用です。GPUがあるのでNCCLを使用できます。そうでなければGlooを使用します。
設定後、いくつかのことを行います。Barrierという便利な関数があります。これは基本的に、プロセスグループ内のすべてのプロセスがこの点に到達するまで待機します。すべてが非同期で実行されているので、場合によっては同期点が必要で、Barrierがそれを行います。私がここに配置した理由は、実際には些細な理由で、これらのprint文をグループ化したいからですが、後で説明するように、Barrierを使用する他の理由もあります。
各グループに対してテンソルを構築します。テンソルは0、1、2、3プラスrankです。各rankについて、All reduceの前に何のようになっているかを出力します。
ここで結果を見てみましょう。後ろの人は読めますか?はい、良いです。Rank 0では0、1、2、3です。Rank 1では1、2、3、4といった具合です。非同期なので、順序は印刷される順序によってバラバラになっていることに注意してください。
各rankは異なるテンソルを持っており、それからAll reduceを行います。All reduceでは、そのテンソルを渡し、合計したいと言います。この場合、非同期にはしませんが、通信と計算をオーバーラップするのに有用な非同期も可能です。
その後、All reduceの後に何が起こるかというと、広告通り、最初のコンポーネントについて、それらを加算すると6を得ます。次に10、14、18を得ます。All reduceの後、このテンソルは対応する合計で上書きされます。使用するのは非常に素晴らしくシンプルです。
次に、Reduce scatterを行いましょう。Reduce scatterでは、world sizeの次元を持つ入力を作成します。この場合は4です。出力を割り当てます。なぜなら、Reduce scatterはインプレースで操作せず、これは単なるスカラーになるからです。
Reduce scatterの前に、これが見た目です。以前のように入力があり、出力は初期化していないので偶然0ですが、任意の値である可能性があります。Reduce scatterの後、入力と出力を渡し、合計します。すると、基本的に最初のコンポーネントについて合計し、それがrank 0に行き、2番目のコンポーネントについて合計し、それがrank 1に行く、といった具合になります。
ご覧のように、これはAll reduceと同じ操作を生成していますが、出力が異なるrank間で分散されている点が異なります。
次に、All gatherを行いましょう。Reduce scatterの出力を直接入力として使用し、出力用に空の配列を割り当てます。All gatherの前に、入力はこれで、出力は任意の値です。All gatherを行った後、これらのテンソルがすべてのデバイス上に現れます。
これも一例です。うまくいけば、Reduce scatter plus All gatherがAll reduceと同じであることを確信してもらえるでしょう。なぜなら、All reduceで計算したのとまったく同じ量を計算したからです。
質疑者:Reduce scatterでは、どのGPUを追跡していますか?
Percy Liang:Reduce scatterで、どのインデックスがどのGPUに行くかを追跡するかという質問ですね。慣例により、次元性は基本的にworld sizeである必要があります。一般的なテンソルも可能ですが、次元の一つがworld sizeで、基本的に対応する計算が各出力に行くことを推測します。次元性が適切に整合するように少し注意が必要です。このような小さな例を通して行うことが役立ちます。
この実行中のプロセスで、完了したら単にクリーンアップします。
3. ハードウェアとソフトウェアの実装
3.1 従来のハードウェア構成とその限界
Percy Liang:これまで集合操作について話し、PyTorchでの実装について説明してきました。次に、実際にこれがハードウェアからどのように実装されているかを見てみましょう。
ここに、GPUのハードウェアが古典的にどのように見えるかを示します。これは一般的な家庭のコンピューターのようなもので、CPUがあり、通常はPCIeバス経由で通信する一つのノード上にGPUがあります。
異なるノード間で通信する必要がある場合、これはすべてEthernetに接続されています。これは、ゲーム用などでGPUを購入した場合の、典型的なマシンの構築方法です。
ご覧のように、これは最適ではありません。なぜなら、データをGPUからGPUに送信する必要がある場合、大きなオーバーヘッドがあるからです。データはカーネルを通過し、バッファにコピーされ、その後Ethernet上でのこの種の転送を通過する必要があり、これが多くのオーバーヘッドを導入します。
現代の科学計算と深層学習において起こったことは、複数のGPUを一緒に接続して何かを一緒に行うことが分かっている場合、基本的にGPUを直接接続することです。Nvidiaエコシステムでは、GPUを直接接続するNVLinkがあり、これによりCPUをバイパスします。ホストマシンのカーネルを通る必要がありません。
ノード間でも、NVSwitch経由でGPUを直接接続できます。したがって、Ethernetをバイパスしています。Ethernetは長い間前に開発されたもので、明らかにこのタイプのアプリケーション用ではありませんでした。NVSwitchとNVLinkは、それらすべてをスキップして、我々が興味を持っているワークロードのタイプに直接最適化されています。
H100を見ると、各GPUには18個のNVLink第4世代があり、これで合計900ギガバイトの帯域幅が得られます。これらと比較すると、確実にPCIeよりもはるかに高速で、Ethernetと比較してもはるかに高速です。
SMから高帯域幅メモリへの読み取りのコストと比較すると、それでも約4倍程度高速です。もちろん、これらの数値は新しいBlackwellで常に変化しており、この数値は2倍から3倍程度多いと思います。
質疑者:PCIeについて、CPUを経由してから他のGPUに行くのか、それともGPUと直接使用するのですか?
Percy Liang:PCIeでデータがどのように転送されるかという質問ですね。CPUを通る必要があると思います。他に質問はありますか?PCIeは、サウンドカードやSSDハードドライブなど、他のものも接続されているように開発されました。したがって、デバイス通信のための汎用バスのようなものです。
質疑者:NVLinkもCPUとの接続がありますね。
Percy Liang:NVLinkもCPUに接続するという質問ですね。後でスライドで物事がどのように接続されているかを少し見ることになります。もちろんCPUと話す必要もありますからね。
3.2 NVLink とNVSwitchによる最適化
Percy Liang:実行できるコマンドがあります。これはいくつかの出力を生成し、GPUが実際にどのように接続されているかを見ることができます。私はこれを我々のクラスター上で実行しました。8つのGPUがあります。皆さんは8つのGPUを取得することはできないと思いますが、もしできたとしたら、このような感じになります。
すべてのGPUのペア間に、NVL18が接続されているのが見えます。また、これらの種類のネットワークカードや他のものもあります。ネットワークカードは基本的にPCIe接続とCPUを提供するものです。
これがハードウェアです。では、このハードウェアをどのように使用するのでしょうか?Nvidiaは、本当に優れたハードウェア上に、本当に優れたソフトウェアを開発するために多くの時間を費やしました。
NvidiaにはNCCLと呼ばれる集合通信ライブラリがあります。これは本質的に、All-reduceのような前に見た集合操作を、GPU間で送信される必要がある低レベルパケットに変換します。
このライブラリは実際に多くの作業を行います。なぜなら、プログラマーが「このテンソルをすべてのマシンに表示させる必要がある」というレベルで操作でき、それが単に実現されるからです。
NCCLを設定する際に何が起こるかを少し説明すると、多くのデバイスを立ち上げ、ハードウェアのトポロジーを把握するためのいくつかの通信が行われます。GPU間のパスを最適化し、実際にこれらの集合通信操作を呼び出すと、データを送受信するCUDAカーネルを起動します。
これがNCCLです。ライブラリとして提供されています。しかし、NCCLは我々にとってまだ少し低レベルすぎます。なぜなら、我々が行っていることの大部分はPythonだからです。そこで、PyTorchにはtorch distributedライブラリがあり、これらの集合操作に対してクリーンなインターフェースを本質的に提供しています。
PyTorchプログラムの快適さから、テンソル上でall_gather_into_tensorと書くだけで、それが異なるrank上に表示されます。また、異なるハードウェア用の複数のバックエンドをサポートする便利で有用な機能もあります。特に、NCCLはGPU用でしたが、集合操作を実行することもできます。
これはGPU特有のものではないことを覚えておいてください。任意のデバイスセット用です。したがって、Glooと呼ばれるバックエンドを使用してCPU用でも実行できます。例えば、課題のためにラップトップでデバッグしている場合、Glooを使用して、GPUなしでも実行できます。
3.3 NCCLライブラリとPyTorch Distributedの役割
Percy Liang:NCCLについてより詳しく説明しましょう。NCCLは本質的に、All-reduceのような前に見た集合操作を、GPU間で送信される必要がある低レベルパケットに変換します。
このライブラリは実際に多くの作業を行います。なぜなら、プログラマーが「このテンソルをすべてのマシンに表示させる必要がある」というレベルで操作でき、それが単に実現されるからです。
NCCLを設定する際に何が起こるかを少し説明すると、多くのデバイスを立ち上げ、ハードウェアのトポロジーを把握するためのいくつかの通信が行われます。GPU間のパスを最適化し、実際にこれらの集合通信操作を呼び出すと、データを送受信するCUDAカーネルを起動します。これがNCCLで、ライブラリとして提供されています。
しかし、NCCLは我々にとってまだ少し低レベルすぎます。なぜなら、我々が行っていることの大部分はPythonだからです。
そこで、PyTorchにはtorch distributedライブラリがあり、これらの集合操作に対してクリーンなインターフェースを本質的に提供しています。PyTorchプログラムの快適さから、テンソル上でall_gather_into_tensorと書くだけで、それが異なるrank上に表示されます。
また、異なるハードウェア用の複数のバックエンドをサポートする便利で有用な機能もあります。特に、NCCLはGPU用でしたが、集合操作を実行することもできます。これはGPU特有のものではないことを覚えておいてください。任意のデバイスセット用です。
したがって、Glooと呼ばれるバックエンドを使用してCPU用でも実行できます。例えば、課題のためにラップトップでデバッグしている場合、Glooを使用して、GPUなしでも実行できます。
これらの高レベルプリミティブを持つ他の利点は、GPU特有のもののみを持つよりもはるかにポータブルであることです。もちろん、パフォーマンスは実際にハードウェアに依存しますが、少なくとも論理的にはコードが実行されることを確認できます。
PyTorch distributedは、Tatsuが前回の講義で話したFSDPなどの他の高レベルな機能もサポートしていますが、ゼロから開発するという精神で、このクラスではこれを使用しません。それが我々がやろうとしていることです。
4. 集合通信のベンチマーキング
4.1 All-Reduceのベンチマーク実験
Percy Liang:これまで集合操作について話し、PyTorchとNCCLでの実装について説明してきました。次に、課題や第2講義で行ったのと同じ精神で、少しベンチマークを行いましょう。今のところ、単一ノードに焦点を当てます。
All-reduceを行いましょう。1億要素のテンソルとworld size 4を使用します。テンソルを割り当てます。一般的に、ベンチマークを行う際には、ある意味でパレットをきれいにすることに本当に注意する必要があります。この場合、基本的にウォームアップを行います。操作を一度実行し、同期してからbarrierを実行します。これの一部は少し防御的だと思いますが、安全のために、すべてのカーネルがロードされ、計算される必要があるものが計算されるようにします。
その後、時計を開始し、All-reduceを実行し、再び同期して時計を停止します。これにかかった時間を見ることができます。スクロールダウンすると、これはあまり情報的ではありません。マイクロ秒で印刷すべきでした。非常に高速で、何秒かの数値でした。
次に、実際に集約で1秒あたりに転送されたギガバイト数である帯域幅を測定しましょう。これを行う方法は、ここで実際に何が転送されるかを考える必要があることです。
要素サイズを持つテンソルがあり、各要素のサイズは、これはfloat 32だと思うので4バイトです。それがバイト単位のサイズです。
これは少し微妙です。実際に送信または転送された、送受信されたバイト数はいくつでしょうか?各rankに置かれているテンソルはsize bytesです。それをworld size - 1の他のマシンまたはrankに送信する必要があります。しかし、2倍の係数があります。なぜ2倍の係数があるのでしょうか?All-reduceを行っているからです。
すべての異なる要素を基本的に一箇所に送信する必要があることを覚えておいてください。それが合計され、その後みんなに戻る必要があります。したがって、rankは入力を送信し、出力を受信する必要があります。だから2倍の係数があるのです。
総継続時間は、world sizeに実際に経過した継続時間を掛けたものです。4つのプロセッサがある場合、それは4倍の壁時計時間が発生したようなものだと仮定しています。帯域幅は単に継続時間に対するバイト数です。
ここで何を得ているでしょうか?約277ギガバイト毎秒です。H100について、上で900ギガバイト毎秒のようなものがあると主張したと思います。
もちろん、知っているように、テンソルのサイズと正確なデバイス数、天気などの様々な要因によって、結果は変わります。天気ではありませんが、様々な要因によって変わります。したがって、実際に得られる1秒あたりのギガバイト数を確認するために、常にベンチマークを行うことが良いです。
4.2 Reduce-Scatterのベンチマーク実験
Percy Liang:Reduce-scatterは非常に似ているので、これを非常に素早く見ていきましょう。入力を作成します。これはworld size × 要素数です。各rankはこのマトリックスを持つことになります。ウォームアップしてから時計を開始し、Reduce-scatterを実行し、時計を停止して、どのくらい時間がかかったかを確認します。まあ、それは役に立ちません。
次に帯域幅を見てみましょう。送信されたバイト数には、ここでは2倍の係数はありません。なぜなら、Reduce-scatterでは、覚えておいてください、あなたが行っているのは入力を一箇所に送信することだけだからです。Reduceについて考えてみると、すべての要素が一箇所に行くだけで、それで終わりです。
Scatterは、テンソルの異なるコンポーネントが異なる場所に行くことを意味するだけですが、実質的にはReduceのようなものです。同じ計算を行うと、この場合70を得ます。
なぜ他の数値ではなく正確に70なのかは正確にはわかりません。All-reduceでは一般的により多くのトラフィックが発生し、All-reduceは潜在的により最適化されていると推測できます。Nvidiaハードウェアには、実際にネットワーク内でこれらの計算の一部を行うSHARP加速のようなものがあり、2倍の係数を削減しますが、ここでの違いを完全に説明するかどうかはわかりません。
NCCLで起こる多くのことがあり、パフォーマンスについて正確に推論するのは少し困難です。したがって、ベンチマークが重要です。
質疑者:送信バイトとデータバイトの計算方法について、特に質問があります。出力に送信されるデータだけを計算しているように見えますが、入力についてはどうでしょうか?削減ステップはどうですか?削減を行うために入力をどのように取得するのか疑問に思います。
Percy Liang:これは出力のバイトだけで、入力についてはどうかという質問ですね。明確にするために、入力はすでにデバイス上にあると仮定しています。したがって、その時間をカウントしておらず、Reduce-scatterを実行するために何が起こる必要があるかをカウントしているだけです。
質疑者:これは単なるscatter操作ですか?
Percy Liang:これはReduce-scatter操作です。削減ステップが必要です。この関数はReduce-scatterを行います。それは一つの操作です。前回のAll-reduceでは2倍があったのは、削減してから再び展開する必要があったからだと言っているのですね。
Reduce-scatterについては、名前はReduce-scatterですが、実際には削減だけです。Reduce-scatterとAll-gatherを行うと、それぞれに2倍の係数がないので、それらを加算すると2倍の係数が得られ、これがAll-reduceが2倍であることを示す別の方法です。
これらの集合操作をベンチマークする方法について読むことができる参考文献もあります。
4.3 パフォーマンス分析と実験結果の考察
Percy Liang:実験結果を詳しく分析してみましょう。All-reduceでは約277ギガバイト毎秒を得ましたが、H100について先ほど900ギガバイト毎秒のようなものがあると主張しました。もちろん、知っているように、テンソルのサイズと正確なデバイス数、そして様々な要因によって、結果は変わります。天気ではありませんが、様々な要因によって変わります。
送信バイト数の計算における2倍の係数について説明しましょう。各rankに置かれているテンソルはsize bytesです。それをworld size - 1の他のマシンまたはrankに送信する必要があります。しかし、2倍の係数があります。なぜ2倍の係数があるのでしょうか?All-reduceを行っているからです。
すべての異なる要素を基本的に一箇所に送信する必要があることを覚えておいてください。それが合計され、その後みんなに戻る必要があります。したがって、rankは入力を送信し、出力を受信する必要があります。だから2倍の係数があるのです。
一方、Reduce-scatterでは70ギガバイト毎秒を得ましたが、送信されたバイト数には2倍の係数はありません。なぜなら、Reduce-scatterでは、あなたが行っているのは入力を一箇所に送信することだけだからです。Reduceについて考えてみると、すべての要素が一箇所に行くだけで、それで終わりです。
なぜAll-reduceが277で、Reduce-scatterが正確に70なのかは正確にはわかりません。All-reduceでは一般的により多くのトラフィックが発生し、All-reduceは潜在的により最適化されていると推測できます。Nvidiaハードウェアには、実際にネットワーク内でこれらの計算の一部を行うSHARP加速のようなものがあり、2倍の係数を削減しますが、ここでの違いを完全に説明するかどうかはわかりません。
NCCLで起こる多くのことがあり、パフォーマンスについて正確に推論するのは少し困難です。したがって、ベンチマークが重要なのです。
また、Reduce-scatterとAll-gatherを行うと、それぞれに2倍の係数がないので、それらを加算すると2倍の係数が得られ、これがAll-reduceが2倍であることを示す別の方法でもあります。
実際に得られる1秒あたりのギガバイト数を確認するために、常にベンチマークを行うことが良いということが、この実験から明確になります。
5. 分散学習戦略の実装
5.1 データ並列処理(DDP)の実装
Percy Liang:次に分散学習について話しましょう。我々の一般的なアプローチは、深層MLPで各戦略の最低限の実装を説明することです。MLPは一般的に計算のボトルネックであり、transformersでもattentionではないことを思い出してください。ある意味で、これは非常にシンプルなアーキテクチャですが、見るであろうワークロードのタイプを十分に代表しています。
データ、テンソル、パイプライン並列処理は、モデルまたはデータを切り分ける異なる方法として考えることができます。これを視覚的に描写したいと思います。
データ並列処理では、4つの層を持つモデルがあると仮定します。MLPの各層は行列乗算で、これが隠れ次元です。データも行列で、バッチ次元と隠れ次元があります。データ並列処理は、バッチ次元に沿って小さな部分に切り分けます。
各rankは異なるデータのスライスを取得します。例を見てみましょう。バッチサイズ128、隠れ次元1,024のサンプルデータを生成します。ランダムなデータを生成します。バッチサイズと次元は前述の通りです。
バッチサイズをworld sizeで割って、ローカルバッチサイズを取得します。これは与えられたrank上でのバッチサイズです。rankに基づいて、ローカルバッチサイズのサイズの開始および終了インデックスを決定し、対応するデータを取得します。基本的に、rank に基づいて行のサブセットを取得しています。
MLPをセットアップします。これは非常に基本的に行われています。MLPパラメータを作成します。各層は本質的に、次元数×次元数の行列で、次元数は1024です。オプティマイザを作成します。
この関数は、すべての異なるrankで非同期に実行されていることを覚えておいてください。4つのrankはそれぞれrank 0、1、2、3でこれを実行しています。
学習を開始します。ステップ数について、各層を通じて前方パスを行います。行列乗算、非線形性、行列乗算、非線形性。ここには4つの層があります。ロスを計算します。ロスが何であるかは実際には気にしません。作り物です。何か作ったものです。後方パスを行います。
これまでのところ、これはSGDを実装しているだけに見えますよね?それがポイントです。DDPを実装するための唯一の違いは、ワーカー間で勾配を同期するこの行を注入することです。
各層について、all-reduceを呼び出し、平均化し、平均化するものはparam.gradです。誰かのSGDコードを乗っ取って、「待って、後方パス後にすべての勾配を実際に混合するつもりです」と言っているだけです。
その後、通常通りパラメータを更新します。SGDの観点からは、何も起こっていないように見えます。SGDを実行しているだけですが、誰かが勾配を混合したのです。
いくつかのことを出力します。データ並列処理では、ロスを出力しています。注目すべき点は、すべての異なるrank間でロスが異なることです。なぜなら、異なるデータを持っているからです。しかし、all-reduce後、すべてのパラメータは同じです。
これは、MLセットアップでのall-reduceの教科書的な応用です。
質疑者:各rankがこれを実行するとき、all-reduceで同じステップにいることをどのように確実にしますか?
Percy Liang:すべてのプロセスが非同期で実行されている場合、例えば同じステップにいることをどのように確認するかという質問ですね。これは、all-reduceが同期ポイントだからです。みんなを停止してall-reduceを実行します。rankの一つにall-reduceが欠けている場合、ハングしてしまうので注意が必要です。
質疑者:初期パラメータの取得がrankに依存するのはなぜですか?
Percy Liang:初期パラメータの取得がrankに依存するのはなぜかという質問ですね。同じであるべきです。理由は、このコードが基本的に適切なGPU上に配置するからだと思います。
他に質問はありますか?DDPは課題2で実装するものです。おそらく一部の人は見たことがあるかもしれませんし、そうでないかもしれません。transformerのコンテキストで行われますが、これは最も基本的なバージョンなので、何が起こっているかを非常に明確に見ることができます。
それがDDPです。ロスはrank間で異なりますが、勾配はすべて同じになるように削減されます。したがって、すべてのrankのパラメータは同じです。実際には、world size個のSGD実行を行っていますが、同期されているため、同じことを行っています。
これは、時々物事を保存したくないために余分な計算を行う、活性化チェックポイントの類推のようなものと考えることができます。この場合、例えばオプティマイザ状態を送信することもできますが、オプティマイザパラメータを実際に移動するよりも、オプティマイザ状態を更新する方がはるかに高速なので、それは悪いアイデアでしょう。
5.2 テンソル並列処理の実装
Percy Liang:昨年はFSDPを試しましたが、それは少し複雑だったので、それをスキップしてテンソル並列処理を行います。ここでの図は、データは同じままにして、隠れ次元に沿ってモデルを切り分けることです。
各rankはすべての層を取得しますが、各層の一部のみを取得します。結果として、すべてのデータと活性化を転送することになります。
同じサンプルデータを生成し、テンソル並列処理を見てみましょう。以前と同様にバッチサイズと次元数がありますが、今度は以前はバッチサイズを切り分けていましたが、今度はnum_dimを切り分けます。local_num_dim = 1024 / 4 / world_sizeで、これは256です。
各モデルは本質的に、各rankがworld_sizeの1分の1のパラメータを取得するモデルの一部を取得します。我々が並列処理を行っている理由全体を思い出してください。これは、モデルが単一のGPUに収まらないからです。そのため、複数のGPU間で分割しています。
パラメータ行列は今やnum_dim × local_num_dimです。各rankは、全体の学習ループではなく、前方パスのみを実装します。
すべての層を通って行きます。活性化を最初に計算します。これは非常に普通に見えますが、活性化は実際に、各rankが活性化の一部のみを持っているため、batch_size × local_num_dimではなくnum_dimであることを覚えておいてください。
しかし、活性化を取得した後、通信する必要があります。ここで行う必要があることは、すべての活性化用のメモリを割り当てることです。この時点で、誰もがXを持っていますが、そのXは活性化の異なる部分を表しています。
batch_size × local_num_dim × world_size個を割り当てます。基本的に、各rankは、world_size個のbatch_size × local_num_dim行列を持つことになります。そして、all_gatherを行います。
すべての活性化を送信します。これは非常にシンプルです。Xは、覚えておいてください、batch_size × local_num_dimですが、Xは各rankで異なります。all_gatherを行うと、activationsに配置されます。これは本質的に、Xと同じ形状のworld_size個を持っています。
今、すべてのrankが同じ活性化を持っています。今、すべてのrankがモデル全体の活性化を持っています。それらを連結してXを取得します。XはbatchSize × num_dimです。これを繰り返します。
ご覧のように、かなりの通信が発生します。これが、テンソル並列処理には非常に高い相互接続が必要だとTatsuが言った理由です。さもなければ、これらの活性化を多く渡すことになります。
次の層、次の層でこれを行い、アイデアを理解してもらいます。出力を印刷します。テンソル並列処理の前方パスは、基本的にフルサイズの活性化を生成し、最終的にすべての人が同じ活性化を持ちます。
後方パスはスキップします。なぜなら、それは少し面倒だからです。
質疑者:なぜ後方パスが難しいのですか?
Percy Liang:なぜ後方パスが難しいのかという質問ですね。必ずしも難しいとは思いませんが、限られた時間とスペースでは、難しくありません。ただ、もう少し作業が必要です。
5.3 パイプライン並列処理の実装
Percy Liang:次にパイプライン並列処理について説明します。この場合、モデルを層別に切り分けています。すべてのrankがすべてのデータを取得し、各rankがすべての一つの層を取得しますが、異なる層を取得します。
データをサンプリングして、すべてのrankに対してこのプログラムのこの関数を実行します。ここで、各rankに何層が入るかを決定します。この場合は2つです。4層のネットワークがあり、2つのrankがあるので、各rankは2つの層を取得します。この図のようにです。
必要な層のパラメータのみを割り当てます。前方パスを行います。素朴に行うとTatsuが前に話したパイプラインバブルが発生するという、さらなる最適化があることを覚えておいてください。それを緩和する一つの方法は、バッチをマイクロバッチに分割することです。
ここでは、このバッチをサイズ32のバッチに分割します。32サイズの4つのバッチです。アイデアは、すべてのrankが本質的に前のrankが活性化を渡すのを待ち、その層を適用して、次のrankに転送するということです。
基本的なケースから始めて、rank 0があります。それは単にデータです。データを多くのマイクロバッチにチャンク化します。各マイクロバッチを通って、まずテンソルを受信します。今度は集合プリミティブの代わりにpoint-to-pointプリミティブを使用しています。本質的にテンソルXを受信し、このrankに割り当てられた層を計算します。
この場合、それらは2つだけです。次のrankに送信します。sendはpoint-to-point操作です。次のバッチでも同じことを行います。これをスキップします。
基本的にそれだけです。パイプライン並列処理の非常に素朴なバージョンは、概念的に比較的シンプルです。Tatsuが前回述べたように、この基本実装から欠けているものがたくさんあります。
通信と計算のオーバーラップは、ここでは全く行っていません。例えば、receiveとsendは同期的ですが、実際には非同期にすべきです。また、前方を行う順序について、これは前方のみですが、後方もある場合、前方と後方のステップをどのようにインターリーブするかを考える必要があります。
質疑者:先ほど言及した非同期について、実際にはGPUが他のものから何かが渡されるのを聞いているようなもので、イベント駆動のようなもので、前の層がそれを通すときにのみ処理を開始するということでしょうか?
Percy Liang:これがイベント駆動プログラミングのようなものかという質問ですね。イベント駆動プログラミングでは、基本的にハンドラーを書き、何かが起こったとき、マウスクリックやファイル準備イベントが発生したときに、コードの一部が実行されます。
これは、私が思うに、このスタイルのコーディングとは大きく異なります。すべてが歩調を合わせて動作する必要があります。確かに前のrankから情報を送ってもらうのを待っているのは事実ですが、少なくともこの実装では、どこから来るかの柔軟性はありません。任意のデータがどこからでも来るのを待っているわけではありません。
非同期学習を行う方法があると思います。これは10年以上前にかなり人気がありました。データを送信するサーバーがあり、勾配が準備できるといつでもアップロードし、勾配が蓄積され、ワーカーが死んでもより堅牢に処理されるというものです。
しかし、現代の学習では、かなりスケールアップしているにも関わらず、すべてが同期パラダイムにあるようです。
質疑者:この通信と計算のオーバーラップのためにこのプログラムをどのように変更しますか?
Percy Liang:通信と計算をオーバーラップするためにこれをどのように変更するかという質問ですね。例えば、これを送信するとき、データが送信されるのを待つ理由はありません。送信を開始するだけです。送信は実際にはカーネル起動を通じてGPU上で実行されることを覚えておいてください。それは独立です。すぐに他のマイクロバッチを処理できます。
私が思う方法は、非同期のsendという別の関数があり、ハンドルを返すので、基本的にすべての送信を行い、最後にすべての送信が完了するのを待つということです。後方ステップがある場合の重複については、基本的にここでスケジュールする必要があります。
質疑者:複数の送信、複数の受信がある場合、どれがどれかをどのように知るのですか?
Percy Liang:複数の送信と受信がある場合、どれがどれかをどのように知るかという質問ですね。ここでテンソル名は重要ではなく、そこにある変数で、指定しているのはソースです。ノードにいて受信している場合、そのrankから来る次のメッセージを、このxに入れて実行を続けます。
質疑者:同じrankから2つの送信を行いたい場合はどうですか?
Percy Liang:同じrankから2つの送信を行いたい場合ですね。確信はありませんが、2つの送信がある場合、ストリームに入れられると思います。送信のペアがある場合、その順序は保持されますが、他のrankが別のrankに送信している順序は、いつでも起こり得ます。
質疑者:送信してもそれを受信する人がいない場合はどうなりますか?
Percy Liang:送信しても誰も受信しない場合、停止して待機すると思います。プロセスは実行され続け、そこに到達しないのか、時間の問題なのかわからないからです。
質疑者:最後のrankはどうなりますか?
Percy Liang:最後のrankには、最終的にすべての活性化があります。それが基本的に完全な前方パスの結果です。後方パスを実装すると、実際に損失に関する勾配を計算し、rank n-1に送信していくことになります。
6. 高度なトピックと今後の展望
6.1 JAXとTPUエコシステムの紹介
Percy Liang:時間が足りなくなると思っていましたが、実際には時間があります。来年は後方パスもやるべきかもしれません。これまで、深層MLPに対するデータ、テンソル、パイプライン並列処理の3つのシンプルな例を見てきました。
もちろん、これはシンプルなMLPのためのものです。実際にはTransformerのような、より高度なモデルでこれを行いたいでしょう。少なくとも中核的なアイデアは、MLPを通じて理解できると主張しました。しかし、もちろん学習したい場合は、深層MLPではなくTransformerを学習したいでしょう。したがって、完全な複雑さを実装する必要があります。
また、欠けているのは通信と計算のオーバーラップで、これはここでは非常に慎重に扱われていません。一般的に、より複雑なブックキーピングを伴うコードがあります。Megatron LMやPyTorchのFSDPなどをチェックすることをお勧めします。それはかなり複雑になります。
ブックキーピングの少なくともFSDPについて、課題A2で少し触れることになりますが、複雑にする要因の一つは、任意のアーキテクチャを処理するものが必要な場合、パラメータを把握し、それらがどの層であるかを把握するためのブックキーピングを行う必要があることです。一方、MLPの場合、モデルをこの特定のシンプルな方法で分割するという決定をしました。
脇道として言及したいもう一つのことは、このコースで行っていることはすべてPyTorchですが、JAXとTPUを中心とした、この全く別のエコシステムを知っておくことが有用です。これは実際にある意味で素晴らしいものです。
ここでのアイデアは、JAXがモデルを定義し、シャーディング戦略を定義すると、JAXコンパイラが残りを処理するということです。我々が開発したLevanter と呼ばれるJAXベースのツールキットがあります。FSDPを10行のコードで示すスニペットをお見せします。
基本的にモデルがあり、この特定の方法でシャードすると言います。正確に読むことは期待しませんが、基本的にどの次元でシャードするかを定義し、それで終わりです。テンソル並列処理でも同様に、attentionのhead次元でモデルをシャードし、モデル次元でもシャードできると言うだけです。
ある意味で、これは行おうとしていることの概念的なシンプルさを提供します。基本的に計算グラフがありますが、モデル次元、埋め込み次元、attention sequence次元というこれらの次元があります。JAXは、どの次元で切り分けたいかを指定し、実際のTPUへのマッピングを定義することができ、その後JAXコンパイラが魔法のように、物事を移動するプリミティブにコンパイルする方法を把握します。
これは、集合通信で操作するよりもはるかに高レベルです。しかし、我々はPyTorchに留まっています。なぜなら、実際に何が起こっているかをフードの下で見ることができるからです。しかし、実際の世界でこれを行っている場合、明らかにすべてをゼロから実装する必要はなく、おそらくすべきではありません。
これがJAXの脱線でした。まとめると、これまで並列化する多くの方法を見てきました。これらの並列化方法はそれぞれ、データのバッチ次元、幅次元、深さ次元、またはコンテキスト長次元のいずれかに沿って、モデルまたはデータを分割する方法として考えることができます。
6.2 通信と計算のオーバーラップ最適化
Percy Liang:通信と計算のオーバーラップについて詳しく説明しましょう。これは我々の実装では非常に慎重に扱われていない部分です。
パイプライン並列処理での例を考えてみましょう。送信を行うとき、データが送信されるのを単純に待つ理由はありません。送信を開始するだけです。送信は実際にはカーネル起動を通じてGPU上で実行されることを覚えておいてください。それは独立しているので、すぐに他のマイクロバッチを処理できます。
私が思う実装方法は、非同期のsendという別の関数があり、ハンドルを返すので、基本的にすべての送信を行い、最後にすべての送信が完了するのを待つということです。そして、後方ステップがある場合のオーバーラップについては、基本的にここでスケジュールする必要があります。
質疑者:複数の送信、複数の受信がある場合、どれがどれかをどのように知るのですか?
Percy Liang:複数の送信と受信がある場合、どれがどれかをどのように知るかという質問ですね。ここでテンソル名は重要ではなく、そこにある変数で、指定しているのはソースです。ノードにいて受信している場合、そのrankから来る次のメッセージを、このxに入れて実行を続けます。
質疑者:同じrankから2つの送信を行いたい場合はどうですか?
Percy Liang:同じrankから2つの送信を行いたい場合ですね。確信はありませんが、2つの送信がある場合、ストリームに入れられると思います。ペアで2つの送信を行う場合、その順序は保持されますが、他のrankが別のrankに送信している順序については、いつでも起こり得ます。
実際に後方ステップがある場合、前方と後方のステップをどのようにインターリーブするかを考える必要があります。receiveとsendは現在同期的ですが、実際には非同期にすべきです。これにより、通信が進行している間に計算を続行でき、全体的なパフォーマンスが向上します。
この最適化が欠けていることが、この基本実装から欠けている多くのもののうちの一つです。Megatron LMやPyTorchのFSDPなどの実際の実装では、これらの最適化が実装されており、それがコードをかなり複雑にする理由の一つでもあります。
6.3 専用ハードウェアの発展とその影響
Percy Liang:我々は、これらの並列化の様々な方法を見てきました。また、再計算という繰り返しのテーマも見てきます。何かを最初から再計算するか、メモリに保存してデータ転送コストに苦しむか、または今やマルチGPU、マルチノード設定では、実際に別のGPUのメモリに保存してから通信することもでき、これはさらに遅くなります。これらのトレードオフがあります。
多くの場合、再計算は実際により良い選択肢ですが、明らかに全体を再計算することはできません。多くの場合、通信制限またはメモリ制限のいずれかです。
最後に言いたいのは、ハードウェアが良くなっているということです。5年後にはすべてがL1 HBMに収まるので、これらはすべて必要ないかもしれないと思うかもしれません。これは起こりません。なぜなら、それらはかなり成長するかもしれませんが、物理的な限界があるからです。我々は常に、ハードウェアができることの限界にあるより大きなモデルを最終的に扱うことになります。
質疑者:GPUは特殊なハードウェアやより専門化されたものに置き換えられるのでしょうか?
Percy Liang:GPUがtransformer特化ハードウェアに置き換えられるかという質問ですね。推論スペースでは、GroqやCerebrasが推論、そして訓練もできる専用ハードウェアを持って、すでにかなり見られています。Cerebrasは訓練も行います。
基本的に、これらのハードウェアは本質的にはるかに多くのオンチップメモリを提供します。それが基本的にゲームの名前です。Cerebrasには巨大な、本質的に効果的にL1キャッシュのようなものがあるので、物事をオフチップに移動する必要がありません。多くの簡素化が可能になると思います。
GPUには実際に多くの歴史的負担があります。分岐や様々なタイプのアドホック計算を多く行う必要があった時代に設計されたためで、これらは深層学習の領域では実際には必要ありません。したがって、ハードウェアも改善する多くの機会があると思います。
質疑者:前の質問について、モデル特化ハードウェアに関する前の質問で、おそらく技術的な理由でノードをそれほど大きくできない物理的な技術的理由があるのでしょう。どんな進歩が話されているのでしょうか?
Percy Liang:GPU用の物理的制限が確実にあるという質問ですね。GPUを明らかに無限に大きくしたり、無限に密にしたりすることはできません。電力の問題もあり、すべての熱を取り除く必要があり、収まる帯域幅にも限界があります。
正確な詳細はわかりませんが、少なくともCerebrasの場合、基本的にメモリをチップ上に配置する製造方法を持っています。それを配置する方法だと思います。柔軟性がないというコストを伴う明らかなトレードオフがあります。
しかし、一般的に、より広く考える方法は、GPUはまだCPU時代に開発されたもので、はるかに制御重視で、実行しているコードが第一級市民のようなもので、データはコードを処理するために移動される必要があるということです。
しかし、深層学習ワークロードとの大きな違いは、これらはすべてデータフローのようなもので、計算グラフを見ると静的です。基本的に訓練の終わりまで、最初から行われるすべての計算を正確に知っています。
その知識を使用して、すべてのアドホック計算の不確実性に対処するよりも、はるかに賢い方法で計算をレイアウトできるはずです。この階層構造は、コンピュータシステムが存在して以来、常に我々と共にあり、常にそこにあるでしょう。
7. 質疑応答セッション
7.1 実装に関する技術的質問
質疑者:同じパラメータセットでも、正規化がデータセット全体の関数である可能性があるため、例えばBatchNormでコストが異なる可能性があります。
Percy Liang:データ並列処理において、パラメータがすべて同期されていても、BatchNormのようにデータに依存する他のものがあるため、異なる結果が生じる可能性があるという質問ですね。BatchNormは常に少し厄介なので、正確にどのように行うかは、正直なところわかりません。
少なくとも言語モデルの世界では、それは実際には現れません。Layer normが使用されるからです。すべてのパラメータを初期化し、同じランダムシードを使用している限り、問題ありません。GPU上で非決定性の問題があるかもしれませんが、それらは軽微であることを願っています。
質疑者:PyTorchにもJAXが提供するような機能はありますか?
Percy Liang:PyTorchにもJAXが提供するような機能があるかという質問ですね。PyTorchには、このクラスを受講していない場合は絶対に使用すべきFSDPライブラリがあります。これは基本的にラッパーです。任意のモデルを定義すると、それに対してFSDPを実行します。
より多くのカスタムシャーディングを可能にするものがあるかどうかについては、いくつかのものが来ていると思いますが、それほど開発されていないと思います。JAXの世界では、宣言的に物事を定義し、JAX TPUシステム内に留まる場合、Googleのインフラストラクチャは非常によく開発されていると思います。
しかし、実際に悪い相互接続を持つGPUを持つDeepSeekのような反対の端を見ると、彼らは実際にNCCLレベルに入り、パフォーマンスを引き出すために理解していない多くのことを行う必要があります。一方、JAXを書いている場合、高いレベルからモデルを宣言し、その後物事が起こります。
ハードウェアを活用する方法は、操作しているエコシステムに本当に依存していると思います。
質疑者:アプリケーションで活性化を再計算できる部分はありますか?どの部分を再計算するかを指定できるAPIがありますか?
Percy Liang:活性化チェックポイントについて、どの部分を再計算するかを指定できるAPIがあるかという質問ですね。PyTorchとJAXの両方で、どの部分を再計算したいかを指定できるAPIがあります。明らかにすべてを再計算したくも、何も再計算したくもありません。
おそらく数層ごと、大きな行列乗算の直後のように、例えば行列乗算と点的線形性がある場合、些細に2つのコピーを保存する必要はないと思います。2つのもので、一方から他方に到達するのが些細であれば、一つのバージョンを保存するだけでよいかもしれません。
7.2 ハードウェアの物理的制約について
質疑者:モデル特化ハードウェアについて、おそらくノードをそれほど大きくできない物理的な技術的理由があるのでしょう。実際にどのような技術的進歩が話されているのでしょうか?
Percy Liang:GPU用の物理的制限が確実にあるという質問ですね。GPUを明らかに無限に大きくしたり、無限に密にしたりすることはできません。電力の問題もあり、すべての熱を取り除く必要があり、収まる帯域幅にも限界があります。
正確な詳細はわかりませんが、少なくともCerebrasの場合について説明します。彼らは基本的にメモリをチップ上に配置する製造方法を持っています。それを配置する方法だと思います。明らかなトレードオフがあり、柔軟性がないというコストを伴います。
しかし、一般的に、より広く考える方法があります。GPUはまだCPU時代に開発されたもので、はるかに制御重視でした。実行しているコードが第一級市民のようなもので、データはコードを処理するために移動される必要がありました。
しかし、深層学習ワークロードとの大きな違いは、これらはすべてデータフローのようなものだということです。計算グラフを見ると静的です。基本的に訓練の終わりまで、最初から行われるすべての計算を正確に知っています。
その知識を使用して、すべてのアドホック計算の不確実性に対処するよりも、はるかに賢い方法で計算をレイアウトできるはずです。
質疑者:計算グラフは通常CPUまたはGPUに保存されますか?
Percy Liang:計算グラフがどこに保存されるかという質問ですね。このコードはすべてCPU上で動作しています。しかし、PyTorch関数のようなものを呼び出すとき、それがGPU上で実行される必要がある場合、内部でカーネルを起動し、カーネルはGPU上で実行されるコードです。
計算グラフはより概念的なものだと思います。グラフが文字通りGPUに配置されるわけではありません。
質疑者:これらの通信プリミティブは、実際にはCPU命令なのか、それともGPUを使用するプログラムなのでしょうか?
Percy Liang:通信プリミティブがCPUかGPUかという質問ですね。これらの集合操作は、ある意味で何種類の操作が起こる必要があるかの抽象的な仕様です。PyTorch distributedが異なるバックエンドを持っていることを覚えておいてください。GPU上でもCPU上でも起こり得ます。
しかし、CPUで起こっているとき、CPUがそれらをスケジューリングしているのか、それとも独立したカーネルなのでしょうか?CPUは基本的にまだマスターであり、集合操作を行うとNCCLライブラリを呼び出し、これはCPUですが、データを移動するいくつかのカーネルを起動します。
7.3 将来のハードウェア発展の可能性
Percy Liang:最終的に言いたいのは、ハードウェアが良くなっているということです。5年後にはすべてがL1 HBMに収まるので、これらはすべて必要ないかもしれないと思うかもしれません。しかし、これは起こりません。なぜなら、それらはかなり成長するかもしれませんが、物理的な限界があるからです。我々は常に、ハードウェアができることの限界にあるより大きなモデルを最終的に扱うことになります。
この階層構造は、コンピュータシステムが存在して以来、常に我々と共にあり、常にそこにあるでしょう。
質疑者:GPUは特殊なハードウェアやより専門化されたものに置き換えられるのでしょうか?
Percy Liang:GPUがtransformer特化ハードウェアに置き換えられるかという質問ですね。推論スペースでは、GroqやCerebrasが推論、そして訓練もできる専用ハードウェアを持って、すでにかなり見られています。Cerebrasは訓練も行います。
基本的に、これらのハードウェアは本質的にはるかに多くのオンチップメモリを提供します。それが基本的にゲームの名前です。Cerebrasには巨大な、本質的に効果的にL1キャッシュのようなものがあるので、物事をオフチップに移動する必要がありません。多くの簡素化が可能になると思います。
GPUには実際に多くの歴史的負担があります。分岐や様々なタイプのアドホック計算を多く行う必要があった時代に設計されたためで、これらは深層学習の領域では実際には必要ありません。したがって、ハードウェアも改善する多くの機会があると思います。
質疑者:このコンテキストで、増分的にモデルを訓練するために、新しい訓練データを取得したときに、ファインチューンするだけでなく、すべてを再計算することなく実際にすべてを再統計するために、我々が話している技術のいずれかを使用できますか?
Percy Liang:これらの技術を継続訓練に使用できるかという質問ですね。絶対にできます。我々が扱っている単位は、勾配ステップを行うことだからです。ハーフトレインチェックポイントを取得した場合、これを続けることができます。ゼロから始めることに特有なものはありません。