※本記事は、Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lecture 10: Inferenceの講義内容を基に作成されています。講義の詳細情報やコースの履修については https://online.stanford.edu/courses/ および https://stanford-cs336.github.io/spri... でご覧いただけます。本記事では、講義の内容を要約・整理しております。なお、本記事の内容は原講義の見解を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や文脈については、オリジナルの講義動画(https://www.youtube.com/watch?v=fcgPYo3OtV0 )をご視聴いただくことをお勧めいたします。Stanford大学のオンライン人工知能プログラムについては https://stanford.io/ai をご参照ください。
登壇者紹介
Percy Liang スタンフォード大学コンピューターサイエンス学部准教授 Center for Research on Foundation Models (CRFM) ディレクター
Tatsunori Hashimoto スタンフォード大学コンピューターサイエンス学部助教授
本講義は、Stanford Onlineを通じて提供されており、Stanford Engineering Center for Global & Online Education (CGOE)の運営・管理の下、スタンフォード大学全体の学部・部門と連携し、世界的な規模でオンライン教育を展開する取り組みの一環として実施されています。
1. イントロダクション:推論の重要性
1.1 推論とは何か
これは第10回講義となります。私たちはスケーリング則から少し離れて、推論について話をしていきます。推論の問題は非常にシンプルなものです。訓練済みの固定されたモデルが与えられた時に、プロンプトに対してレスポンスを生成するというものです。
まず、推論が何を意味するのか、そしてそれが伴うワークロードについて理解することから始めます。その後、推論をより高速化する方法について話をしていきます。この講義を通して、推論は実際に非常に深いトピックであることがお分かりいただけるでしょう。実際、昨年は講義で推論を扱いませんでした。今年が初めて推論を扱う年です。しかし実際には、複数の講義にまたがるほど多くのトピックがあり、それを一つの講義に圧縮しようと思います。
推論は複数の異なる場所で現れます。最も明白な場所は、実際にモデルを使用したい場合です。チャットに使用したり、Cursorなどを使ってコード補完を行ったり、言語モデルを使用してバッチデータ処理ジョブを実行したりする場合です。これらすべてのケースでは、実際のモデルからトークンを生成する必要があるため、推論を要求します。
しかし、推論は他の文脈でも現れます。モデルを評価したい場合、例えば指示に従うことについて評価する場合でも、推論を行う必要があります。テスト時計算への多くの関心があります。これは、実際に最終的な答えを出力する前により多く考えるということを意味します。そしてこれもまた、より多くの推論です。なぜなら、考えることは基本的にトークンを生成することだからです。
最後に、訓練そのものでさえも推論を必要とします。強化学習を使用している場合、レスポンスをサンプリングし、それを何らかの報酬に基づいて評価する必要があり、それもまた推論を必要とします。そのため、推論はチャットボットのデモを立ち上げたいというだけではありません。推論は実際に言語モデルの多くの基本的な機能の基盤となるものです。
一つの講義であっても、多くのことにとって推論がいかに重要であるかを強調したいと思います。そして、クラスの後半でアライメントについて話す際に、おそらくこのトピックに戻ってくることでしょう。
1.2 推論が必要な場面
推論が重要であることは明らかです。そこで、このクラスのテーマは効率性です。そして効率性は明らかに重要です。訓練は一回限りのコストですが、推論は複数回繰り返すものです。
推論が大きな問題である理由について、いくつかの逸話的な統計をお話ししましょう。サムは、OpenAIが1日に1000億語を生成していると述べています。これは非常に多い数字です。そして、それほど新しい製品ではないCursorでさえ、1日に10億行の承認されたコードを生成していると言われています。これは、推論がどれだけ多くを占めているかという概念を与えてくれます。そして、訓練と比較した推論のコストは間違いなく増加しています。
推論は複数の異なる場所で現れます。最も明白な場所は、実際にモデルを使用したい場合です。チャットに使用する場合、Cursorなどを使ってコード補完を行う場合、言語モデルを使用してバッチデータ処理ジョブを実行する場合などです。これらすべてのケースでは、実際のモデルからトークンを生成する必要があるため、推論を要求します。
しかし、推論は他の文脈でも現れます。モデルを評価したい場合、例えば指示に従うことについて評価する場合でも、推論を行う必要があります。テスト時計算への多くの関心があります。これは、実際に最終的な答えを出力する前により多く考えるということを意味します。そしてこれもまた、より多くの推論です。なぜなら、考えることは基本的にトークンを生成することだからです。
最後に、訓練そのものでさえも推論を必要とします。強化学習を使用している場合、レスポンスをサンプリングし、それを何らかの報酬に基づいて評価する必要があり、それもまた推論を必要とします。そのため、推論はチャットボットのデモを立ち上げたいというだけではありません。推論は実際に言語モデルの多くの基本的な機能の基盤となるものです。
1.3 推論の重要性とコスト
推論が重要であることは明らかです。そこで、このクラスのテーマは効率性です。そして効率性は明らかに重要です。訓練は一回限りのコストですが、推論は複数回繰り返すものです。
推論が大きな問題である理由について、いくつかの逸話的な統計をお話ししましょう。サムは、OpenAIが1日に1000億語を生成していると述べています。これは非常に多い数字です。そして、それほど新しい製品ではないCursorでさえ、1日に10億行の承認されたコードを生成していると言われています。これは、推論がどれだけ多くを占めているかという概念を与えてくれます。そして、訓練と比較した推論のコストは間違いなく増加しています。
多くの人々が推論を行っています。実際に製品やプラットフォームを持っている人は誰でも、大型モデルでこれらのコストを実行することが増加することを素早く認識し、その時間を短縮するために多くの時間とエンジニアリング努力を費やします。そのため、クローズドモデルを提供するプロバイダーとオープンウェイトモデルを提供するプロバイダーの両方が、推論に多くの注意を払います。これは、実際にモデルを提供していない平均的な学術研究者よりもはるかに多いと思います。私たちは単に訓練を行い、スコアを得て、論文に掲載するだけです。しかし、実際にモデルを提供している人々は推論に多くの注意を払います。
また、見てみると興味深いオープンソースパッケージも多数あります。
2. 推論の評価指標
2.1 Time-To-First-Token (TTFT)
では、良い推論がどのようなものかをどのように測定するのでしょうか。まず、Time-To-First-Token(TTFT)があります。これは、個々のユーザーが何らかの生成が全く起こる前に待つ必要がある時間です。そして、これは明らかにインタラクティブなアプリケーションにとって重要です。大きなプロンプトがあって、そこで10秒間待たなければならない場合、それは良いユーザー体験ではないかもしれません。
2.2 レイテンシとスループット
レイテンシは、おそらく最初のトークンの後にトークンがどれくらい速く到着するかということです。これもまたインタラクティブなアプリケーションにとって重要です。スループットは少し異なるものです。スループットは、全体的にどれだけ多くのトークンが生成されるかということで、全体のユーザーに対してではありません。そのため、これは特にバッチ処理アプリケーションで有用です。
スループットについて考えると、高いスループットは低いレイテンシを意味しないということが言えます。なぜなら、いくつかのリクエストは非常に長い時間がかかる可能性があり、それでも高いスループットを持つことができるからです。レイテンシは個々のユーザーの最悪のケースのようなものです。
2.3 推論と訓練の違い
推論の効率性について考える時に何を考慮する必要があるでしょうか。訓練においては、少なくとも教師あり訓練では、すべてのトークンを見ることができるという重要なアイデアがあります。これは、シーケンス全体で並列化できることを意味します。これはTransformerで大いに活用されています。
皆さんはTransformerの訓練を行ったことがあるでしょう。基本的に、シーケンス全体でこれらのテンソルを構築することを知っています。そして、それはテンソルの行列乗算のようなもので、その後出力を得ます。しかし、推論、少なくともTransformerにとっての重要な定義的特徴は、順次生成しなければならないということです。並列化することはできません。なぜなら、トークンの生成は過去のすべてに依存するからです。
これが推論をはるかに困難にする重要な要素となります。特に、利用可能なすべての計算を活用することが困難になります。そして、後で詳しく見るように、メモリ制限になります。
多くの人々が推論を行っています。実際に製品やプラットフォームを持っている人は誰でも、大型モデルでこれらのコストを実行することが増加することを素早く認識し、その時間を短縮するために多くの時間とエンジニアリング努力を費やします。そのため、クローズドモデルを提供するプロバイダーとオープンウェイトモデルを提供するプロバイダーの両方が、推論に多くの注意を払います。これは、実際にモデルを提供していない平均的な学術研究者よりもはるかに多いと思います。私たちは単に訓練を行い、スコアを得て、論文に掲載するだけです。しかし、実際にモデルを提供している人々は推論に多くの注意を払います。
また、見てみると興味深いオープンソースパッケージも多数あります。
3. 推論ワークロードの理解
3.1 Transformerの数学的復習
推論ワークロードを詳細に理解したいと思います。そこで、課題1で行い、クラスの最初の週で少し話したこのTransformerの数学を簡単に復習します。
これはScaling jax-mlブックからのもので、皆さんには本当に見ていただきたいものです。ここで多くの重要な概念を優れた方法で概説していると思います。そして、入力を取り、それをattentionとMLPレイヤーを通して処理する計算グラフを本質的に示すこの非常に良い図があります。
特に、この記法を使用しますので、これを簡単に復習しましょう。Bはバッチ内のシーケンス数です。Lはレイヤー数です。Tはシーケンス長です。これは、生成するトークン数やクエリで使用するトークン数と考えることができます。Sもシーケンス長ですが、プロンプトで条件付けするトークン数です。
Vは語彙です。Dはモデルの次元です。FはMLP隠れ次元で、通常はDの4倍です。Hはattentionヘッド次元です。Nはクエリヘッド数です。一般的に、N×HはDに等しく、GQA(Group Query Attention)では、クエリヘッドとは異なる数のキー・バリューヘッドがあります。通常、KはNより小さく、Gはグループ数です。K×GはNに等しくなります。
この図は、Xを取り、QKV行列を通して供給し、多くのことを行うことを示しています。フィードフォワードパスに必要なFLOPsは、6×トークン数(B×T)×パラメータ数であることを覚えておいてください。さらに、attentionにはTのもう一つの次数があります。T×TはT²の依存性です。
3.2 算術強度(Arithmetic Intensity)の概念
また、算術強度も復習しましょう。これは、何かが計算制限なのかメモリ制限なのかを特徴づけるのに役立ちます。基本的な行列乗算から始めましょう。行列XをB×D、行列WをD×Fとします。この計算に色をつけるために、Bはバッチサイズ、Dは隠れ次元、FはMLPの投影行列です。
では、X×Wを行うためのFLOPs数とメモリの読み書き数を数えてみましょう。0で初期化することから始めます。これに対して行う必要があることは、HBMからXを読み取ることです。つまり、すべてがbf16であると仮定すると、2×B×Dのメモリコストが発生します。また、Wも読み取ります。つまり、2×D×Fです。
その後、行列乗算を行い、2×B×D×FのFLOPsが発生します。覚えておいてください、これは最初の講義からのものです。これが復習であることを願っています。そして、それを書き戻す必要があり、これは別の転送を支払う必要があります。
FLOPsの総数は行列乗算だけです。転送されるバイト数は、基本的に読み書きされるすべての行列のサイズです。算術強度は基本的にその比率です。つまり、比率はこの式になります。
一般的に、物事を少し単純化するために、通常、バッチサイズはDやFよりもはるかに小さいです。Bは100かもしれませんが、DやFは数千や数万かもしれません。ここでは、愚かな間違いを避けるためにSymPyを使用しています。基本的に、Cを無限大にし、DをC×Bとしてスケールし、FをC×Bとしてスケールします。そして、それはBの簡略化された式を得ます。つまり、算術強度はこの特定の行列乗算に対してBです。
これを解釈する方法は、転送された1バイトあたりどれだけのFLOPsが実行されるかということです。
3.3 行列乗算の例:メモリ制限 vs 計算制限
次に、第二部では加速器を見ます。H100の場合、秒あたりのFLOPsは989テラフロップス、メモリ帯域幅は秒あたり3.3バイトです。これを割ると、加速器強度と呼ばれるものが得られます。
計算強度、つまりBを見ると、それが加速器強度より大きい場合、それは計算制限であることを意味します。つまり、すべてのGPUやTPUを使用できるということです。それより小さい場合は、メモリ制限となり、これは悪いことです。この行列乗算の場合、H100でBが295より大きい場合、計算制限になります。
これはすべて少し理想化されています。実際の詳細については、これは一次近似を与えています。極端なケースでは、一般的に、バッチサイズが例えば300の場合、GPUを飽和させることができるということです。しかし、バッチが本当に小さい場合はどうなるでしょうか。
特に、B=1の場合、これは基本的に行列ベクトル積に対応します。その時、算術強度は基本的に1です。そして、それは本当に、本当に悪いことです。つまり、メモリ制限になるということで、これは基本的に理にかなっています。なぜなら、基本的にこのD×F行列を読み書きしているのに、本質的に同じ数のFLOPsを実行しているからです。
FLOPsと読み取りの比率は同じで、それが1になり、1は悪いのです。メモリ読み取りは遅いので、メモリ読み取りあたり多くのFLOPsが実行されることを望みます。しかし、これは本質的に生成で起こることです。トークンごとに進むため、基本的に算術強度は1になることがわかります。
そして、それが生成がメモリ制限になり、計算制限にならない理由です。これは、生成が遅くなる理由の核心に迫る非常に単純な例だと思います。
4. 推論における算術強度の分析
4.1 MLPレイヤーの算術強度
要約すると、行列乗算は中核的な計算です。私たちは行列乗算を研究し、それが必要とするFLOPs数と読み書き数を数えました。そして、その比率、つまり加速器強度が、この場合はバッチ次元である次元の一つに依存することを示しました。
そして、それが大きな行列が良い理由です。なぜなら、それは計算を飽和させることができるからです。一方、たとえ薄い行列、B=1でも、それは本当に悪いです。なぜなら、メモリから読み取るのに多くの時間を費やし、それほど多くの計算を行わないからです。
では、推論の算術強度について話しましょう。推論がどのようなものかについて、より詳細に入っていきましょう。
まず、TransformerのFLOPsとメモリI/Oの両方を計算しましょう。MLPレイヤーとattentionレイヤーに分けて考えます。記法として、条件付けしているトークン数、つまりプロンプトの長さをSとして、この計算を行います。そして、生成またはクエリで使用しているトークン数をTとします。プリフィルでは、TはSになります。なぜなら、Tトークンを生成しているのではなく、これらの各トークンを使ってクエリしているからです。生成では、Tは単に1です。
行列乗算が頭に新鮮に残っていることを願います。なぜなら、これは本質的にそれですが、Transformerなので少し複雑になるからです。
FLOPsと生成されるバイト数を数えていきます。まず、T×D行列であるXを取ります。多分これらのTはSであるべきだと思いますが、とにかく、それは多くの転送を含み、基本的にその行列のサイズに2をかけたもので、bf16だからです。
その後、3つの行列があります。up投影、gate、down投影です。これらはすべて転置まで同じサイズです。だから、それらを転送する必要があります。その後、up投影を行います。それはいくつかのFLOPs数です。B×D×T×Fです。2つのテンソルを掛ける時は基本的に、縮約次元は一度だけカウントされますが、他の次元は単に集まります。
それを書き出す必要があります。また、同じことであるgateもあります。それを書き出します。非線形性を計算します。いくつかのものを掛け合わせ、down投影を行います。そして、それはB×T×D×Fで、基本的に同じFLOPs数で、結果を書き出します。
カウントを見ると、多分結果をチェックできます。実際、これはSymPyなので正しいことが保証されているため、チェックする必要はありません。しかし、再び、B×DがDやFよりもはるかに小さいと仮定します。そして、強度がB×Tであることが分かります。これは、算術強度、高くしたいものが、バッチの大きさと本質的に生成しているトークン数に依存する行列乗算の場合と類似しています。
2つの段階を見ると、プリフィルでは、生活は良好であることを覚えておいてください。なぜなら、BTを十分大きくすることができるからです。バッチサイズを使用します。バッチサイズが1でも、十分長いシーケンスがあれば実際に大丈夫かもしれません。それは問題ではありません。
生成では、一度に一つのトークンを生成しているため、少し困難になります。つまり、Tは1です。Tが1の場合、BTが大きくなるためにはBが大きくなる必要があります。Bは本質的に同時リクエスト数です。これは興味深いことです。なぜなら、効率は大きなバッチサイズを持つことに依存するからです。直感的には理にかなっています。多くのリクエストを取り、それらをまとめてバッチ処理できれば、少なくともスループットでより良い効率を得ることができます。
しかし、これはBが何であるかにも依存します。一度に少数のリクエストしか取得していない場合、ハードウェアを非常に効率的に使用することはできません。そして、これは推論の非常に動的な側面について語っており、講義の後半で戻ってきます。
4.2 Attentionレイヤーの算術強度
では、attentionはどうでしょうか。実はattentionは、これから説明する理由でさらに悪いことが判明します。
FLOPsと転送されるバイト数をカウントしましょう。HBMからQKV行列を読み取ります。attentionを計算します。これはQ×Kの行列で、FLOPs数はB×S×T×Dです。SとTはプリフィル中は同じであることを覚えておいてください。つまり、シーケンス長の2乗×B×Dです。
その後、他のステップからのFLOPsは実際にはあまり重要ではないため、行列乗算のみを見ています。その後、これとVの組み合わせに投影します。実際、ソフトマックスがあるため、これは数学的に間違っています。しかし、数学モデルの本質は同じです。
それは同じFLOPs数です。その後、HBMに書き込みます。Flash attentionを使用していない場合、より多くのバイトが転送されると仮定しています。Flash attentionは、中間ステップでHBMに書き戻し続ける必要がないことを意味します。しかし、次数は実際に影響を受けません。定性的には、flash attentionを使用するかどうかは実際には関係ありません。
ここでの数学は依存します。定数は重要ですが、FLOPsと転送されるバイト数を見てみましょう。割って単純化すると、この非常に良い式が得られます。単純であるという意味で良いのであって、効率が良いという意味で良いのではありません。それはS×T割るS+Tです。
これを少し解釈してみましょう。プリフィルでは、T=Sです。つまり、プリフィル強度は次数Sです。これは良いことです。十分に長いシーケンスがある限り、順調に進みます。一般的に、仮定するシーケンスは十分に長いものです。
しかし、生成中には、密度が本質的に1であることがわかります。S割るS+1です。しかし、それは基本的に1です。そして、1は本当に悪いということを覚えておいてください。
しかし、Bへの依存性が全くないことに注目してください。MLPでは、覚えておいてください、MLPでは、プリフィルはBTで、これは素晴らしいものでした。そして、算術強度はBでした。これは、ユーザーやワークロードの気まぐれに依存するため、それほど素晴らしくありませんでしたが、それでも1より大きくなる可能性がありました。
一方、attentionでは、実際にシーケンスがどれほど長くても、ユーザーが何人いても関係なく、常に1未満です。常に1なのです。
4.3 プリフィル vs 生成の比較
なぜ直感的にバッチ次元Bへの依存性がないのでしょうか。その理由は、MLPレイヤーでは、直感的に、すべてのシーケンスが同じMLP重みにヒットするのに対し、attentionレイヤーでは、各シーケンスが独自のKVキャッシュを持つからです。
KVキャッシュはシーケンス固有であるため、MLPの場合のように、すべての重みを読み取って、バッチを直感的に処理することはできません。一方、attentionの場合、すべてのシーケンスが追加のメモリを必要とします。それらをバッチアップしても、実際には何の節約も得られません。
数学的には、FLOPs数にBがあることで、これは期待されることですが、転送されるバイト数はB×(Bでのスケーリングがある)なので、それを割ると、Bが相殺されます。一方、こちらでは、Bがあります。しかし、DFが支配的であると仮定しているため、割ると、基本的に分母にBが残りません。
数学的にそれを見ることもできますし、attentionについて直感的に推論することもできます。KVキャッシュはすべてのシーケンスが独自の特別な雪片なのです。
要約すると、プリフィルは計算制限であるのに対し、生成はメモリ制限です。MLPの算術強度はBで、それを十分に良くするためには、多くの同時リクエストが必要です。しかし、attentionの強度は1で、それを改善することは不可能です。
5. スループットとレイテンシの理論分析
5.1 Llama 2-13Bでの具体例
今、推論が生成のおかげでメモリ制限であることが分かりました。少なくとも理論上、スループットとレイテンシを研究してみましょう。
これらすべては概算計算で、少し様式化されていますが、正しい種類のスケーリングと物事の考え方を与えてくれます。通信と計算が完全に重複できると仮定します。これは明らかに間違いですが、これらの定性的な見積もりを行うには十分良いものです。
H100でのLlama 2-13bのレイテンシとスループットを具体化します。13bについては、これらの値があります。シーケンス長を1,000、隠れ次元をモデル次元を5,000にしましょう。Fは4倍です。実際、それが4倍かどうかは分かりませんが、とにかく、Fはその数の何倍かです。ヘッド数、キー・バリュー数、クエリヘッド数、キー・バリューヘッド数で、Llama 2では同じです。この点については後で触れます。その他諸々です。そして、H100のメモリ帯域幅は、その数値です。
これが設定です。メモリレイテンシとスループットを計算します。
まず、パラメータ数を簡単に取得しましょう。皆さんは課題1でこれを行ったので、詳しくは述べませんが、すべての異なる変数に依存するいくつかの式があり、パラメータを保存するために使用します。推論は一般的に32ビットではなく16ビットになるため、bf16を使用します。つまり、2を掛けます。それがパラメータが取るメモリです。
勾配は必要ありません。訓練していないため、オプティマイザの状態も必要ありません。しかし、KVキャッシュを保存する必要があります。これは活性化の一部です。すべての活性化ではなく、長さSのすべてのシーケンスについてのそれらの一部です。
シーケンスごとに保存する必要がある量は、基本的にシーケンス長×キー・バリューヘッド数×そのヘッドの次元×レイヤー数×基本的にキーとバリューの両方について2、そしてbf16について2です。
それがキャッシュサイズの取る量です。総メモリは、バッチサイズ×シーケンスあたりのキャッシュ+メモリ+パラメータサイズです。
レイテンシはメモリI/Oによって決定されます。メモリ制限であることを覚えておいてください。そのため、この計算を行うためにGPUに転送される必要があるメモリの量を計算し、メモリ帯域幅で単純に割ります。スループットは本質的にレイテンシの逆数ですが、並列でBトークンを生成しているため、Bでスケールアップされます。
Llama設定を代入すると、パラメータ数が確認できます。これは約130億です。メモリレイテンシとスループットはこれらの式を持ちます。
メモリは明らかに増加します。これはパラメータサイズです。これはキー・バリューキャッシュサイズ×Bです。レイテンシもBの関数として上昇します。スループットは増加しますが、ある点まで増加することがわかります。Bは分子と分母の両方に現れます。そのため、すべてをメモリに収めることができたとしても、スループットをどれだけ伸ばすことができるかには限界があります。
これらが、この特定のモデルのレイテンシ、スループット、メモリの式です。
5.2 バッチサイズの影響
では、異なるバッチサイズで具体化してみましょう。B=1の場合、レイテンシは約8ミリ秒です。つまり、8ミリ秒ごとにトークンを生成し、スループットは124トークン/秒です。これがH100でバッチサイズ1を使用している場合の13bです。
では、バッチサイズ16を使用するとどうなるでしょうか。メモリ使用量が増加することがわかります。なぜなら、現在64シーケンスすべてのKVキャッシュを保存する必要があるからです。レイテンシは上昇します。なぜなら、1つだけを処理する代わりに、すべてが終了するまで待つ必要があるからです。しかし、スループットも実際にかなり上昇します。
つまり、低いレイテンシが欲しい場合はB=1を使用し、高いスループットが欲しい場合は一般的により大きなBを使用するという、レイテンシとスループットの間の即座のトレードオフが見えています。
さらに大きなバッチサイズ、例えば256を使用するとどうなるでしょうか。レイテンシは上昇し、スループットは上昇しますが、しばらくすると収穫逓減があるため、スループットはそれほど上昇していないことがわかります。
しかし、最も重要なことは、メモリを見ると240ギガバイトなので、実際にはH100で実行できないということです。メモリに収まらないのです。そのため、バッチサイズはメモリのためにある点までしか増加させることができません。
要約すると、レイテンシとスループットの間にはトレードオフがあります。小さなバッチサイズは、より良いレイテンシをもたらします。大きなバッチサイズは、より良いスループットをもたらします。
5.3 並列化戦略
最後に、先週、訓練の並列化について話しました。そして、それは複雑で、面倒でした。少なくとも推論の並列化の一つのタイプは、本当に、本当に良くて単純です。単純にモデルのM個のコピーを起動するだけです。通信は必要ありません。なぜなら、モデルを更新する必要がないからです。レイテンシは同じで、スループットはM倍に増加します。これはかなり良いことです。常に簡単なことを忘れないようにしましょう。
十分に大きなモデルがある場合、単一のGPUに収まらないケースもあります。その場合、モデルをシャードする必要があります。この場合、より良い効率を得るために、場合によってはKVキャッシュもシャードしたいことがあります。詳細については、このブック章を確認してください。
Time-to-first-token、これは先ほど言及したメトリックですが、本質的にプリフィルの関数です。基本的に、プロンプトをエンコードするのにどれくらい時間がかかるかということです。通常、これは計算制限なので、基本的に可能な限り速く進んでいます。固定されたアーキテクチャが与えられた場合、それについてできることはあまりありません。
申し訳ありませんが、バッチサイズを減らすことで改善することはできます。しかし、スループットを改善したい場合は、バッチサイズを増加させる必要があります。
これについて何か質問はありますか。これはスループットとレイテンシの計算についてでした。前の部分で与えたメモリ制限の議論のため、メモリに焦点を当て、何バイトが送信される必要があるかを計算しました。そして、それはレイテンシの大まかな境界を与えてくれます。実際には、計算が重要になるレジームもありますが、物事を単純に保つためにそれを無視しています。
6. 推論の高速化:KVキャッシュの削減
6.1 Grouped-Query Attention (GQA)
今、推論ワークロードがどのようなものかを良く理解できました。算術強度を見ました。算術強度に関してTransformerの推論を見ました。KVキャッシュがすべてのシーケンスに特別でなければならないため、attentionでメモリ制限であることがわかりました。
そして、それを使って、私たちが気にする主要な推論メトリックであるスループットとレイテンシを計算することができました。では、どのように物事を良くするのでしょうか。
無損失でできることがいくつかあります。より良いカーネルを書くことができます。システムを改善することができます。しかし、私は、ショートカットを取ることを厭わなければ、できることがたくさんあると言いたいと思います。そして、これらは技術的にはこの講義は推論についてですが、密かにモデルアーキテクチャについてでもあるため、本当に興味深いものです。なぜなら、モデルアーキテクチャの多くの変更が推論に直接的な影響を与えることがわかり、実際に推論を迅速に行う必要性に触発されているからです。
ここでの大きなボトルネックはKVキャッシュです。メモリ制限であることを覚えておいてください。これは、メモリの使用量が少なければ少ないほど、高速になることを意味します。FLOPsの一部でもありますが、主にメモリのためです。なぜなら、それは主にメモリ転送についてだからです。この講義から一つのことを持ち帰るとすれば、それは速度のためのメモリがすべてだということです。
問題は、KVキャッシュを単純に削除し始めると、精度を失う可能性があることです。では、あまり多くの精度を失わずに、KVキャッシュを小さく保つことをどのように確実にできるでしょうか。
KVキャッシュを減らそうとするすべてのアイデアをいくつか紹介します。これらのアイデアのいくつかは見たことがあると思いますが、このより体系的な方法で説明していきます。
grouped-query attentionというアイデアがあります。multi-headed attentionは、バニラTransformerで、基本的にヘッド数を保持します。そして、その数について、同じ数のキー、バリュー、クエリがあります。
一時期、multi-query attentionがありました。これは、1つのキーと1つのバリューしかありません。基本的に、1つのキー・バリューヘッドです。これはあまり表現力がないことがわかりました。そこで、減らされた数のキーとバリューを持ち、より多くのクエリを持つ中間点がありました。
なぜこれを行うのでしょうか。KVキャッシュサイズを減らしたいことを覚えておいてください。キーとバリューが少なければ少ないほど良いのです。バッチサイズとシーケンス長は変更されませんが、これらのベクトルの次元も変更されませんが、減らしているのはキー・バリューヘッド数です。それが基本的なアイデアです。
この論文は、実際にレイテンシとスループットの改善を得ることを示しています。サンプルあたりの時間です。8つ程度までのグループ数を増やすにつれて、基本的にfull attentionと比較してこれは本当に高速になります。グループ数を増やすにつれて、明らかに元に戻ることになります。これがレイテンシとスループットの改善です。
実際にこれをより厳密に行うために、Llama 2-13bモデルがあります。統計を計算すると、これはバッチサイズ64を使用しています。覚えておいてください、これが得られたものです。ここでもレイテンシを印刷すべきでした。そして、GQAで実行すると、メモリが減少し、スループットが大幅に上昇することがわかります。これは実際に素晴らしいことです。
Llama 2-13bアーキテクチャを取り、単純に減らした場合です。キー・バリューヘッドごとに5つのクエリヘッドがあります。それが1対5の比率の意味です。
これは、より大きなバッチサイズを使用できることも意味します。最後に256を試したとき、H100のメモリにも収まらなかったことを覚えておいてください。今では、実際にH100メモリに快適に収まり、より大きなバッチサイズを使用することでスループットをさらに改善できます。
ここで多くの異なる効果を見ることができます。キー・バリューペアの数を減らすことで、KVキャッシュのメモリが減少します。メモリ転送が少なくなるため、スループットとレイテンシが自動的に上昇します。さらに、二次的効果として、GPU内でバッチサイズを増加させることができ、それがスループットをさらに改善します。
それは素晴らしいことです。精度が下がらないことも確認する必要があります。これは、これがfull attentionであることを示す元の論文です。これがGQAです。時間ははるかに少ないですが、精度は基本的に同じです。
実際に何が起こったのでしょうか。Llama 2はこの比率を使用しませんでした。しかし、Llama 3は実際にGQAを取り上げ、おそらく推論コストに動機づけられました。実際、Llama 2では、70の大きなモデルにはGQAがありましたが、小さなものにはありませんでした。
6.2 Multi-Head Latent Attention (MLA)
キー・バリューキャッシュを減らす別の方法があります。これはDeepSeekから来ています。これは実際にDeepSeek V2論文からのもので、multi-head latent attentionと呼ばれ、Tatsuが以前に講義で話しましたが、推論の文脈とその意味について話してみます。
基本的なアイデアは、ここにfull attentionがあります。GQAは、より少ないキーとバリューを使用すると言っています。MLAは、キーとバリューの数を変更するのではなく、これらをより低次元の空間に投影すると言っています。つまり、KVサイズを縮小する別の方法ですが、異なる次元で行うのです。
各トークンのKVキャッシュのN×H次元を使用する代わりに、C次元に投影します。これがDeepSeekが行ったことです。16,000から512への非常に積極的な削減です。唯一の問題は、これがropeと互換性がないことです。そのため、ropeを戻すためにいくつかの次元を追加する必要がありますが、全体的には、これはKV削減の観点から実際に非常に有望です。
数学は行いませんが、KVキャッシュが大幅に削減されることがわかると信じてください。そして、同じ種類のレイテンシとスループットの利点を得ることができます。精度の面では、実際にGQAと比較してMLAが改善することを示しましたが、この表はそれを示していません。
後でそれを掘り起こす必要があります。とにかく、MLAは精度も維持します。
6.3 Cross-Layer Attention (CLA)
cross-layer attentionと呼ばれる別のアイデアがあります。この論文がありますが、多くの人がこれを考え、実行していると思います。これが実際に最初の論文かどうかはわかりません。
基本的に、Transformerの図を見ると、一つのレイヤーのキー・バリュー投影があり、次のレイヤーがあり、これらのキー・バリューベクトルは通常別々です。しかし、CLAでのアイデアは、レイヤー間で同じキー・バリュー投影を使用することです。
それがCross Layer Attentionと呼ばれる理由です。GQAがヘッド間で共有するのと同じように、CLAはレイヤー間で共有します。
ここで、精度とKVキャッシュサイズのパレートフロンティアを経験的に改善することを示しています。スループットとレイテンシに関連するKVキャッシュサイズは小さくしたいものです。そして、パープレキシティも小さくしたいものです。
彼らはそれを改善することができました。例えば、64ヘッドでは、キャッシュサイズは削減されますが、検証パープレキシティは少し上昇することに注意してください。しかし、全体的に、そのトレードオフを行うことに利点があります。
6.4 Local Attention
さらに別の方法があります。local attentionです。これは、Longformer、OpenAIの論文があり、その後MistralやI think多くの他の人がこれを使用しているなど、実際にかなり多く探索されています。
これは非常に自然なアイデアだと思います。full attentionの図を見ると、それは密なN²で、多くの複雑性がそこから来ています。基本的に、アイデアは過去のKトークンのみに注意を向けるということです。これは、KVキャッシュにおいて、シーケンスを生成している時に、すべてを覚えておく必要がないことを意味します。
attentionを持つウィンドウの外にトークンが落ちるとすぐに、それを捨てることができます。local attentionは、シーケンス長に伴って成長するのではなく、KVキャッシュサイズが一定のままであると非常に言えるでしょう。
これは本当に良いことです。なぜなら、長いシーケンスでさえも、非常に小さなキャッシュを持つことができることを意味するからです。
しかし、問題は、これがまだ精度を損なうことです。なぜなら、考えてみれば、なぜRNNの代わりにattentionを行っているのかというと、長距離モデルと長距離依存関係が必要だったからです。そして、これはある意味でattentionと呼ぶのは少し誇大宣伝かもしれません。これはローカルコンテキストのみを見ているので、あまり表現力がありません。
ここでできることは、local attentionをfull global attentionハイブリッドレイヤーと織り交ぜることです。例えば、キャラクターは6レイヤーごとに1つのglobal attentionグローバルレイヤーと5つのローカルレイヤーを持っていました。cross layer attentionに加えて、このようなものに見えます。
full attentionでは、すべてのレイヤーでKVキャッシュを保存する必要があります。彼らが行ったことは、6レイヤーごとにfull attentionを持ちますが、その間にlocal attentionがあるということです。そして、それに加えて、local attentionとglobal attentionの両方でKVキャッシュ共有をローカルに行っています。これは、すべてのトリックではありませんが、多くのトリックが組み合わされているようなものです。
7. 新しいアーキテクチャによる根本的解決
7.1 State Space Models
要約すると、KVキャッシュサイズを減らすためのいくつかの方法があります。推論はメモリ制限であることを覚えておいてください。そのため、キャッシュサイズを減らしたいのですが、精度をあまり損なわないようにしたいものです。それを行う多くの方法があります。KVキャッシュの次元を下げることができます。KVキャッシュベクトルを少なくすることができます。KVベクトルの次元を減らすことができます。レイヤー間でKVキャッシュを共有することができ、また、一部のレイヤーでlocal attentionを使用することもできます。
KVキャッシュを減らすためのこの一連のトリックについて何か質問はありますか。
はい、品質について質問があります。レイヤー間で共有される重みがすべてあると感じます。1セットの重みだけを持っているのか、それともKVのようなものでレイヤー間で共有されているのでしょうか。質問は、重みが共有されているかということです。KVキャッシュが共有されているだけでなく、重みも共有されています。つまり、投影を行うための重みが共有される必要があります。そのため、いくつかの一貫性があります。
別の質問がありました。コンテキストサイズが大きすぎて、KVキャッシュも増加する場合、トランスレーションモデルに与えられるコンテキストでプロンプトが与えられる時のコンテキストサイズです。それが長すぎる場合、サイズを増加させます。そのため、コンテキストを要約してサイズを減らそうとします。
質問は、本当に長いコンテキストがある場合、プロンプトが巨大だとしましょう、それは本質的に多くのKVキャッシュを取ることになります。これらすべてのトリックはそれを減らそうとすることができます。gistトークンや、この授業では話さないプロンプトを要約する方法のアイデアがありますが、長いプロンプト状況にも対処する方法があります。
では、Transformerを変更することによって推論をさらに高速化するさらに根本的な方法について話します。
KVキャッシュ、これらは基本的にTransformerの変種です。しかし、実際にTransformerの外に出て、より良いことができるかもしれません。なぜなら、Transformerは重い推論ワークロードを念頭に置いて実際に設計されたわけではないからです。それらは効率的に良いモデルを訓練しようとしていただけでした。それは主に訓練効率についてでした。
自己回帰は、私たちが指摘したように、自己回帰にfull attentionを加えたものが、ここでこのボトルネックを本当に引き起こしています。state space modelsとdiffusion modelsという2つの方向について話します。これはかなり簡単に行います。
state space modelsのアイデアは、実際に信号処理と制御理論からアイデアを引き出しています。当初、動機はN²の膨張を被ることなく長いコンテキストシーケンスをモデル化しようとすることでした。そのため、それは必ずしも推論速度についてではありませんでした。しかし、その問題を解決すれば、より高速な推論も得られることがわかります。
S4に関する一種の初期の論文があり、これは古典的なstate-space modelsを使用しています。これらは基本的にこれらの線形力学系で、長いコンテキストをモデル化するために使用され、現代のニューラル設定に押し込まれています。
この研究は、線形性構造のためにこのRNNの種類の解釈を持ち、畳み込み解釈も持っている点で良いものです。彼らはこの論文を発表し、これらの長いコンテキスト合成タスクで本当にうまく機能することを示したと思います。
しかし、発見されたことは、まあ、推測ですが、発見されたことは、言語モデリングではそれらが実際にうまく機能しないということでした。そして、それは明らかに失望でした。なぜなら、Transformerの価値の多くは言語をうまく行うことができることだからです。
一連の論文で、彼らはこれらのモデルがうまく機能していない理由の本質を捉えた一連の合成タスクを特定しました。それは基本的にこれらの連想記憶タスクです。
ここに、基本的にキー・バリューペアのシーケンスが与えられる合成タスクがあります。目標は、基本的にキーを検索して値を出力することです。ある意味では、論理的に些細なタスクです。しかし、多くのキー・バリューペアを持つことができるため、長いシーケンスです。そして、任意に遠くを見返さなければなりません。
任意に長い依存関係である可能性があります。local attentionは、最後のいくつかのシーケンスを覚えているだけなので、あまりうまく機能しないことがわかります。state-space modelsの問題は、これらの種類の信号処理タスクには良いということです。しかし、特定のキー・バリューペアを分離し、それらのタイプのタスクに対する答えを引き出す必要がある場合、実際にはうまく機能しませんでした。
引用していない多くの研究があります。hyena、H3、そしてMambaのようなものがあり、基本的にHSSMを調整または変更して、これらの連想記憶タスクを基本的に処理します。そして、最終的に、それはより良く機能しました。1Bスケールまでは、Transformerに匹敵していました。
Mambaのアイデアは人気があり、2.1の人々によって52B MoEまでスケールアップされました。この場合、彼らはまだTransformerを使用しなければならなかったことに注意してください。Transformerでしたが、8レイヤーだけだと思います。彼らはTransformerを持っていました。残りはMambaレイヤーでした。それでも、かなり大きな節約と速度向上につながりました。
7.2 Linear Attention
しかし、最近では、linear attentionと呼ばれるこの古いアイデアの復活があります。大きくできるかどうか見てみましょう。これは実際に非常にシンプルなアイデアです。local attentionやsliding window attentionが何であるかを知っています。Linear attentionは、基本的にattentionの計算において、キーとクエリがあり、それらをドット積し、その指数を取るというアイデアです。これは基本的にexp kernelを与えています。
そのため、基本的にそのTaylor展開を取り、その計算を基本的にある非線形写像のドット積として書くことができます。その後、本質的に持っているのは、すべてのキー・バリュー位置について、基本的に何らかの非線形性を適用し、それをある空間に拡張してから、その上で線形計算を行うことです。
そして、それがlinear attentionであるため、実際にはRNNのように振る舞い、シーケンス長に対して2次ではなく線形です。それは少し速かったですが、味だけを与えたかったのです。
そして、このアイデアは実際にかなりうまくスケールアップされています。MiniMaxと呼ばれるこの組織があり、456億パラメータのMoEまで、かなり合法的なモデルを訓練しています。
彼らは基本的にこのlinear attentionのアイデアを使用しています。時々full attentionを使用する必要があるようです。人々がfull attentionを全く回避することができたとは思いません。しかし、少なくとも、レイヤーの大部分がもはやfull attentionではなく、linear レイヤーかlocal attentionレイヤーのどちらかであり、これははるかに、はるかに効率的であるようです。
linear + local attentionは現在、実際に真剣な最先端モデルを生み出しています。そして、クローズドモデルプロバイダーが実際に何をしているのかは分かりませんが、少なくともこれと同じくらい効率的で、スパース性を活用していることを期待すると言うのは、おそらく安全だと思います。
人々が「attention is all you need」やTransformerについて尋ねる時、興味深い質問です。まあ、イエスでもありノーでもあります。つまり、ある意味では、まだその感覚があると思います。多分それを取り除くことができるかもしれません。しかし、Transformerの大部分は、他のはるかに軽量なコンポーネントを持つことによって、かなり根本的に変更されています。
そして、まだ同じ種類の精度の多くを得ることができます。そして、これはすべて推論にとって本当に役立ちます。なぜなら、これらの非full attentionレイヤーでは、基本的にシーケンス長とともに成長するオーダーTのKVキャッシュを一定のものに置き換えているからです。
フォローアップ研究があります。BASEDペーパーにあると思います。どこに行ったのでしょうか。このペーパーまたはフォローアップ研究で、基本的にKVサイズと様々なタイプのリコールタスクを実行する能力の間のトレードオフを分析しています。これは理にかなっています。なぜなら、あまり多くを保存しない場合、特定のタスクを解決することができないからです。しかし、遊ぶことができるこのトレードオフ曲線があります。
State-space modelsについて言うのはこれだけです。
7.3 Diffusion Models
では、完全に異なるスタイルの生成モデルについて話しましょう。Diffusion modelsです。Diffusion modelsは画像生成で非常に人気がありますが、テキストで動作させるのはかなり困難であることがわかります。ただし、最近ここでいくつかの進歩がありました。
Diffusionのアイデアは、自己回帰的に生成する代わりに、すべてのトークンを並列で生成するということです。明らかに、単純なレイヤーを通してそれを行うだけでは、あまり良くないでしょう。すべての単語を並列で生成し、それが一貫していることを期待することはできません。しかし、行うことは反復し、最終的に出力する最終生成に到達するまで、この生成を洗練し続けることです。
並列で生成するというアイデアの背後にあるのは、もはや自己回帰的に制約されないということで、すべてのトークンを並列で生成することは並列で行うことができます。そのため、コンテキスト長が十分に大きい限り、GPUを比較的簡単に飽和させることができます。
最近、Inception Labsがいくつかの非常に興味深いモデルを制作しました。それらについてはあまり書かれていませんが、生成と処理のデモを見ることができます。それは明らかに壊れたコードを瞬時に生成するだけですが、時間をかけて洗練されていきます。
これは、少なくともコーディングにおいて、他のタスクについては分からないという彼らのベンチマークの一つです。秒あたりのトークン数を見ると、これらのモデルは速度の面でTransformerの何よりもはるかに上にあります。
Jambaでさえ、覚えておいてください、それはハイブリッドMamba-Transformerアーキテクチャのようなものでしたが、これらのdiffusion modelsと比較するとかなり遅いです。
Diffusion modelsが他のすべてのタスクで十分に汎用的で強力かどうかは、まだ分からないままです。しかし、ここでのトークン速度でこのようなリードを持っているので、精度の損失を回復するために、より多くの計算を投入することができたとしても、必要であればそれを考えます。
ここでの要約は、この全体の新しいアーキテクチャのことが推論にとって実際に非常にエキサイティングだということです。なぜなら、根本的な障害を回避することができるからです。Attentionを扱っている場合、単に量子化し、最適化することができるこの根本的なKVキャッシュ障害がありますが、それはまだそこにあります。
State-space modelを作ることによって、それを定数サイズに縮小しています。精度を維持できる限り、これは大きなifですが、そうすれば大勝利します。Diffusion modelsでも同様です。自己回帰生成は重要なボトルネックです。物事を並列で生成するだけであれば、ゲームを完全に変えることになります。
推論を改善することで、ここでやるべき仕事がはるかに多くあります。ご覧のように、推論ゲームは最初に見えるよりもはるかに広いものです。システムの最適化を行って高速化することについては必ずしもありませんが、明らかにそれらは必要です。しかし、本当の利益はアーキテクチャの根本的な変化から来ていると思います。
8. 量子化とモデル剪定
8.1 量子化の基本概念
残り約10分です。これらを簡単に説明していきます。量子化とモデル剪定です。
量子化の重要なアイデアは、数値の精度を減らすということです。とても簡単に行うことができます。そして、考え方は、メモリが少ないということは、転送されるバイトが少なく、レイテンシが低く、スループットが高いということです。
申し訳ありません、これは低いレイテンシ、高いスループットであるべきです。そして、もちろん精度を心配する必要があります。それがトレードオフです。
異なるタイプのフォーマットを見ると、FP32は訓練に使用されますが、推論には使用されません。実際に、BF16が推論のデフォルトです。FP8やINT8まで下げることができ、これはより正確性が低くなりますが、FP8よりもはるかに安価です。人々は実際にINT8で多くの推論を行います。範囲を見ると、127から負の128の間の整数で、それほど多くありません。かなり低精度です。
人々はINT4まで下げることさえあります。これらはありません。INT4はかなり低いです。また、他の方法でも行うことができます。
量子化を行うことを決めたら、いくつかのことができると思います。量子化で訓練することもできますが、明らかにそれはモデルを再訓練する必要があることを意味し、より一般的には、既存のモデルを取って量子化しようとし、あまり物事を台無しにしないようにする訓練後量子化を行います。
8.2 LLM-int8()手法
LLM-int8()と呼ばれる論文があり、これを簡単に説明します。量子化では、基本的に起こることは、FP16であるベクトルを取り、それをINT8にパックしたい場合、動的範囲を把握する必要があるということです。最大値が何であるかを把握する必要があります。それを把握したら、それで割って128を掛けることができます。
そして、整数が得られます。そして、逆量子化する必要がある場合、逆の方向に進みます。基本的に量子化は、メモリが帯域幅のボトルネックであることを覚えておいてください。つまり、すべての転送がデータで起こっています。しかし、実際に行う時、時々算術を実際に行うために浮動小数点にアップキャストする必要があります。
INT8の問題は、すべてがうまく収まるわけではないということです。そして、大きなネットワークに現れる外れ値があり、物事を台無しにします。この論文が行ったことは、この行列を取り、本当に大きな外れ値を特定することです。そして、それらを完全な16ビット精度を使用して別々に処理し、その後、大部分をINT8で行うということです。
これはうまく機能しますが、実際には少し遅いです。ここでの動機は推論速度ではなく、モデルをメモリに収めることができることでした。
8.3 Activation-aware量子化
Activation-aware quantizationと呼ばれる別の論文があります。ここでのアイデアは、重みを量子化しているが、活性化に基づいてどの重みを量子化するかを把握するということです。
本当に簡単に、実際にINTまで下がり、これは明らかにメモリをかなり削減し、3倍の速度向上につながります。ここでの一般的なアイデアは、訓練済みモデルを得るということです。そして、重みや活性化の一部が異常に大きくなることがたまたま起こります。それらについては別々に処理し、その他すべてについては低精度で作業できます。
8.4 モデル剪定技術
モデル剪定のアイデアについて話します。非常に軽い量子化です。基本的なアイデアは非常にシンプルです。高価なモデルの部分をちぎり取って安くし、その後修正するのです。
このNVIDIAの論文では、まず小さなキャリブレーションサイズを使用して、重要なレイヤーまたはヘッドまたは隠れ次元を特定します。それらを計算するためにいくつかのシンプルなスコアを使用します。その後、重要でないレイヤーや隠れユニットやヘッドを単純に削除します。
その後、そのモデルを取ると、明らかに悪くなります。そこで最後のステップは、元のモデルを剪定されたモデルに蒸留することです。つまり、ゼロから始めるのではなく、初期化である剪定されたものから始めて、モデルを修復します。
ゼロから始めるのではありません。より悪いが、元のモデルと同じ構造的特性の多くを保持していることを願っているものから始めます。単にある意味で調整されていないだけです。そして、その結果はかなり良いものです。
彼らは、150億パラメータモデルを持っており、それを8Bに削減することができ、少なくともMLUによると、ほとんど何の低下もありません。そして、4Bまで削減すると、いくらかの低下がありますが、4Bモデルまでかなり下がっています。
9. Speculative Decoding
9.1 基本アイデア:チェックは生成より高速
このショートカットを取るアイデアをまとめると、精度を損なうことなく推論の複雑さを減らすことができます。ゼロから行うことができ、構造上高速で、単にそれを訓練する新しいアーキテクチャを定義します。または、蒸留を行うことができます。そのアーキテクチャを定義し、遅いモデルを取り、古いモデルで新しいモデルを初期化するスキームを考え出し、その後基本的に蒸留を行います。
これらすべては、損失があるため少し不満足です。大幅な速度向上を得ますが、このモデルが実際に元のモデルと同じくらい良いのかといつも疑問に思います。Speculative decodingまたはspeculative samplingは、基本的にケーキを食べて、それも持つことを可能にします。
推論には2つの段階があることを思い出してください。シーケンスが与えられ、すべてのトークンを並列でエンコードするプリフィルがあります。これは計算制限で、素晴らしいものです。また、これは各トークンの対数確率も与えることに注意してください。その後、一度に一つのトークンを生成する生成があります。これはメモリ制限です。遅いのです。
言い換えれば、チェックは生成よりも高速です。直感的にはこれは理にかなっています。しかし、今はなぜこれが真実であるかの数学も理解できることを願っています。
Speculative samplingのアイデアは実際に本当に、本当にエレガントです。これはGoogleの2つの独立したチームによって並行して提案されました。
アイデアは、安価なドラフトモデルpを使用していくつかのトークンを先に実行して生成することです。その後、ターゲットモデルでそれらのトークンを評価します。与えられたトークンの評価は単にプリフィルなので、並列で行うことができ、これは高速です。そして、良く見える場合はそれを受け入れます。
これが実際の生活でどのように見えるかです。大きなモデルを使用して一度に一つのトークンを生成している場合、それは遅いです。しかし、speculative decodingでは、先に進んで多くのトークンを生成するドラフトモデルがあり、大きなモデルを使用して基本的に検証します。そして、時々拒否し、時々受け入れます。受け入れ率が基本的にどれだけの速度向上を得るかを決定します。
9.2 アルゴリズムの詳細
より正式なアルゴリズムは次のとおりです。Kの先読みを持つことになります。ドラフトモデルを使用してKトークンを自己回帰的に生成します。ドラフトモデルが小さいため、これが高速であることを期待しています。その後、生成したこれらのKトークンが与えられます。ターゲットモデルQの下でそれらをスコア付けします。
この時点で、これを受け入れるかどうか決定します。各トークンを通して進み、基本的に確率QをPで割った値で受け入れます。1は、この確率が0と1の間にあることを確認します。これは、Metropolis-Hastingsに馴染みのある人にはそのように見えるかもしれません。ここから来ています。
直感的に、Pでサンプリングしています。Pは欲しくないので、それを割って除去する必要があります。Qが欲しいのです。これが重要な重み付けです。受け入れる場合は、素晴らしいです。次のドラフトトークンに移動して、そのように続けます。受け入れない場合は、ターゲットモデル、遅いモデルからサンプリングしますが、既にpを使ってサンプリングを試みているため、この修正を行います。
そのため、もうそれを行う必要はありません。それを引いて、Qからサンプリングします。これは基本的に提案PとターゲットQを持つ棄却サンプリングです。唯一の違いは、相互作用サンプリングでは、棄却すると、棄却して、再試行して、再試行することです。ここでは、永続的にループし続けたくないため、棄却した場合は、「わかりました、諦めて、より高価なモデルからサンプリングします」と言います。
ここでの素晴らしい点は、ターゲットモデルから正確なサンプルを得ることが保証されていることです。サンプリングに馴染みのある人にとって、これはそれほど驚くべきことではないはずです。事前情報を使用してサンプリングを高速化することができます。しかし、言語モデリングの文脈では、これはかなり良いものです。
9.3 数学的保証と実験結果
これをスキップします。これは実際には証明ではありません。語彙が2の場合について、なぜこれらの式が正しい偏りのないサンプリング手順を与えるのかを示すいくつかの導出です。
そして、これはかなりうまく機能します。同じモデルなので、精度は実際に同じであるべきです。しかし、そこにはいくらかのランダム性があるかもしれません。しかし、速度向上は本質的に2倍の速度向上を得ています。
実際には、70Bモデルのようなものを持ち、ドラフトモデルははるかに、はるかに小さいです。ターゲットモデルが8Bの場合、ドラフトモデルは1Bかもしれません。一般的に、ドラフトモデルをターゲットにできるだけ近づけたいのです。蒸留を行っている場合、それはさらに良くなる可能性があります。
これは推論における非常にホットな研究分野です。このプロセスを改善する多くの方法があります。Medusaを使用することができます。これは、ドラフトモデルが自己回帰的に生成する代わりに、並列で複数のトークンをサンプリングする方法です。または、EAGLEでは、実際にターゲットモデルの高レベル特徴を取り、それらをドラフトモデルに送り込んで生成させます。
つまり、ドラフトモデルは実際に単独で立つ必要はありません。生成を助けるためにターゲットモデルに接続することができます。
要約すると、数学のおかげでターゲットモデルからの厳密なサンプリングです。そして、これはチェックと生成の間の対称性を利用しています。またはプリフィルと生成です。そして、ドラフトモデルに多くの革新の余地が実際にあります。これまで話したすべて、異なる根本的なアーキテクチャ、量子化の異なる方法を持つことができ、これらすべてが適用されます。
唯一のことは、基本的に正確なサンプルを得ることを保証できることです。
10. 実用的な推論システム
10.1 動的バッチング
時間がありませんが、先ほど出てきた質問を簡単に説明します。それは、実際にサービスを提供する際には、ライブトラフィックがあるということです。リクエストは異なる時間に到着します。異なる時間に終了します。一部は共有プレフィックス を持ちます。一部は持ちません。異なる長さを持ちます。つまり、基本的に密なトークンのブロックを取得し、それをフルスピードでGPUに押し込む訓練と比較して、非常に異質です。
この場合、どうするのでしょうか。これを探求する一連の論文があります。基本的なアイデアは、これは最後の部分がよりシステムレベルの貢献だということです。
アイデアは、バッチが終了するのを待たないということです。電車は出発します。電車はあなたを待ちません。新しいバッチが来ると、単純にそれを入れるのです。これは、トークンを生成しているワーカーが、すべてのステップでスケジューラーに制御を戻す必要があることを意味します。
トークンを生成し、スケジューラーに戻って、新しいリクエストがあるかどうかを確認し、それらを詰め込み、続行します。つまり、リクエストを待ち回って時間を無駄にすることはありません。
バッチングには質問の背後にある問題があると思います。すべてが同じ次元にある場合、バッチングは機能しますが、すべてのリクエストが異なる長さである可能性があります。
選択的バッチングというアイデアがあります。基本的に、attentionの計算を分解します。すべてを別々に処理する必要があります。しかし、MLPについては、覚えておいてください、これは計算の大部分ですが、実際には異なるサイズのテンソルを取り、それらを平坦化することができます。相互作用しないため、バッチ次元で単純に同行できます。
10.2 PagedAttention
PagedAttentionについて簡単に説明します。これはvLLMの背後にある論文で、皆さんの一部が使用したことがあるかもしれません。これはメモリ使用量の問題に対処します。
KVキャッシュがあり、プロンプトが入ってきて終了する場合、キャッシュが断片化されることになります。リクエストに対してスペースを割り当てることになりますが、何個のトークンを生成するかわからないのです。そのため、内部断片化が発生します。また、リクエストとレスポンスの間にパディングがある外部断片化も発生します。これは良くありません。
PagedAttentionは基本的に、オペレーティングシステムを覚えていますか?そして仮想メモリがどのように動作するかを覚えていますか?KVキャッシュを連続するブロックのシーケンスに分割します。
その後、空白スペースを見つけた場所にそれらを置くだけです。2つのリクエストが入ってくる場合、最初のリクエストはここ、ここ、ここにあり、2番目のリクエストはここ、ここにあるかもしれません。ブロックは連続性を保つものです。そして、それがメモリを結合することを可能にします。
プレフィックスの共有がある場合、オペレーティングシステムからのもう一つのアイデアであるcopy on writeを使用してこれらのトリックを実行することもできます。基本的に、特定のブロックを使用している基本的にシーケンスの数について参照カウンターを維持します。そして、ブロックが異なる方向に分岐する必要がある場合、コピーして参照カウントを減らします。
10.3 メモリ断片化の解決
多くの他のvLLM最適化がありますが、それらについては説明しませんが、基本的な要約は、オペレーティングシステムのクラスを覚えておいてください。それらを推論にも適用することができます。
PagedAttentionの核心的なアイデアは、メモリ断片化の問題を解決することです。従来のアプローチでは、各リクエストに対して連続したメモリ空間を割り当てる必要がありましたが、リクエストの長さが事前に分からないため、メモリが無駄になったり、断片化が発生したりしていました。
PagedAttentionは、オペレーティングシステムの仮想メモリ管理の概念を借用しています。KVキャッシュを固定サイズのブロックに分割し、これらのブロックを物理メモリの任意の場所に配置できるようにします。各シーケンスは、必要に応じてブロックを動的に割り当てることができ、ブロックは物理的に連続している必要がありません。
さらに、プレフィックス共有の場合、copy-on-writeメカニズムを使用します。複数のシーケンスが同じプレフィックスを共有している場合、同じ物理ブロックを参照することができます。参照カウンターを使用して、何個のシーケンスが特定のブロックを使用しているかを追跡します。シーケンスが異なる方向に分岐する必要がある場合にのみ、新しいブロックをコピーして割り当て、参照カウンターを更新します。
このアプローチにより、メモリ使用量を大幅に最適化し、より多くのリクエストを同時に処理できるようになります。
11. まとめと今後の展望
11.1 推論の特徴:メモリ制限と動的性
簡単な要約です。推論は本当に、本当に重要です。そして、その特徴は訓練とは異なっています。メモリ制限であり、また動的でもあり、これが多くの新しい課題をもたらします。
推論の根本的な特徴を理解することが重要です。まず、推論はメモリ制限であるということです。これは、特にTransformerアーキテクチャにおけるKVキャッシュの性質に起因します。各シーケンスが独自のKVキャッシュを必要とするため、attentionレイヤーの算術強度は常に1程度となり、メモリ帯域幅がボトルネックになります。
訓練では、すべてのトークンを同時に見ることができ、シーケンス全体で並列化が可能でした。しかし、推論では、特に生成段階において、トークンを順次生成する必要があります。これは、過去のすべてのトークンに依存するためです。この順次性が、利用可能な計算リソースを十分に活用することを困難にしています。
さらに、推論は動的な性質を持ちます。実際のサービス環境では、リクエストが異なる時間に到着し、異なる長さを持ち、異なる時間に完了します。これは、訓練時の密なトークンブロックを一定速度で処理する状況とは大きく異なります。このような動的な特性は、効率的なバッチ処理や資源管理において新たな課題を生み出します。
メモリ帯域幅が主要なボトルネックであるため、推論の最適化においては、計算量の削減よりも、メモリ転送量の削減が重要になります。これが、なぜ多くの最適化手法がKVキャッシュサイズの削減に焦点を当てているかの理由です。
11.2 アーキテクチャ変更の重要性
新しいアーキテクチャ、量子化、剪定、蒸留、speculative decodingに関する様々な技術の全体像を見てきました。通信と計算をオーバーラップさせ、メモリをより良く使用するシステムからのアイデアもあります。
しかし、私は、モデリングとアーキテクチャにおそらくさらに多くの機会があると言いたいと思います。なぜなら、考えてみると、すべて推論を狭く捉えれば、特定のモデルにおける推論です。この特定のモデルをどのように実行するかということです。しかし、その特定のモデルなど誰が気にするでしょうか。あなたがケアしているのは、リソース予算が与えられた時に良い精度を提供することです。
KVキャッシュを減らそうとしているこれらのアイデア、Transformerを変更することの多くは、基本的に問題を回避して、「まあ、より効率的なものがあります。そして、より良い精度を得る方法でそれを訓練できれば、私の勝ちです」と言う方法です。
推論における最も重要な洞察の一つは、システムレベルの最適化だけでは限界があるということです。真の飛躍的な改善は、根本的なアーキテクチャの変更から来ています。従来のTransformerアーキテクチャは、主に訓練効率を念頭に置いて設計されており、推論時の制約は十分に考慮されていませんでした。
KVキャッシュの問題は、Transformerアーキテクチャに内在する根本的な制約です。いくら最適化やシステム改善を行っても、この制約は残り続けます。一方で、state-space modelsやlinear attention、diffusion modelsのような新しいアーキテクチャは、この制約を完全に回避することができます。
これらの新しいアーキテクチャアプローチは、問題を根本から解決します。State-space modelsは、KVキャッシュを定数サイズに削減します。Diffusion modelsは、自己回帰生成の制約を完全に取り除き、並列生成を可能にします。これらは、推論効率において桁違いの改善をもたらす可能性があります。
重要なのは、特定のモデルアーキテクチャに固執することではなく、与えられたリソース制約の下で最高の精度を達成することです。そのため、アーキテクチャの革新こそが、推論効率における次の大きなブレークスルーをもたらすと考えられます。
11.3 研究の方向性
それが私が持っているすべてで、次回またお会いしましょう。そして、スケーリング則に戻ります。
推論の研究における今後の方向性を考える時、いくつかの重要な点が浮かび上がります。まず、推論は単なる技術的な問題ではなく、実際のアプリケーションにおいて極めて重要な要素であることが明らかになりました。OpenAIが1日に1000億語を生成し、Cursorが1日に10億行のコードを生成しているという事実は、推論効率の改善が実世界に与える影響の大きさを示しています。
今後の研究においては、システムレベルの最適化も重要ですが、真の革新はアーキテクチャレベルの変更から生まれると予想されます。State-space models、linear attention、diffusion modelsといった新しいアプローチは、従来のTransformerの根本的な制約を回避する道筋を示しています。これらの手法は現在も急速に発展しており、さらなる改善が期待されます。
特に注目すべきは、推論効率を考慮したアーキテクチャ設計の重要性です。従来は訓練効率が主要な関心事でしたが、実用的なアプリケーションでは推論効率こそが決定的な要因となります。今後は、推論効率を最初から考慮に入れたアーキテクチャ設計が主流になると考えられます。
また、speculative decodingのような数学的に保証された高速化手法は、既存のモデルに適用可能であり、immediate impactを持つ重要な研究領域です。これらの手法と新しいアーキテクチャを組み合わせることで、さらなる効率向上が期待できます。
最終的に、推論研究の目標は特定のモデルを高速化することではなく、与えられたリソース制約の下で最高の性能を達成することです。この観点から、アーキテクチャ、システム、アルゴリズムの全てのレベルでの継続的な革新が必要となります。