GoogleはPJRTの初のプラグインPJRT APIをIntelと共同開発したことを発表した。PJRT APIを実装することでIntel GPU上でのJAXモデルの実行や、XLAアクセラレーションを使用したTensorFlowおよびPyTorchモデルの初期Intel GPUサポートも可能になる。
この記事は会員限定です。会員登録(無料)すると全てご覧いただけます。
Googleは2023年6月1日(米国時間)、Intel GPU上でJAXモデルをシームレスに実行する「Intel Extension for TensorFlow」で、初のPJRT(Performance JIT Runtime)プラグインの実装を発表した。PJRT APIは別途開発されたIntel GPUプラグインとの統合を簡素化し、JAXへの迅速な統合を可能にする。また、PJRTの実装により、XLAアクセラレーションを備えたTensorFlowとPyTorchモデルの初期Intel GPUサポートも有効になる。
IntelとGoogleは共同でTensorFlow PluggableDeviceメカニズムを開発した。このメカニズムはTensorFlowを新しいデバイスに拡張するためにサポートされた手法であり、ハードウェアベンダーはこれを使用して、独自のプラグインバイナリのリリースが可能になる。Intelは、XLAコンパイラ用のモジュラーインタフェースの構築や、IntelGPU上でJAXワークロードを実行するためのPJRTプラグインの開発でGoogleと協力を続けている。
JAXは、GPUやTPUなどのハイパフォーマンスコンピューティングデバイスでの複雑な数値計算のために設計されたオープンソースのPythonライブラリである。NumPy関数をサポートし、自動微分や、ニューラルネットワークを構築、訓練するためのコンポーザブル関数変換システムを提供する。
JAXは、コンパイルと実行のバックエンドとしてXLAを使用し、特にAIハードウェアアクセラレータ上での計算を最適化し並列化する。JAXプログラムが実行されると、PythonコードはOpenXLAのStableHLOオペレーションに変換され、コンパイルと実行のためにPJRTに渡される。その下で、StableHLOオペレーションがXLAコンパイラによってマシンコードにコンパイルされ、ターゲットのハードウェアアクセラレータ上で実行することができる。
PJRT(OpenXLAのStableHLOと組み合わせて使用)は、コンパイラとランタイム向けのハードウェアおよびフレームワークに依存しないインタフェースを提供する。PJRTインタフェースは、新しいデバイスバックエンドからのプラグインをサポートしている。このインタフェースは、JAXをIntelのシステムへストレートに統合するための手段を提供し、Intel GPU上でのJAXのワークロードを可能にする。さまざまなAIフレームワークとPJRTとの統合により、IntelのGPUプラグインはIntel GPUを使用する幅広い開発者にハードウェアアクセラレーションとoneAPIの最適化を提供できる。
PJRT APIは、上位のAIフレームワークがStableHLOで表現された数値計算をAIハードウェア/アクセラレータ上でコンパイル、実行するためのフレームワーク非依存型APIだ。このAPIは、JAX、TensorFlow(TF-XLA経由)、PyTorch(PyTorch-XLA経由)などの一般的なAIフレームワークと統合されており、ハードウェアベンダーは新しいAIハードウェアのプラグインを提供するだけで、これら全ての一般的AIフレームワークをサポートできるようになる。また、ゼロコピーバッファーの提供、依存関係管理が軽量で効率的なことなど、上位のAIフレームワークとの効率的なやりとりを可能にする低レベルプリミティブを提供し、AIフレームワークがハードウェアリソースを最大限に活用し、高性能な実行を実現できるようにした。
Intel GPUプラグインは、StableHLOをコンパイルし、実行ファイルをIntel GPUにディスパッチすることで、PJRT APIを実装する。コンパイルはXLA実装をベースに、Intel GPU用のターゲット固有のパスを追加し、oneAPIパフォーマンスライブラリを活用してアクセラレーションを実現している。デバイスの実行は、SYCLランタイムを使用してサポートされている。また、Intel GPUプラグインにはデバイスの登録、列挙、SPMD(Single Program, Multiple Data)実行モードも実装されている。
PJRTの高レベルのランタイム抽象化により、プラグインは独自の低レベルのデバイス管理モジュールを開発し、新しいデバイスが提供する高度なランタイム機能を使用することが可能だ。
FlaxやT5XのようなJAXベースのフレームワークを含むJAXプログラムを実行するために、Intel GPUプラグインは簡単に使い始めることができる。ドキュメントであれば、プラグインをビルドし、環境変数と依存ライブラリのパスを設定するだけである。JAXは自動的にプラグインライブラリを探し、現在のプロセスにロードする。
以下は、Intel GPU上でJAXを実行するコードスニペット例である。
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_library/libitex_xla_extension.so' $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib $ python >>> import numpy as np >>> import jax >>> jax.local_devices() # PJRT Intel GPU plugin loaded [IntelXpuDevice(id=0, process_index=0), IntelXpuDevice(id=1, process_index=0)] >>> x = np.random.rand(2,2).astype(np.float32) >>> y = np.random.rand(2,2).astype(np.float32) >>> z = jax.numpy.add(x, y) # Runs on Intel XPU
Intel GPU用のPJRTプラグインは、IntelのGPUで動作するようにTensorFlowにも統合されている。これにより、TensorFlowモデル内のXLAサポートされた操作を実行することができる。しかし、XLAのオペレーションセットはTensorFlowよりも少ないため、XLAがTensorFlowで提供されている全てのオペレーションをサポートしていない。そこでTensorFlowモデルでは、モデルグラフの部分はPJRTで実行し、他の部分はTensorFlow OpKernelを使用して初期TensorFlowランタイムで実行している。
次のステップとして、GoogleとIntelは全てのTensorFlowモデルをサポートするIntel GPUで非XLAOpsを実装するために、NextPluggableDevice APIを採用し協力を続けていく予定だという。
Copyright © ITmedia, Inc. All Rights Reserved.