PyTorch における Deformable Convolution の ONNX 変換
PyTorch では、何年か前から Deformable Convolution が DeformConv2d としてサポートされていました。
しかし、 ONNX 変換はサポートされておらず、 DeformConv2d を含むネットワークを ONNX にしようとするとエラーになっていました。
今回は、 DeformConv2d を PyTorch で ONNX に変換できるようにするモジュール deform_conv2d_onnx_exporter を作ったので、その説明です。
が、諸事情により、一旦ライブラリの公開は取りやめました。すみません。 → 再度公開しました。 PyPI のページからダウンロードできます。
Deformable Convolution の説明自体はここでは触れません。
ONNX 変換に必要なこと
DeformConv2d を ONNX に変換できないのは、内部で使われている deform_conv2d 関数を ONNX に変換する方法が定義されていないためです。
最終的に呼ばれる torch.ops.torchvision.deform_conv2d は、 CPU, CUDA 向けなどに C++ で実装されています。
そのため、 ONNX に変換するためには専用の変換関数を用意してやる必要があるのですが、現時点では実装されていません。
似たようなものとしては、 torch.ops.torchvision.roi_pool があるのですが、こちらは torchvision/ops/_register_onnx_ops.py の中で、以下のように ONNX 変換用の関数が登録されています。
こちらにあるように、 Graph object g を用いて、 ONNX の Operator である MaxRoiPool を使って ONNX のグラフを構築していくことになります。
ONNX Operator
ONNX の Operator 一覧は GitHub にあります。
今回やる必要があるのは、ここに定義されている Operator の一覧の中から適切なものを選び、組み合わせることで、 deform_conv2d の挙動を構築することになります。
なお、 ONNX の Operator は通常の手続き的な書き方ができるわけではなく、基本的には Tensor に対する処理の組み合わせで書いていくことになります。
deform_conv2d の ONNX Operator での実装
やることは、 Deformable ConvNets v2: More Deformable, Better Results の論文に書かれた内容を愚直に実装することになります。
注意点としては、 deform_conv2d の引数として渡されてくる Tensor のデータの格納順序が一部不明瞭なところがあるので、それを実動作ベースで確認しながら進めることになります。
大まかな流れとしては、下記になります。
- 各 Convolution の対象となる画素の位置を求める。
- その画素の位置座標が、画像外を指さないように丸める。
- 位置座標は小数になっているので、 Bilinear 補完により、周囲 4画素の値から重み付け平均を取り、画素の値とする。
なお、このとき、画素の値はGatherNDOperator を用いて取得をする。 - 得られた画素の値を基に、必要に応じて mask をかけ、 Convolution を行う。
- 必要に応じて bias を足す。
「愚直に実装する」と書きましたが、実際にはここはかなり苦労しました。
上記の各処理を ONNX の Operator のみで記述していくのは、 deform_conv2d の内部実装ドキュメントの少なさも相俟って、なかなか大変なものがありました。
気を付けたこと
Transpose を多用しない
上記の流れを実装する際に、 (論文に沿った愚直な方法では) GatherND を使って画素値を集めてくることになるかと思います。
似たような Operator の GatherElements でも実現はできますが、入力画像のチャネル数が大きくなった場合に、 (おそらく) とても非効率になります。
一方で、現時点の GatherND Operator の仕様上、複数のチャネルにまたがってデータを一括して集めてくる場合、そのチャネルがテンソルのランクの末尾の方に位置している必要があります。
つまり、 Tensor[batch, height, width, channel] のように、座標として指定される height, width よりも末尾に channel が位置している必要があります。
そのため、 Transpose Operator (PyTorch の Tensor の permute) で軸の順序を入れかえることになりますが、この処理は状況によっては負荷がかかるようです。
当初は Transpose を多用するコードになっていたのですが、なぜか速度がまったく出ず、 ONNX の profiling をすることで、無駄な Transpose が遅い要因の一つであることが分かりました。
そのため、必要最小限な部分に絞って Transpose を呼ぶようにすることで、効率化を図りました。
ただ、この辺りは ONNX のランタイムの実装にも強く依存する話なので、環境が変わると結果も異なるかもしれません。
不要な Operator を呼ばないようにするのはよいと思いますが。
GatherND で範囲外の座標値にアクセスしない
当然ですが、 GatherND では範囲外の座標値にアクセスしてはいけません。
当初は、これを防ぐために、範囲外なのかどうかを Greater, Less Operator を使って判定した mask を作り、これを活用することで範囲外へのアクセスを防いでいました。
しかし、この方法ではそれなりに時間がかかることから、 Clip Operator を使った方法に変更しました。
これは、まず入力画像の周囲に必ず 0 クリアされた Padding 領域が入るようにしておくことで、 Clip された座標値が Padding 領域を指すようにしておき、 Bilinear 補完の計算に入らないようにすることで実現しました。
これはそれなりに効率化がされたようで、処理速度も速くなりました。
これらを気を付けることで、 (遅いことは遅いのですが) それなりに使える範囲での実行速度を実現することができました。
よさそうなら PyTorch 本体にも Pull Request を出してみようかと思います。