GoogleとIntelが共同開発、Intel GPUでJAXモデルを高速化するPJRT API登場:AIフレームワークへのハードウェア最適化が容易に
GoogleはPJRTの初のプラグインPJRT APIをIntelと共同開発したことを発表した。PJRT APIを実装することでIntel GPU上でのJAXモデルの実行や、XLAアクセラレーションを使用したTensorFlowおよびPyTorchモデルの初期Intel GPUサポートも可能になる。
Googleは2023年6月1日(米国時間)、Intel GPU上でJAXモデルをシームレスに実行する「Intel Extension for TensorFlow」で、初のPJRT(Performance JIT Runtime)プラグインの実装を発表した。PJRT APIは別途開発されたIntel GPUプラグインとの統合を簡素化し、JAXへの迅速な統合を可能にする。また、PJRTの実装により、XLAアクセラレーションを備えたTensorFlowとPyTorchモデルの初期Intel GPUサポートも有効になる。
IntelとGoogleは共同でTensorFlow PluggableDeviceメカニズムを開発した。このメカニズムはTensorFlowを新しいデバイスに拡張するためにサポートされた手法であり、ハードウェアベンダーはこれを使用して、独自のプラグインバイナリのリリースが可能になる。Intelは、XLAコンパイラ用のモジュラーインタフェースの構築や、IntelGPU上でJAXワークロードを実行するためのPJRTプラグインの開発でGoogleと協力を続けている。
JAX
JAXは、GPUやTPUなどのハイパフォーマンスコンピューティングデバイスでの複雑な数値計算のために設計されたオープンソースのPythonライブラリである。NumPy関数をサポートし、自動微分や、ニューラルネットワークを構築、訓練するためのコンポーザブル関数変換システムを提供する。
JAXは、コンパイルと実行のバックエンドとしてXLAを使用し、特にAIハードウェアアクセラレータ上での計算を最適化し並列化する。JAXプログラムが実行されると、PythonコードはOpenXLAのStableHLOオペレーションに変換され、コンパイルと実行のためにPJRTに渡される。その下で、StableHLOオペレーションがXLAコンパイラによってマシンコードにコンパイルされ、ターゲットのハードウェアアクセラレータ上で実行することができる。
PJRT
PJRT(OpenXLAのStableHLOと組み合わせて使用)は、コンパイラとランタイム向けのハードウェアおよびフレームワークに依存しないインタフェースを提供する。PJRTインタフェースは、新しいデバイスバックエンドからのプラグインをサポートしている。このインタフェースは、JAXをIntelのシステムへストレートに統合するための手段を提供し、Intel GPU上でのJAXのワークロードを可能にする。さまざまなAIフレームワークとPJRTとの統合により、IntelのGPUプラグインはIntel GPUを使用する幅広い開発者にハードウェアアクセラレーションとoneAPIの最適化を提供できる。
PJRT APIは、上位のAIフレームワークがStableHLOで表現された数値計算をAIハードウェア/アクセラレータ上でコンパイル、実行するためのフレームワーク非依存型APIだ。このAPIは、JAX、TensorFlow(TF-XLA経由)、PyTorch(PyTorch-XLA経由)などの一般的なAIフレームワークと統合されており、ハードウェアベンダーは新しいAIハードウェアのプラグインを提供するだけで、これら全ての一般的AIフレームワークをサポートできるようになる。また、ゼロコピーバッファーの提供、依存関係管理が軽量で効率的なことなど、上位のAIフレームワークとの効率的なやりとりを可能にする低レベルプリミティブを提供し、AIフレームワークがハードウェアリソースを最大限に活用し、高性能な実行を実現できるようにした。
Intel GPU用PJRTプラグイン
Intel GPUプラグインは、StableHLOをコンパイルし、実行ファイルをIntel GPUにディスパッチすることで、PJRT APIを実装する。コンパイルはXLA実装をベースに、Intel GPU用のターゲット固有のパスを追加し、oneAPIパフォーマンスライブラリを活用してアクセラレーションを実現している。デバイスの実行は、SYCLランタイムを使用してサポートされている。また、Intel GPUプラグインにはデバイスの登録、列挙、SPMD(Single Program, Multiple Data)実行モードも実装されている。
PJRTの高レベルのランタイム抽象化により、プラグインは独自の低レベルのデバイス管理モジュールを開発し、新しいデバイスが提供する高度なランタイム機能を使用することが可能だ。
FlaxやT5XのようなJAXベースのフレームワークを含むJAXプログラムを実行するために、Intel GPUプラグインは簡単に使い始めることができる。ドキュメントであれば、プラグインをビルドし、環境変数と依存ライブラリのパスを設定するだけである。JAXは自動的にプラグインライブラリを探し、現在のプロセスにロードする。
以下は、Intel GPU上でJAXを実行するコードスニペット例である。
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_library/libitex_xla_extension.so' $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib $ python >>> import numpy as np >>> import jax >>> jax.local_devices() # PJRT Intel GPU plugin loaded [IntelXpuDevice(id=0, process_index=0), IntelXpuDevice(id=1, process_index=0)] >>> x = np.random.rand(2,2).astype(np.float32) >>> y = np.random.rand(2,2).astype(np.float32) >>> z = jax.numpy.add(x, y) # Runs on Intel XPU
今後の取り組み
Intel GPU用のPJRTプラグインは、IntelのGPUで動作するようにTensorFlowにも統合されている。これにより、TensorFlowモデル内のXLAサポートされた操作を実行することができる。しかし、XLAのオペレーションセットはTensorFlowよりも少ないため、XLAがTensorFlowで提供されている全てのオペレーションをサポートしていない。そこでTensorFlowモデルでは、モデルグラフの部分はPJRTで実行し、他の部分はTensorFlow OpKernelを使用して初期TensorFlowランタイムで実行している。
次のステップとして、GoogleとIntelは全てのTensorFlowモデルをサポートするIntel GPUで非XLAOpsを実装するために、NextPluggableDevice APIを採用し協力を続けていく予定だという。
Copyright © ITmedia, Inc. All Rights Reserved.
関連記事
Google Cloud、クラウド開発環境「Cloud Workstations」正式リリース OSS版VS Codeも利用可能
Google Cloudは、クラウド上で完全に管理された安全な開発環境を提供する「Cloud Workstations」の一般提供を開始した。Metaが次世代AIインフラ構築計画の進捗状況を発表 カスタムAIアクセラレータチップ、次世代DCなど
Metaは、次世代AIインフラを構築する計画の最近の進捗状況を発表した。発表の目玉は、AIモデルを実行するための同社初のカスタムシリコンチップ、AIに最適化された新しいデータセンター設計、1万6000個のGPUを搭載するAI研究用スーパーコンピュータの第2フェーズだ。Google、分散アプリケーションを構築・デプロイするOSSフレームワーク「Service Weaver」発表
Googleは、分散アプリケーションを構築、デプロイ(展開)するためのオープンソースフレームワーク「Service Weaver」を発表した。