※本記事は、Percy Liang氏による Stanford CS336「Language Modeling from Scratch」講義の内容を基に作成されています。講義の詳細情報や全コースプレイリストは https://stanford-cs336.github.io/spri... でご覧いただけます。本記事では、講義の内容を要約しております。なお、本記事の内容は原講義の見解を正確に反映するよう努めていますが、要約や解釈による誤りがある可能性もありますので、正確な情報や文脈については、オリジナルの講義動画をご視聴いただくことをお勧めいたします。
登壇者紹介
Percy Liang氏 スタンフォード大学コンピュータサイエンス学部准教授、Foundation Models研究センター(CRFM)所長
Tatsunori Hashimoto氏 スタンフォード大学コンピュータサイエンス学部助教授
本講義は、Stanford Onlineを通じて提供されており、スタンフォード大学の学術・専門教育へのアクセスを世界中に拡大する取り組みの一環として、Stanford Engineering Center for Global & Online Education(CGOE)により運営されています。Stanford Onlineの詳細は https://online.stanford.edu/ でご確認いただけます。
1. 推論の基礎と重要性
1.1 推論とは何か
Percy Liang氏:推論の質問は非常にシンプルなものです。訓練済みの固定されたモデルが与えられた時、プロンプトに対する応答を生成することです。まず、推論が何を意味するのか、そしてそれが伴うワークロードの含意を理解することから始めましょう。
推論とは、既に学習が完了したモデルを使用して、入力されたプロンプトに基づいて新しいトークンを生成するプロセスを指します。これは訓練段階とは根本的に異なる性質を持つ作業であり、モデルのパラメータを更新するのではなく、固定されたパラメータを使用して出力を生成することに焦点を当てています。
推論は今回の講義で初めて扱うトピックです。昨年の講義では推論について扱いませんでしたが、実際には複数の講義にまたがって扱うことができるほど深いトピックであり、今回は1つの講義に凝縮して説明することになります。推論は実際には非常に深い領域であり、多くのサブトピックが存在することを理解しておくことが重要です。
1.2 推論が必要となる場面
Percy Liang氏:推論は複数の異なる場面で登場します。最も明確な場面は、実際にモデルを使用したい時です。チャットを行いたい場合、Cursorやその他のツールを使用してコード補完を行う場合、言語モデルを使用してバッチデータ処理ジョブを実行する場合など、これらすべてのケースでは実際のモデルからトークンを生成する必要があるため、推論が求められます。
しかし、推論は他の文脈でも登場します。モデルを評価したい場合、例えば命令追従について評価する際にも推論が必要になります。また、テスト時計算に大きな関心が寄せられていますが、これは実際に最終的な答えを出力する前により多くの思考を行うことを意味し、これもまた推論の一部です。なぜなら、思考は基本的にトークンを生成することだからです。
さらに、訓練そのものにおいても、強化学習を使用している場合、応答をサンプリングしてから何らかの報酬に基づいてそれらを評価する必要があり、これもまた推論を必要とします。
つまり、推論は単に「チャットボットのデモを立ち上げたい」というだけのものではありません。推論は実際には言語モデルの多くの基本機能の基盤となるものです。この講義は1つの講義ですが、推論が実際にいかに多くのことにとって重要であるかを強調したいと思います。そして、クラスの後半でアライメントについて話す際にも、おそらくこのトピックに戻ってくることになるでしょう。
1.3 推論の重要性と実用例
Percy Liang氏:推論が重要である理由について説明しましょう。このクラスのテーマは効率性であり、効率性が明らかに重要である理由があります。訓練は一回限りのコストですが、推論は複数回繰り返されます。
推論がいかに大きな取り組みであるかを示すいくつかの逸話的な統計データを紹介します。Samによると、OpenAIは1日に1000億ワードを生成しているとのことで、これは非常に多い量です。さらに、それほど新しくない製品であるCursorでさえ、毎日10億行の受け入れられたコードを生成していると言われています。
これらの数値は、推論がどれほどの量を占めているかのアイデアを与えてくれ、訓練と比較した推論のコストは確実に増加しています。推論の重要性は、これらの実用例からも明らかであり、実際のプロダクションシステムにおいて推論処理が占める割合とその経済的インパクトの大きさを物語っています。
現在、推論は単なる研究上の関心事ではなく、実際のビジネスにおいて重要なコスト要因となっており、その効率化は企業にとって直接的な経済効果をもたらす重要な技術課題となっています。
2. 推論の評価指標
2.1 Time to First Token (TTFT)
Percy Liang氏:良い推論がどのように見えるかを測定する方法について説明します。まず、Time to First Token(TTFT)があります。これは、個々のユーザーが最初の生成が行われるまでに待つ必要がある時間を指します。
TTFTは、インタラクティブなアプリケーションにとって明らかに重要です。大きなプロンプトがあり、そこで10秒間待たなければならない場合、それは良いユーザーエクスペリエンスとは言えないでしょう。
TTFTは特にリアルタイムでユーザーとやり取りするアプリケーションにおいて重要な指標です。ユーザーが質問やプロンプトを入力してから、システムが最初のトークンを返すまでの時間が長すぎると、ユーザーはシステムが応答していないと感じ、体験の質が大幅に低下します。この指標は、システムの応答性を直接的に表現する重要な測定基準となっています。
2.2 レイテンシ
Percy Liang氏:レイテンシは、最初のトークンの後にトークンがどれだけ速く到着するかを示します。これもまたインタラクティブなアプリケーションにとって重要です。
レイテンシは、ユーザーが最初のトークンを受け取った後の継続的な応答速度を測定する指標です。これは、システムが一度応答を開始した後、残りの回答をどれだけスムーズに提供できるかを表しています。インタラクティブなアプリケーションでは、ユーザーは回答が途切れ途切れに表示されるよりも、一定の速度で連続的に表示されることを期待しています。
この指標は、個々のユーザーの体験に直接影響を与えるものであり、特にリアルタイムでの対話や長い回答を生成する際の品質を決定する重要な要素となります。レイテンシが高い場合、ユーザーは各トークンの間で不自然な待機時間を経験することになり、全体的な使用感が損なわれます。
2.3 スループット
Percy Liang氏:スループットは少し異なるものです。スループットは、個々のユーザーではなく、全体のユーザーに対して一般的に単位時間あたりに生成されるトークンの数を指します。これは特にバッチ処理アプリケーションで有用です。
スループットについて考える際、高いスループットが低いレイテンシを意味するわけではないことを理解する必要があります。なぜなら、一部のリクエストは非常に長い時間がかかる可能性があり、それでも高いスループットを維持できるからです。レイテンシは、個々のユーザーに対する最悪ケースのようなものと考えることができます。
この区別は重要です。スループットはシステム全体の処理能力を表しており、複数のユーザーからのリクエストを同時に処理する際の効率性を測定します。一方で、レイテンシは個々のユーザーの体験に焦点を当てています。システム設計者は、これらの指標間でトレードオフを考慮する必要があり、用途に応じてどちらを優先するかを決定しなければなりません。バッチ処理システムでは一般的にスループットが重視され、インタラクティブシステムではレイテンシがより重要になります。
3. 推論ワークロードの特性
3.1 訓練と推進の違い
Percy Liang氏:推論の効率性について考える際に何を考慮する必要があるでしょうか。訓練では、少なくとも教師ありの訓練において、重要なアイデアは全てのトークンを見ることができるということです。これは、シーケンス全体にわたって並列化できることを意味します。
これはTransformerで大いに活用されています。Transformerの訓練を行ったことがある方なら、基本的にシーケンス全体にわたってこれらのテンソルを構築し、それがテンソル、テンソル、テンソル、行列乗算の繰り返しであり、そして出力を得ることを知っているでしょう。
しかし、少なくともTransformerにおける推論の重要な定義的特徴は、逐次的に生成しなければならないということです。並列化することができません。なぜなら、トークンの生成は過去のすべてに依存するからです。
この違いが推論を大幅に困難にする重要な要素となります。特に、利用可能なすべての計算リソースを活用することが難しくなり、後で詳しく見るように、メモリ制限を受けることになります。この根本的な違いが、推論を訓練とは全く異なる技術的挑戦にしているのです。訓練では効率的だったアプローチが、推論では必ずしも有効ではない理由がここにあります。
3.2 逐次生成の制約
Percy Liang氏:推論における重要な定義的特徴は、逐次的に生成しなければならないということです。並列化することができません。なぜなら、トークンの生成は過去のすべてに依存するからです。
この逐次生成の制約は、Transformerアーキテクチャにおいて特に顕著に現れます。訓練時には、入力シーケンス全体が既知であるため、各位置でのアテンション計算を並列に実行できました。しかし推論時には、次のトークンを生成するために現在までに生成されたすべてのトークンの情報が必要になります。
つまり、n番目のトークンを生成するためには、1番目からn-1番目までのすべてのトークンが既に生成されている必要があり、これらを並列に処理することは不可能です。各トークンの生成は前のトークンの完了を待つ必要があり、この依存関係の連鎖が推論プロセス全体を本質的に逐次的なものにしています。
この制約により、たとえ大量の計算リソースが利用可能であっても、それらを効果的に活用することが困難になります。GPU上の多数のコアが利用可能であっても、逐次的な依存関係により、一度に実行できる有意な並列計算の量が制限されてしまうのです。
3.3 計算とメモリの制限
Percy Liang氏:この逐次生成の制約が推論を大幅に困難にする重要な要素となり、特に、利用可能なすべての計算リソースを活用することが難しくなり、後で詳しく見るように、メモリ制限を受けることになります。
推論において、逐次的な性質により、大量の計算リソースが存在しても、それらを効率的に活用することができません。訓練時のように大きなテンソル演算を並列実行することができないため、GPUやTPUの計算能力を十分に利用できない状況が生じます。
さらに重要なのは、推論プロセスがメモリ制限を受けるということです。各トークンの生成において、過去のすべてのトークンの情報を保持し続ける必要があるため、シーケンスが長くなるにつれてメモリ使用量が増加し続けます。このメモリの要件は、特にKVキャッシュの形で顕在化し、推論の効率性における主要なボトルネックとなります。
計算リソースは豊富にあっても、メモリの制約により実際の推論速度が制限されるという状況が生まれ、これが推論を訓練とは根本的に異なる技術的課題にしています。この理解が、後に説明する様々な最適化手法の動機となっています。
4. 算術強度(Arithmetic Intensity)の分析
4.1 行列乗算の算術強度
Percy Liang氏:算術強度について復習しましょう。これは、何かが計算制限か、メモリ制限かを特徴付けるのに役立ちます。基本的なマトリックス乗算から始めましょう。
行列XをB×D、行列WをD×Fとして、この計算に色を付けるために、BはバッチサイズD は隠れ次元、Fはゲート付きMLPでのアップ投影行列とします。X×Wを行う際のFLOPSの数とメモリ読み書きの数を数えてみましょう。
まず、ゼロで初期化します。この計算を行うために、HBMからXを読み取る必要があります。これは、すべてがBF16であると仮定して、2×B×Dのメモリコストが発生します。また、Wも読み取るので、それは2×D×Fです。次に行列乗算を実行し、これは2×B×D×Fのフロップを必要とします。これは最初の講義からの内容なので、復習のはずです。そして、結果を書き戻す必要があり、これは別の転送を支払うことになります。
総フロップ数は行列乗算のみで、転送されるバイト数は基本的に読み書きされるすべての行列のサイズです。算術強度は基本的にその比率です。
この比率は特定の式になります。一般的に、バッチサイズはDやFよりもはるかに小さく、Bは数百かもしれませんが、DやFは数千または数万かもしれません。ここで私はSimpを使用して、間違いを避けるために物事を少し単純化しています。基本的にCを無限大にし、DをC×Bとしてスケールし、FをC×Bとしてスケールすると、Bの簡略化された方程式が得られます。
つまり、この特定の行列乗算における算術強度はBです。これをどのように解釈するかというと、転送されたバイトあたりにどれだけのフロップが実行されるかということです。
4.2 H100での計算限界とメモリ限界の境界
Percy Liang氏:次に、アクセラレータを見てみましょう。H100の場合、毎秒のフロップは989テラフロップ、メモリ帯域幅は毎秒3.3テラバイトです。これを割ると、アクセラレータ強度と呼ばれるものが得られます。
計算強度(Bです)を見ると、これがアクセラレータ強度より大きい場合、計算制限されていることを意味します。つまり、すべてのGPUまたはTPUを使用できることを意味します。それより小さい場合、メモリ制限されており、これは良くありません。
この行列乗算の場合、BがH100の場合の295より大きい場合に計算制限されます。これらすべては少し理想化されており、実際の詳細はもう少し複雑ですが、これは一次近似を与えてくれます。
極端なケースでは、バッチサイズが300のバッチを使用すれば、GPUを飽和させることができることを意味します。しかし、バッチが本当に小さい場合はどうなるでしょうか。特に、B=1の場合、これは基本的に行列ベクトル積に対応しますが、算術強度は基本的に1になり、これは本当に悪いです。
これは、メモリ制限されることを意味し、直感的に理解できます。基本的にこのD×F行列を読み書きしているだけで、基本的に同じ数のフロップを実行しているからです。フロップと読み取りの比率は同じで、これは1を与え、1は悪いです。メモリ読み取りが遅いため、メモリ読み取りに対して多くのフロップが実行されることを望むからです。
しかし、これは本質的に生成で起こることです。トークンごとに進行するからです。基本的に、算術強度が1のようになり、それが生成がメモリ制限され、計算制限されない理由です。
4.3 バッチサイズの影響
Percy Liang氏:バッチサイズがなぜ1より大きくないのかという質問について答えましょう。バッチサイズは、後でバッチサイズ×シーケンス長を意味することになります。
この非常にシンプルな例が、生成が遅くなる理由の核心を捉えていると思います。行列乗算がコアな計算であることを学び、行列乗算を研究して、必要なフロップ数を読み書きの数で割って数え、その比率である算術強度が次元の一つ(この場合はバッチ次元)に依存することを示しました。
大きな行列が良いのは、計算を飽和させることができるからです。一方、薄い行列、例えばB=1の場合、メモリから読み取ることに多くの時間を費やし、それほど多くの計算を行わないため、本当に悪いです。
この理解は、なぜ生成において効率が悪くなるかの根本的な理由を説明します。生成時には、各トークンを個別に処理する必要があるため、バッチサイズが効果的に1になってしまいます。これにより算術強度が1になり、GPUの計算能力を十分に活用できない状況が生まれます。
逆に、複数のリクエストを同時に処理できる場合、バッチサイズを大きくすることで算術強度を改善し、より効率的にハードウェアを活用することが可能になります。これが、後で説明する様々なバッチング戦略の重要性の基礎となっています。
5. Transformerにおける推論の詳細分析
5.1 KVキャッシュの仕組み
Percy Liang氏:推論の算術強度について話す前に、推論がどのように見えるかをより詳しく理解しましょう。
最も素朴にできることを想像してみてください。これらの素晴らしい図はこの本から取られています。Transformerがあり、プロンプトを与え、次のトークンの語彙に対するロジットを与え、そこからサンプリングします。そのトークンを取得したら、それをプロンプトに付加し、Transformerに通し、再びロジットを見てサンプリングし、これを繰り返します。
これが最も素朴に行うことです。ここでの複雑さはかなり悪く、生成する各トークンがTransformerを通るn平方またはt平方の計算のようなものになります。これは良くありません。
しかし、これを注意深く見ると、多くの冗長な作業を行っていることに気づくでしょう。基本的に、プレフィックスをエンコードするすべての作業は同じままです。これは双方向Transformerの場合は異なりますが、少なくとも自己回帰的因果Transformerの場合、プレフィックス間で多くを共有できるはずです。
そこで解決策は、キャッシュすることです。HBMにキャッシュします。なぜなら、そこに物を保存するのに十分なスペースがあるからです。
KVキャッシュがある場合、概略的にはこのように見えます。プロンプトを取り、プリフィルステップでTransformerに通し、このKVキャッシュを計算します。そして、次のトークンに対するロジットを生成します。次に、それを生成されたトークンとキャッシュに入れ、Transformerに通すことができます。
しかし、これらはすでに計算済みなので、再度行う必要はありません。この新しいトークンに対する新しいKVベクトルだけを計算すればよく、これにより次のトークンをより迅速に生成できます。基本的に、生成したトークンまたはプリフィルしたトークンに対応するこのKVキャッシュを埋めているのです。
トークンあたりt平方ではなく、tのようなものになります。
5.2 Prefillステージと生成ステージ
Percy Liang氏:具体的には、KVキャッシュは、バッチ内のすべてのシーケンスについて、シーケンス内のすべてのトークンについて、Transformerのすべての層について、すべてのヘッドについて、h次元のベクトルを保存することになります。これが多くのメモリを取ることになると思うかもしれませんが、それは間違いではありません。
推論には2つのステージがあります。プリフィルは、与えられたプロンプトをベクトルにエンコードすることです。これはちょうど訓練で行うことと同じです。並列化可能で、高速で、計算制限され、生活は良好です。
そして生成があり、これは応答トークンを一つずつ逐次的に生成することです。これが効率の観点から多くの問題を引き起こす部分です。
プリフィルステージでは、入力プロンプト全体が一度に処理されるため、Transformerの並列処理能力を最大限に活用できます。すべてのトークンの表現を同時に計算し、KVキャッシュを一度に構築することができます。これは訓練時の順伝播と非常に似ており、効率的に実行できます。
一方、生成ステージでは、各新しいトークンが前のすべてのトークンに依存するため、並列化が不可能です。各ステップで一つのトークンのみを生成し、そのトークンの情報をKVキャッシュに追加して、次のトークンの生成に備える必要があります。この逐次的な性質が、生成ステージを推論における主要なボトルネックにしています。
5.3 MLPレイヤーの算術強度
Percy Liang氏:Transformerのフロップとメモリ入出力を計算してみましょう。MLPレイヤーとアテンションレイヤーに分けて説明します。記法として、条件付けしているトークン数をs(プロンプトの長さと考えてください)、生成またはクエリに使用しているトークン数をtとします。プリフィルではtはsと等しくなります。なぜなら、tトークンを生成しているわけではありませんが、これらの各トークンを使ってクエリしているような感じだからです。生成ではtは単に1です。
行列乗算がまだ頭の中に新鮮に残っていることを願います。なぜなら、これは基本的にそれですが、Transformerなので少し複雑になります。
フロップとバイト生成を数えてみましょう。まず、b×t×d行列であるxを取ります。これらのTはSであるべきかもしれませんが、とにかく、これは一連の転送を伴います。基本的にBF16のため、その行列のサイズ×2です。
次に、3つの方法の行列があります:アップ投影、ゲート、ダウン投影です。これらはすべて転置まで同じサイズです。それらを転送する必要があります。次に、アップ投影を行います。これはいくつかのフロップ、B×D×C×Fです。2つのテンソルを掛ける場合、基本的に縮約次元は一度だけカウントされ、他の次元は単に集められます。
それを書き出す必要があります。ゲートも同じことで、書き出します。非線形性を計算します。いくつかのものを掛け合わせ、ダウン投影します。それはb×t×d×fで、基本的に同じフロップ数です。結果を書き出します。
カウントを見ると、これをチェックする必要はないかもしれません。実際にはSimpを使っているので、正しいことが保証されています。しかし、再び、B×DがDやFよりもはるかに小さいと仮定し、強度はB×Tになります。
これは、算術強度が高くなりたい場合に、バッチがどれだけ大きいか、そして基本的に生成しているトークン数に依存する行列乗算ケースに類似しています。
5.4 アテンションレイヤーの算術強度
Percy Liang氏:次にアテンションについて説明しましょう。アテンションは、私が説明する理由でさらに悪いことが判明します。
フロップとバイト転送をカウントしてみましょう。HBMからQKV行列を読み取ります。アテンションを計算します。これはQ×Kの行列です。フロップ数はB×S×T×Dです。SとTはプリフィル中は同じであることを覚えておいてください。つまり、シーケンス長の2乗×B×Dです。
行列乗算のみを見ています。なぜなら、他のステップからのフロップは実際には重要ではないからです。次に、これとvの組み合わせを取ります。つまり、実際には数学的に間違っています。なぜなら、そこにソフトマックスがあるからですが、行列乗算の本質は同じです。
同じフロップ数で、HBMに書き込みます。ここでは、Flash Attentionを使用していない場合により多くのバイトが転送されることを想定しています。Flash Attentionは、中間ステップでHBMに書き戻し続ける必要がないことを意味します。しかし、実際にはオーダーは影響を受けません。定性的には、Flash Attentionを使用するかどうかは実際には重要ではありませんが、ここでの数学は定数が重要になります。
フロップと転送されたバイトを見て、割って簡略化すると、このかなり素晴らしい表現が得られます。素晴らしいというのは、それがシンプルだという意味で、効率が良いという意味ではありません。S×T / (S+T)です。
これを少し解釈してみましょう。プリフィルではT = Sなので、プリフィル強度はSのオーダーです。これは良いことです。なぜなら、十分に長いシーケンスがある限り、順調に進むことができるからです。一般的に、シーケンスは十分に長いと仮定できます。
しかし、生成中は、実体は基本的に1 s/(s+1)ですが、これは基本的に1です。1は本当に悪いことを覚えておいてください。
しかし、Bに依存がまったくないことに注目してください。MLPでは、プリフィル算術強度はBTで素晴らしく、生成算術強度はBで、これはユーザーやワークロードの気まぐれに依存するため素晴らしくありませんが、それでも1より大きくなる可能性があります。
一方、アテンションでは、実際には常に1未満です。シーケンスがどれだけ長くても、ユーザーがどれだけいても、常に1です。
6. 実例による性能分析
6.1 LLaMA 2 13BモデルでのH100性能
Percy Liang氏:H100上でLLaMA 2 13Bのレイテンシとスループットを実例化してみましょう。13Bに対して、以下の値があります。シーケンス長を1000、隠れ次元をモデル次元を5000、Fは4倍ではありませんが、とにかくFはその数の何倍かの数、ヘッド数、キー・バリューヘッドの数(LLaMA 2では同じです、この点については後で説明します)などとし、H100のメモリ帯域幅もこれです。
この設定に基づいて、メモリレイテンシとスループットを計算することになります。まず、パラメータ数を素早く取得しましょう。これは課題1で行ったので、詳しく説明しませんが、すべての異なる項に依存する何らかの式です。パラメータを保存するために、推論は一般的に32ビットではなく16ビットになるため、BF16を使用します。そのため2を掛けます。
これがパラメータが取るメモリです。勾配は必要ありません。訓練していないのでオプティマイザーの状態も必要ありません。しかし、KVキャッシュを保存する必要があります。これは活性化の一部ですが、すべてではありません。
長さsのすべてのシーケンスに対して、シーケンスあたりに保存する必要がある量は、基本的にシーケンス長×キー・バリューヘッド数×そのヘッドの次元×層数×基本的にキーとバリューの両方のため2×BF16のため2です。
これがキャッシュサイズの取る量です。総メモリは、バッチサイズ×シーケンスあたりのキャッシュ+パラメータサイズです。
レイテンシはメモリによって決定されます。メモリ制限であることを覚えておいてください。そのため、この計算を行うためにGPUに転送する必要があるメモリの量を計算するだけです。これは単純にメモリをメモリ帯域幅で割ったものです。スループットは基本的にレイテンシの逆数ですが、並列でBトークンを生成しているため、Bでスケールアップされます。
LLaMA 2の設定を代入すると、パラメータ数は約130億でチェックアウトし、メモリレイテンシとスループットにはこれらの式があります。メモリは明らかに増加し、これがパラメータサイズ、これがキー・バリューキャッシュサイズ×Bです。レイテンシもBの関数として上がり、スループットは増加しますが、ある点まで増加することがわかります。Bは分子と分母の両方に現れるため、すべてをメモリに収められたとしても、スループットをどれだけ引き伸ばせるかには限界があります。
6.2 バッチサイズとレイテンシ・スループットの関係
Percy Liang氏:異なるバッチサイズで実例化してみましょう。B=1の場合、レイテンシは約8ミリ秒です。つまり、8ミリ秒ごとにトークンを生成し、スループットは124トークン/秒です。これがH100上でのバッチサイズ1を使用した13Bです。
次に、バッチサイズ16を使用するとどうなるでしょうか。64個のシーケンスすべてのKVキャッシュを保存する必要があるため、メモリ使用量が増加します。レイテンシは上がります。1つだけを処理するのではなく、すべてが完了するまで待つ必要があるためです。しかし、スループットも実際にかなり大幅に上がります。
レイテンシとスループット間の即座のトレードオフが見られます。低レイテンシが必要な場合は、B=1を使用します。しかし、高スループットが必要な場合は、一般的により大きなBを使用したいでしょう。
さらに大きなバッチサイズ、256を使用するとどうなるでしょうか。レイテンシは上がり、スループットは上がりますが、しばらくすると収穫逓減が生じるため、スループットはそれほど上がりません。
しかし、最も重要なことは、実際にはH100でこれを行うことはできません。メモリを見ると240ギガバイトです。つまり、収まりません。バッチサイズはメモリのためにある点まで増加させることしかできません。
要約すると、レイテンシとスループット間にはトレードオフがあります。小さなバッチサイズはより良いレイテンシを生み、大きなバッチサイズはより良いスループットを生みます。この関係は、システム設計において重要な考慮事項となり、用途に応じてどちらを優先するかを決定する必要があります。
6.3 メモリ使用量の制約
Percy Liang氏:バッチサイズを256まで増加させると、メモリ使用量が240ギガバイトになりますが、これは実際にはH100では実行できません。メモリに収まらないからです。バッチサイズはメモリのためにある点まで増加させることしかできません。
この制約は、推論システム設計における根本的な課題を浮き彫りにします。理論的にはより大きなバッチサイズでスループットを向上させたいのですが、実際にはハードウェアのメモリ容量によって制限されます。
メモリ使用量は主に2つの要素で構成されます。まず、モデルパラメータのメモリです。13Bモデルの場合、BF16精度で約26GBが必要です。次に、より動的で問題となるのがKVキャッシュのメモリです。これはバッチサイズとシーケンス長に比例して増加します。
KVキャッシュのメモリ要件は、バッチ内の各シーケンス、各トークン位置、各層、各アテンションヘッドについてキーとバリューのベクトルを保存する必要があるため、急速に蓄積されます。バッチサイズを16から256に増加させると、KVキャッシュだけで数百ギガバイトのメモリが必要になる可能性があります。
この制約により、実際の推論システムでは、利用可能なメモリ量とターゲットとするレイテンシ・スループットのバランスを慎重に考慮する必要があります。メモリ効率的なアーキテクチャや最適化手法の開発が、実用的な推論システムにとって極めて重要である理由がここにあります。
7. KVキャッシュ削減手法
7.1 Group Query Attention (GQA)
Percy Liang氏:大きなボトルネックはKVキャッシュです。メモリ制限であることを覚えておいてください。つまり、メモリに必要な容量が少なければ少ないほど高速になります。フロップのためだけでなく、主にメモリ転送のためです。
KVキャッシュを削減し始めると、精度を失う可能性があります。そこで、精度をあまり失わずに、KVキャッシュを小さく保つ方法はあるでしょうか。基本的にアーキテクチャを変更してKVキャッシュを削減しようとする一連のアイデアを説明します。
Group Query Attentionというアイデアがあります。バニラTransformerのマルチヘッドアテンションは、基本的にヘッド数を保持し、それぞれについて同じ数のキー、バリュー、クエリを持ちます。
かつてマルチクエリアテンションというものがあり、基本的に1つのキーと1つのバリューしか持ちませんでした。つまり、基本的に1つのキー・バリューヘッドです。これはあまり表現力がないことが判明しました。そこで、キーとバリューの数を減らし、より多くのクエリを持つ中間点があります。
なぜこれを行うのでしょうか。KVキャッシュサイズを削減したいことを覚えておいてください。キーとバリューが少なければ少ないほど良いのです。バッチサイズとシーケンス長は変更されませんが、これらのベクトルの次元も変更されませんが、削減しているのはキー・バリューヘッドの数です。
これが基本的なアイデアで、この論文では実際にレイテンシとスループットの改善が得られることを示しています。グループ数を増やすにつれて、8程度まで基本的にフルアテンションと比較して非常に高速で、無視できる違いがあります。グループ数を増やすにつれて、明らかに元のものになります。
より厳密に行うために、LLaMA 2 13Bモデルがあります。統計を計算すると、バッチサイズ64を使用して、これが得られたものです。レイテンシもここに印刷すべきでした。まあいいでしょう。GQAで実行すると、メモリが削減され、スループットが大幅に向上することがわかります。これは実際に素晴らしいです。
LLaMA 2 13Bアーキテクチャを取り、すべてのクエリヘッドに対して、すべてのキー・バリューヘッドに対して5つのクエリヘッドを持つとします。つまり、1:5の比率です。これはまた、前回256を試した時にH100のメモリに収まらなかったことを覚えているでしょうが、今は実際にH100メモリに快適に収まることができ、より大きなバッチサイズを使用することでスループットをさらに改善できます。
7.2 Multi-Head Latent Attention (MLA)
Percy Liang氏:KVキャッシュを削減する別の方法があり、これはDeepSeekから来ています。これは実際にはDeepSeek V2の論文からのもので、Multi-Head Latent Attentionと呼ばれ、Tatsuが以前に講義しましたが、推論の文脈とその含意について話してみようと思います。
基本的なアイデアは、フルアテンションとGQAです。GQAは「より少ないキーとバリューを使用します」と言います。MLAは「キーとバリューの数を変更しません。これらをより低次元の空間に投影します」と言います。
つまり、KVサイズを縮小する別の方法ですが、異なる次元で行います。各トークンのKVキャッシュに対してN×H次元を使用する代わりに、C次元に投影します。DeepSeekが行ったことは、実際にはかなり積極的な削減で、16,000から512への削減です。唯一の問題は、これがROPEと互換性がないことです。そのため、ROPEを戻すためにいくつかの追加の次元を追加する必要があります。
しかし、全体的に見ると、これは実際にKV削減の観点からかなり有望です。数学は行いませんが、KVキャッシュがどのように大幅に削減されるかを見ることができ、同じ種類のレイテンシとスループットの利点を得ることができると信じることができます。
精度の観点から見ると、GQAと比較して、MLAが実際に改善することを示しました。この表はそれを示していないかもしれません。後で掘り下げる必要があります。しかし、とにかくMLAは精度も保持します。
MLAのアプローチは特に興味深いものです。なぜなら、アテンションヘッドの数を減らすのではなく、各ヘッドが使用する表現の次元を削減するからです。これにより、モデルの表現能力をより保持しながら、メモリ使用量を大幅に削減できる可能性があります。DeepSeekの実装では、この手法により大幅なメモリ削減を実現し、推論効率を向上させています。
7.3 Cross Layer Attention (CLA)
Percy Liang氏:Cross Layer Attentionと呼ばれる別のアイデアがあります。この論文がありますが、多くの人がこれについて考え、実行していると思うので、これが実際に最初の論文かどうかはわかりませんが、基本的にTransformerダイアグラムを見ると、1つの層のキー・バリュー投影があり、次に次の層があり、これらのキー・バリューベクトルは通常は別々です。
CLAのアイデアは、レイヤー間で同じキー・バリュー投影を使用することです。それがクロスレイヤーアテンションと呼ばれる理由です。GQAがヘッド間で共有するように、CLAはレイヤー間で共有します。
ここで彼らは、精度とKVキャッシュサイズのパレート最前線を実証的に改善することを示しています。スループットとレイテンシに関連するKVキャッシュサイズは小さくしたく、パープレキシティも小さくしたいです。
彼らは、このトレードオフを改善できることを示しています。例えば、H6 64ヘッドの場合、キャッシュサイズは削減されますが、検証パープレキシティは少し上がりますが、全体的にそのトレードオフを行う利点があります。
Cross Layer Attentionは、レイヤー間でのパラメータ共有によってメモリ効率を改善する興味深いアプローチです。各レイヤーが独自のキー・バリュー投影行列を持つ従来のアプローチとは異なり、CLAでは複数のレイヤーで同じ投影を共有します。これにより、KVキャッシュに必要なメモリが削減され、推論時のメモリ使用量とそれに伴う転送コストが削減されます。
この手法の利点は、パラメータ数の削減とメモリ使用量の削減の両方を同時に実現できることです。ただし、レイヤー間での表現の多様性が制限される可能性があるため、精度への影響を慎重に評価する必要があります。
7.4 Local Attention
Percy Liang氏:物事を行う別の方法があります。Local Attentionです。これは実際にLongformer、OpenAIの論文以来、かなり探求されており、その後MistralやI think多くの他の人がこれを使用しています。これは非常に自然なアイデアだと思います。
フルアテンションダイアグラムを見ると、それは密なn²です。そして、そこから多くの複雑さが来ます。基本的なアイデアは、過去のKトークンのみにアテンションすることです。
これは、KVキャッシュにおいて、シーケンスを生成している間、すべてを覚えておく必要がないことを意味します。トークンがアテンションのウィンドウの外に落ちるとすぐに、それを捨てることができます。Local Attentionは、シーケンス長とともに成長するのではなく、KVキャッシュサイズが一定に保たれるということです。これは本当に良いことです。なぜなら、長いシーケンスでも、かなり小さなキャッシュを持つことができるからです。
しかし、問題は、これが依然として精度を損なうことです。なぜTransformerではなくRNNを行うのかを考えれば、長距離依存関係をモデル化する必要があったからです。これはある意味で、これをアテンションと呼ぶのは少し誇大広告です。これはローカルコンテキストのみを見ているため、あまり表現力がありません。
そこで、ローカルアテンションをフルグローバルアテンションハイブリッドレイヤーと組み合わせることができます。例えば、Character AIは6レイヤーごとに1つのグローバルレイヤーと5つのローカルレイヤーを使用しました。
Cross Layer Attentionに加えて、このようなものになります。フルアテンション、すべてのレイヤーで、KVキャッシュを保存する必要があります。彼らが行ったことは、6レイヤーごとにフルアテンションを持ちますが、その間にローカルアテンションがあり、さらにローカルアテンションとグローバルアテンションの両方でKVキャッシュ共有があります。これは、すべてのトリックではありませんが、多くのトリックが組み合わされたようなものです。
Local Attentionは、長いシーケンスに対するメモリ効率を大幅に改善する強力な手法です。従来のアテンションが全てのトークンにアテンションする必要があるのに対し、ローカルアテンションは固定サイズのウィンドウ内のトークンのみを考慮します。これにより、シーケンス長に関係なくKVキャッシュサイズを一定に保つことができます。
8. 新しいアーキテクチャによる高速化
8.1 State Space Models
Percy Liang氏:Transformerを変更することによってさらに急進的な方法で推論を高速化する方法について話しましょう。KVキャッシュこれらは基本的にTransformerの変種でした。しかし、もしかするとTransformerの外に出て、より良いことができるかもしれません。なぜなら、Transformerは本当に重い推論ワークロードを念頭に置いて設計されたわけではないからです。
彼らは単に良いモデルを効率的に訓練しようとしていました。それは主に訓練効率についてでした。そして、私たちが指摘したように、自己回帰と完全なアテンションの組み合わせが、ここでこの種のボトルネックを本当に引き起こしています。
State Space ModelsとDiffusion Modelsという2つの方向について話しましょう。これはかなり簡潔に行います。
State Space Modelsのアイデアは、実際に信号処理と制御理論からアイデアを引き出しています。最初の動機は、n²の爆発を被ることなく、長いコンテキストシーケンスをモデル化しようとすることでした。つまり、必ずしも推論速度についてではありませんでしたが、その問題を解決すると、より高速な推論も得られることが判明しました。
S4という初期の論文があります。これは、基本的にこれらの種類の線形動的システムである古典的なState Space Modelsを使用し、長いコンテキストをモデル化し、現代のニューラル設定に無理やり押し込むために使用されています。この研究は、線形性構造のためにRNNのような解釈と畳み込み解釈の両方を持つという点で素晴らしいです。
彼らはこの論文を発表し、これらの長いコンテキストの合成タスクで本当にうまく機能することを示したと思います。しかし、発見されたことは、これらが言語モデリングにはあまりうまく機能しないということでした。そして、それは明らかに失望でした。なぜなら、Transformerの価値の多くは言語をうまく行えることだからです。
一連の論文で、彼らはこれらのモデルがうまく機能しない理由の本質を捉えた一連の合成タスクを特定しました。それは基本的にこれらの連想記憶タスクです。ここに、基本的にキー・バリューペアのシーケンスが与えられる合成タスクがあり、目標は基本的にキーを調べて値を出力することです。
ある意味で、これは論理的に些細なタスクですが、キー・バリューペアをたくさん持つことができるため、長いシーケンスです。任意に長い依存関係を持つことができ、ローカルアテンションはうまく機能しないことがわかります。なぜなら、最後の数シーケンスしか覚えないからです。State Space Modelsの問題は、これらの種類の信号処理タスクには良いが、実際にこれは特定のキー・バリューペアを分離して答えを引き出す必要があり、そのタイプのタスクには実際にうまく機能しなかったということです。
8.2 Linear Attention
Percy Liang氏:HyenaやH3、そしてMambaなど、一連の研究があります。これらは基本的にSSMを調整または変更して、これらの連想記憶タスクを処理し、最終的に1Bスケールまでは、Transformerとマッチするまでうまく機能しました。MambaのアイデアはMambaは人気になり、AI21の人々によって52Bまでスケールアップされました。
この場合、彼らはまだTransformerを使用しなければならなかったことに注目してください。Transformerですが、8層ごとに1つのTransformerがありました。残りはMamba層でした。それでも、かなり大きな節約と速度向上につながりました。
しかし、最近、Linear Attentionと呼ばれるこの古いアイデアの復活があります。これを大きくできるかどうか見てみましょう。実際には非常にシンプルなアイデアです。Local AttentionやSliding Window Attentionが何かを知っていますか?Linear Attentionは、基本的にアテンション計算にキーとクエリがあり、それらをドット積し、そのexpを取るというアイデアです。これは基本的にexpカーネルを与えています。
そのテイラー展開を取り、その計算を基本的に何らかの非線形マップのドット積として書くことができます。つまり、本質的に、すべてのキー・バリュー位置に対して、基本的に何らかの非線形性を適用して何らかの空間に押し上げ、その上で線形計算を行うことです。
Linear Attentionであるため、実際にRNNのように動作し、二次ではなくシーケンス長に対して線形です。これは少し速かったと知っていますが、味を与えたかっただけです。
このアイデアは実際にかなり成功裏にスケールアップされています。Minimaxと呼ばれる組織があり、4560億パラメータまでかなり正当なモデルを訓練しています。彼らは基本的にこのLinear Attentionアイデアを使用しています。今、彼らは時々フルアテンションを使用しなければなりません。人々がフルアテンションを持たないことを回避できたとは思いませんが、少なくとも、ほとんどの層がもはやフルアテンションではないようです。
それらは線形層またはローカルアテンション層のいずれかであり、はるかに効率的です。Linear AttentionとLocal Attentionを組み合わせることで、実際に深刻な最先端モデルを生み出し、おそらく安全に言えることは、少なくともこれと同程度に効率的であり、スパース性を活用していることです。
8.3 Diffusion Models
Percy Liang氏:まったく異なるスタイルの生成モデルであるDiffusion Modelsについて話しましょう。Diffusion Modelsは画像生成で非常に人気がありますが、テキストで動作させるのは非常に難しいことが判明していますが、最近ここでいくつかの進歩がありました。
Diffusionのアイデアは、自己回帰的に生成する代わりに、すべてのトークンを並列に生成することです。明らかに、何らかのシンプルな層を介してそれだけを行うと、あまり良くないでしょう。すべての単語を並列に生成して、それが一貫性があることを期待することはできません。
しかし、行うことは反復し、最終的に出力する最終生成に到達するまで、この生成を継続的に洗練することです。並列に生成するというアイデアは、もはや自己回帰的に束縛されていないということで、すべてのトークンを並列に生成することは、コンテキスト長が十分に大きい限り、並列化可能です。
そのため、GPUまたはTPUを比較的簡単に飽和させることができます。最近、Nous Research Labsがかなり興味深いモデルを製作しました。それらについてはあまり書かれていませんが、生成プロセスのデモを見ることができます。コードを瞬時に生成しますが、明らかに壊れたコードで、その後時間をかけて洗練します。
これは彼らのベンチマークの1つで、少なくともコーディングについて、他のタスクについてはわかりませんが、毎秒トークン数を見ると、これらのモデルはTransformerである何よりもはるかに高速です。Jambaを覚えていますか?これはハイブリッドMamba Transformerアーキテクチャのようなものですが、これらのDiffusion Modelsと比較するとかなり遅いです。
Diffusion ModelsがGeneral Purposeで、すべてのタスクで十分に強力になるかどうかは、まだ見守る必要がありますが、トークン速度でこのようなリードがあるため、たとえ精度の損失があったとしても、より多くの計算を投入して、必要に応じて精度の損失の一部を回復できると思います。
この並列生成アプローチは、推論における根本的なパラダイムシフトを表しています。自己回帰的な制約を取り除くことで、利用可能な計算リソースをより効果的に活用でき、大幅な速度向上を実現できる可能性があります。
9. モデル圧縮技術
9.1 量子化(Quantization)
Percy Liang氏:量子化とモデル枝刈りについて話しましょう。量子化については、重要なアイデアは数値の精度を下げることです。非常に簡単に行うことができ、考え方は、メモリが少ないということは転送されるバイト数が少ないということで、より低いレイテンシ、より高いスループットを意味します。もちろん、精度について心配する必要があります。それがトレードオフです。
異なるタイプのフォーマットを見ると、FP32は訓練に使用されますが、推論には実際には使用されません。BF-16は推論のデフォルトのようなものです。FP8やINT8まで下げることができ、これはより精度が低いですが、FP8よりもはるかに安価です。人々は実際にINT8でかなりの推論を行っており、範囲を見ると、これは127から-128の間の整数で、それほど高精度ではありません。人々はINT4まで下げており、これは非常に低いです。
量子化したいと決めたら、いくつかのことができます。量子化で訓練することもできますが、明らかにモデルを再訓練する必要があり、一般的には、既存のモデルを取り、量子化して、あまり台無しにしないようにする訓練後量子化を行います。
LLM.int8と呼ばれる論文について簡単に説明します。量子化では、基本的にFP16のベクトルを取り、INT8にパックしたい場合、動的範囲を把握する必要があります。最大値が何かを把握し、それを把握したら、それで割り、128を掛けることで整数を得ることができます。逆量子化する必要がある場合は、逆の方向に進みます。
基本的に量子化は、メモリ帯域幅がボトルネックであることを覚えておいてください。すべての転送はINT8で行われますが、実際に算術を行う際は、浮動小数点にアップキャストする必要がある場合があります。
INT8の問題は、すべてがうまく収まるわけではなく、より大きなネットワークに現れる外れ値があり、物事を台無しにすることです。この論文が行ったことは、この行列を取り、本当に大きな外れ値を特定し、それらを別々に処理し、完全な16ビット精度を使用し、そして大部分をINT8で行うことです。
これはうまく機能しますが、実際には少し遅いです。ここでの動機は推論速度ではなく、モデルをメモリに収めることができることでした。
Activation-Aware Quantization(AWQ)と呼ばれる別の論文があります。ここでのアイデアは、基本的に重みを量子化するが、活性化に基づいてどの重みを量子化するかを把握することです。非常に迅速にこれを説明すると、実際にINT3まで下げており、これは明らかにメモリをかなり削減し、3倍の速度向上につながります。
9.2 モデル枝刈り(Pruning)
Percy Liang氏:モデル枝刈りについて話しましょう。量子化と同様に、アイデアは非常にシンプルです。高価なモデルの部分を取り除いて安くし、それを修正することです。
このNvidiaの論文では、まず小さなキャリブレーションサイズを使用して、重要な層、ヘッド、または隠れ次元を特定します。いくつかのシンプルなスコアを使用してそれを計算し、次に重要でない層、隠れユニット、またはヘッドを除去します。
そのモデルを取るだけでは、明らかに悪くなりますよね。そこで最後のステップは、元のモデルを枝刈りされたモデルに蒸留することです。つまり、ゼロから始めるのではなく、悪化しているがうまくいけばそれほど悪くなく、うまくいけば元のモデルと同じ構造的特性の多くを保持している初期化から始めています。ただ、ある意味でキャリブレートされていないだけです。
結果はかなり良いです。彼らは150億パラメータのモデルを持っており、それを8Bまでほとんど低下なしに削減でき、そして4Bまでいくらかの低下で削減できます。しかし、4Bモデルまで大幅に下げてもいます。
モデル枝刈りの魅力は、単純に不要な部分を除去することでモデルサイズを削減し、推論速度を向上させることができる点にあります。重要なのは、ランダムに除去するのではなく、モデルの性能に与える影響が最小限になるように、系統的に重要でない要素を特定して除去することです。
蒸留プロセスは特に重要です。枝刈り後のモデルは、元のモデルの「教師」として機能し、削減されたモデルが可能な限り元の性能を維持できるように訓練されます。これにより、大幅なサイズ削減にもかかわらず、実用的な性能を保持することが可能になります。
10. 投機的復号(Speculative Decoding)
10.1 投機的サンプリングの原理
Percy Liang氏:これらすべてのアプローチは少し不満足です。なぜなら、それらは損失を伴うからです。大幅な速度向上を得られますが、常に元のモデルが実際に元と同じくらい良いのか疑問に思います。Speculative DecodingまたはSpeculative Samplingは、基本的にケーキを食べて、それも持つことを可能にします。
推論には2つのステージがあることを思い出してください。与えられたシーケンスをエンコードし、すべてのトークンを並列に行うPrefillがあります。これは計算制限で素晴らしいです。各トークンの対数確率も与えることに注目してください。そして、一度に1トークンずつ生成があります。これはメモリ制限で遅いです。
言い換えれば、チェックは生成よりも高速です。直感的には理にかなっていますが、うまくいけば、なぜこれが真実であるかの数学も理解できるでしょう。
Speculative Samplingのアイデアは、実際に非常にシンプルです。これは、Googleの2つの独立したチームによって並行して提案されました。アイデアは、安価なドラフトモデルPを使用して先に進んでいくつかのトークンを生成し、次にそれらのトークンをターゲットモデルで評価することです。与えられたトークンの評価は単なるprefillなので、並列に行うことができ、高速です。そして、良く見える場合はそれを受け入れます。
これが実際の生活でどのように見えるかです。大きなモデルを使用して一度に1トークンずつ生成している場合、それは遅いです。しかし、speculative decodingでは、先に進んでたくさんのトークンを生成しているドラフトモデルがあり、大きなモデルを使用して本質的に検証し、時々拒否し、時々受け入れます。受け入れ率が基本的に得られる速度向上を決定します。
この原理の美しさは、正確性を犠牲にすることなく速度を向上させることができることです。ドラフトモデルが生成した候補トークンが適切である場合、それらを受け入れることで複数のトークンを一度に進めることができます。不適切な場合は、ターゲットモデルによる正確な生成にフォールバックします。これにより、平均的により高速な生成を実現しながら、出力品質を保証することができます。
10.2 アルゴリズムの詳細
Percy Liang氏:より形式的なアルゴリズムがここにあります。Kの先読みを持つことになります。ドラフトモデルを使用してK個のトークンを自己回帰的に生成します。これはドラフトモデルが小さいため、うまくいけば高速です。
次に、これらの生成したK個のトークンが与えられ、ターゲットモデルQに基づいてそれらをスコア化します。各トークンを見て、それを受け入れるかどうかを決定します。確率QをPで割ったもので受け入れ、1は確率が0と1の間にあることを確認するためのものです。
これは、サンプリングに精通している人には、Metropolis-Hastingsのようなものに見えるかもしれません。これは基本的にそこから来ています。直感的には、Pでサンプリングしています。Pを割り出したくないのでQが欲しいのです。つまり、これはある種の重要度重みです。
受け入れる場合は、次のドラフトトークンを見て続行します。受け入れない場合は、ターゲットモデル(遅いモデル)からサンプリングしますが、既にPを使用してサンプリングを試みたので、この補正を行います。もうそれを行う必要はありません。それを差し引いてQからサンプリングします。
これは基本的に、提案P、ターゲットQでの棄却サンプリングのようなものです。棄却サンプリングでは、拒否した場合、拒否して再試行し、再試行します。ここでは、永遠にループし続けたくありません。拒否した場合は、「わかりました。諦めて、より高価なモデルからサンプリングします」と言います。
ここでの素晴らしいことは、ターゲットモデルから正確なサンプルを得ることが保証されていることです。
私は導出をスキップします。これは実際には証明ではありません。語彙2の場合に、これらの公式がなぜ正しい偏りのないサンプリング手順を与えるのかを示すための、ある種の導出です。
このアルゴリズムの数学的な美しさは、棄却サンプリングの理論に基づいていることです。提案分布P(ドラフトモデル)を使用してサンプルを生成し、ターゲット分布Q(実際のモデル)から正確にサンプリングするための補正を行います。重要度重み Q(x)/P(x) を使用することで、ドラフトモデルが生成した候補の受け入れ確率を調整し、最終的にターゲットモデルからの正確な分布を保証します。
10.3 実装と性能向上
Percy Liang氏:これはかなりうまく機能します。精度は、同じモデルなので実際に同じであるべきです。ただし、そこにはある程度のランダム性があるかもしれません。しかし、速度向上は基本的に2倍の速度向上を得ています。
実際には、70Bモデルのようなものを持ち、ドラフトモデルははるかに小さくします。ターゲットモデルが70Bの場合、ドラフトモデルは8Bかもしれません。そして、ターゲットモデルが8Bの場合、ドラフトモデルは1Bかもしれません。一般的に、ドラフトモデルをターゲットにできるだけ近づけたいと思います。蒸留を行っている場合、それはさらに良くなる可能性があります。
これは推論においてかなり熱い研究分野です。このプロセスを改善する多くの方法があります。Medusaを使用することができます。これは、ドラフトモデルが自己回帰的に生成する代わりに、複数のトークンを並列にサンプリングする方法です。
または、実際にターゲットモデルの高レベル機能を取り、それらをドラフトモデルに送り込んで生成を支援するEagleがあります。つまり、ドラフトモデルは実際には単独で立つ必要はありません。ターゲットモデルにくっついて、それが生成するのを助けることができます。
要約すると、数学のおかげでターゲットモデルからの正確なサンプリングです。これは、チェックと生成の非対称性、またはprefillと生成を利用し、ドラフトモデルには多くの革新の余地があります。量子化の異なる方法、私たちが前に話したすべてのことが適用されます。唯一の違いは、基本的に正確なサンプルを得ることが保証されていることです。
Speculative Decodingの実装における重要な考慮事項は、ドラフトモデルとターゲットモデル間のバランスです。ドラフトモデルが高速であることは重要ですが、同時にターゲットモデルに十分に近い品質を持つ必要があります。受け入れ率が低すぎると、速度向上の利益が減少してしまいます。
さらに、MedusaやEagleのような改良手法は、従来の自己回帰的な生成を超えて、より効率的な候補生成方法を探求しています。これらの手法により、ドラフトモデルの効率性を向上させ、全体的な速度向上をさらに増大させることが可能になります。
11. 実用的な推論システム
11.1 動的バッチング
Percy Liang氏:時間がないので、先ほど出てきた質問について簡単に説明します。実際に提供する際には、ライブトラフィックがあり、リクエストは異なる時間に到着し、異なる時間に終了し、一部は共有プレフィックスを持ち、一部は持たず、異なる長さを持ちます。訓練では基本的に密なトークンブロックを取得し、それを全速力でGPUに押し通すのと比較して、非常に異質です。
この場合はどうすればよいでしょうか?これを探求する一連の論文があり、基本的なアイデアは、最後の2つの部分がより多くのシステムレベルの貢献のようなものです。
アイデアは、バッチが電車が出発するのを待たないということです。電車はあなたを待ちません。新しいバッチが来ると、それを入れるだけです。つまり、トークンを生成しているワーカーは、毎ステップスケジューラーに制御を戻す必要があります。
トークンを生成し、スケジューラーに戻って「新しいリクエストがあるか」と言い、新しいリクエストがあればそれらを詰め込んで続行します。つまり、リクエストを待つ時間を無駄にしていません。
バッチングには問題があると思います。これは質問の背後にあるものです。バッチングは、すべてが同じ次元性である場合に機能しますが、すべてのリクエストが異なる長さかもしれません。
そこで、選択的バッチングというアイデアがあります。基本的に、アテンションのための計算を分解します。すべてを別々に処理する必要があります。しかし、MLPについて覚えておいてください。これは計算の大部分ですが、実際に異なるサイズのテンソルを取り、それらを平坦化することができます。それらは相互作用しないので、バッチ次元で基本的に同乗できます。
動的バッチングの核心は、静的なバッチ処理とは根本的に異なる、より流動的なアプローチを取ることです。従来のバッチ処理では事前に決められたサイズのバッチを待つ必要がありましたが、動的バッチングでは到着したリクエストを即座に処理に組み込むことができます。
選択的バッチングは特に巧妙な解決策です。アテンション計算では各シーケンスが独自のKVキャッシュを必要とするため個別処理が必要ですが、MLP計算では異なる長さのシーケンスでも効率的にバッチ処理できます。これにより、ハードウェアの利用効率を最大化しながら、異なる長さのリクエストを同時に処理することが可能になります。
11.2 ページアテンション
Percy Liang氏:ページアテンションについて簡単に説明します。これは、あなた方の一部が使用したことがあるかもしれないvLLMの背後にある論文です。これはメモリ使用の問題に対処しています。
KVキャッシュがあり、プロンプトが入ってきて終了している場合、キャッシュが断片化されることになります。リクエストに対して大量のスペースを割り当てることになりますが、何個のトークンが生成されるかわからないので、内部断片化が発生し、リクエストと応答の間にパディングがある外部断片化も発生します。これは良くありません。
ページアテンションは基本的に、オペレーティングシステムと仮想メモリがどのように機能するかを覚えていると言います。KVキャッシュを連続するブロックのシーケンスに分割し、空白スペースを見つけた場所にそれらを配置するだけです。
2つのリクエストが入ってくる場合、最初のリクエストはここ、ここ、ここにあり、2番目のリクエストはここ、ここにあるかもしれません。ブロックは連続性を保つものであり、それによりメモリを結合することができます。
プレフィックスの共有がある場合、オペレーティングシステムからのもう1つのアイデアであるcopy-on-writeを使用することもできます。基本的に、この特定のブロックを使用しているシーケンスの数について参照カウンターを維持し、異なる方向に進む必要があり、ブロックが異なる方向に進む必要がある場合は、コピーして参照カウントを減らします。
他にもvLLMの最適化はたくさんありますが、基本的な要約は、オペレーティングシステムのクラスを覚えていれば、それらを推論にも適用できるということです。
ページアテンションの革新性は、メモリ管理の古典的な問題を言語モデル推論に適用したことにあります。従来のKVキャッシュ管理では、各シーケンスに対して連続したメモリブロックを事前に割り当てる必要がありましたが、実際の生成長は予測できないため、大量のメモリが無駄になっていました。
ページングシステムにより、メモリを小さなブロック単位で管理し、必要に応じて動的に割り当てることができます。copy-on-writeメカニズムは、共通のプレフィックスを持つ複数のリクエストがある場合に特に効果的で、メモリ使用量をさらに削減できます。これらの技術により、限られたGPUメモリでより多くの同時リクエストを処理することが可能になります。
11.3 メモリ管理の最適化
Percy Liang氏:vLLMには他にも多くの最適化がありますが、詳しくは説明しませんが、基本的な要約は、オペレーティングシステムのクラスを覚えていれば、それらを推論にも適用できるということです。
メモリ管理の最適化において重要なのは、断片化の問題を根本的に解決することです。従来のアプローチでは、各リクエストに対して最大可能長のメモリを事前に割り当てる必要がありましたが、実際の生成長は大きく異なるため、大量のメモリが未使用のまま残されていました。
ページアテンションシステムでは、メモリを固定サイズの小さなページに分割し、これらのページを仮想的に連結することで論理的に連続したKVキャッシュを構築します。物理的にはページが散在していても、論理的には連続したアドレス空間として扱うことができます。
参照カウンティングメカニズムは、複数のシーケンスが同じプレフィックスを共有する場合に特に有効です。同一のプレフィックス部分は単一のメモリページセットで共有され、各シーケンスが独自の経路を取る必要がある場合にのみcopy-on-writeが発生します。これにより、チャットボットのような用途で同一のシステムプロンプトから始まる多数のリクエストを効率的に処理できます。
これらの最適化技術は、単純なメモリ割り当てアプローチと比較して、メモリ利用効率を大幅に向上させ、同じハードウェアリソースでより多くの同時リクエストを処理することを可能にします。
12. まとめと今後の展望
12.1 推論の重要性の再確認
Percy Liang氏:簡単にまとめると、推論は本当に重要です。特性は訓練とは異なります。メモリ制限されており、動的でもあります。これが多くの新しい課題につながります。
推論の重要性は、実用的なAIシステムの展開における中心的な役割から来ています。訓練は一度行えば完了しますが、推論は継続的に、しかも大規模に実行されます。OpenAIの1日1000億ワード、Cursorの1日10億行のコード生成といった数値が示すように、推論は現代のAIインフラストラクチャにおいて膨大な計算リソースを消費しています。
訓練と推論の根本的な違いを理解することが重要です。訓練では、すべてのトークンが既知であるため並列処理が可能で、計算リソースを効率的に活用できます。しかし推論では、各トークンが前のすべてのトークンに依存する逐次的な性質により、並列化が制限され、メモリ帯域幅がボトルネックとなります。
この特性の違いが、推論固有の技術的課題を生み出しています。算術強度の分析で示したように、特に生成段階では計算能力を十分に活用できず、メモリ制限により性能が制約されます。さらに、実際のサービス環境では、リクエストが動的に到着し、異なる長さと要件を持つため、静的なバッチ処理とは異なるアプローチが必要になります。
これらの課題により、推論は単なる工学的な問題を超えて、アーキテクチャ設計から根本的に見直す必要がある領域となっています。
12.2 アーキテクチャ変更の可能性
Percy Liang氏:新しいアーキテクチャ、量子化、枝刈り、蒸留、投機的復号などのさまざまな技術の全体を見てきました。システムからのアイデアもあり、メモリをより良く使用し、通信と計算を重複させるなどのことができますが、モデリングとアーキテクチャにはおそらくさらに多くの機会があると言いたいと思います。
推論を狭く考えると、推論は特定のモデルでの推論です。この特定のモデルをどのように実行するかということです。しかし、その特定のモデルを実行することについて誰が気にするでしょうか。あなたが気にするのは、リソース予算が与えられた時に良い精度を提供することです。
KVキャッシュを削減し、Transformerを変更しようとするこれらのアイデアの多くは、基本的に問題を回避し、「まあ、より効率的な何かがあり、それをより良い精度を得る方法で訓練できるなら、そうすれば勝利です」と言う方法です。
このアーキテクチャ革新の視点は、推論効率化における最も有望なアプローチの一つです。従来のアプローチが既存のTransformerアーキテクチャを所与として、その上で最適化を行うのに対し、アーキテクチャレベルでの変更は根本的なボトルネックを解決する可能性があります。
例えば、State Space ModelsやLinear Attentionのような手法は、Transformerの二次的な複雑さを回避し、KVキャッシュの制約を根本的に改善します。同様に、Diffusion Modelsは自己回帰的な生成制約を完全に取り除くことで、並列処理の可能性を大幅に拡大します。
これらのアーキテクチャ変更は、単なる効率性の改善を超えて、新しい能力の可能性をも開きます。より長いコンテキストの効率的な処理、より高速な推論、さらにはより表現力豊かなモデルの実現など、従来のTransformerでは困難だった課題に対する解決策を提供する可能性があります。
重要なのは、アーキテクチャの選択が推論効率に与える影響の大きさを認識することです。システムレベルの最適化も重要ですが、アーキテクチャの根本的な変更により得られる改善幅は、しばしば桁違いの効果をもたらします。
12.3 効率性と精度のトレードオフ
Percy Liang氏:これらのKVキャッシュを削減し、Transformerを変更しようとするアイデアの多くは、基本的に問題を回避する方法です。「まあ、より効率的な何かがあり、それをより良い精度を得る方法で訓練できるなら、そうすれば勝利です」と言うのです。
推論を狭く考えると、特定のモデルでの推論、つまりこの特定のモデルをどのように実行するかということになります。しかし、その特定のモデルを実行することについて誰が気にするでしょうか。あなたが気にするのは、リソース予算が与えられた時に良い精度を提供することです。
この視点の転換が重要です。従来のアプローチでは、「この70Bモデルをいかに高速に実行するか」という問題設定でしたが、実際に重要なのは「与えられた計算予算内で最高の性能を達成する」ことです。この観点から見ると、効率的なアーキテクチャで訓練された小さなモデルが、非効率的なアーキテクチャの大きなモデルよりも優れた結果を提供する可能性があります。
量子化や枝刈りのような手法は、既存のモデルに対する後処理として適用されるため、必然的に元の性能から劣化します。一方、アーキテクチャレベルでの効率化は、最初から効率性を考慮して設計されるため、同じ計算予算でより良い性能を実現できる可能性があります。
Speculative Decodingのような手法は、この課題に対する興味深い解決策を提供します。正確性を保証しながら速度向上を実現するため、精度と効率性の間でトレードオフを強いられることがありません。
最終的に、推論効率化の目標は、特定のモデルを高速化することではなく、限られたリソース内で最良の結果を達成することです。この目標を念頭に置くと、アーキテクチャ革新こそが最も有望なアプローチであることが明らかになります。効率的なアーキテクチャの開発により、従来は不可能だった計算効率と精度の両立が実現される可能性があります。