np.sumよりnp.dotするほうが速い
TL;DR
- np.dotはBLASを呼び出すので、np.sumするよりnp.dotを使うほうが速い(float型に限る)
- 条件付きでsumを取る場合は中間結果を保持する必要がないため更に速くなる
- (理由)np.sumはnumpyネイティブな処理なのに比べて、np.dotはBLASを呼び出しているためより最適化された処理が行われているから?
BLASの活用
御存じの通りNumPyはPythonにおいて数値計算を効率的に行うための拡張モジュールで、NumPyで書けば大抵の計算はPythonの遅さを気にすることなく実行できます。 しかしNumPyでも呼び出す関数や書き方によって処理速度が異なってくることがあります。
実はNumPy自体もその内部で更にOpenBLASやMKLといった高速な線形代数ライブラリ(BLAS:Basic Linear Algebra Subprograms)を呼び出しています。OpenBLASやMKL内部での処理は並列化やメモリアクセスの最適化がなされているため、できるだけBLASを呼び出すような処理を書くと同じNumPyを使う場合でもより速く処理ができることが期待できます。
今回は例としてnp.sum処理をより線形代数っぽい操作に置き換えることでより高速化してみます。
np.sumをnp.dotに置き換える
適当な2次元配列Aに対して2次元目の軸について和を取る操作
A.sum(axis=1)
を考えます。この操作は線形代数的な演算で書き直すことができます。Aを$M×N$行列と見なし、すべての要素が$1$である$N$次元ベクトル$b$を用意すると
行列積$Ab$は$M$次元ベクトルとなり各要素は行列積の定義よりAの各行の列和となるのでこれはA.sum(axis=1)
と同じ結果となります。実際に以下のようなテストをしてみるとちゃんと通ることがわかります。
A = np.random.randint(0, 1000, size=(5000, 5000)) b = np.ones(5000) np.testing.assert_almost_equal(A.sum(axis=1), np.dot(A, b))
A.sum
とnp.dot
で計算量は変わらないですが、np.dotのほうは線形代数の演算なのでBLASが呼び出されます。
条件付きのnp.sumをnp.dotで置き換える
np.sumを使うとき、特定の要素だけを足し合わせたい場合もあります。この場合、sum関数のwhere引数を作るか、またはあらかじめフラグ配列を掛けてからnp.sum
することで実現できます。例えば、足し合わせる箇所を表すベクトルをcond
とすると
A=np.random.randint(0,1000,size=(5000,5000)) cond=np.random.randint(0,2,5000).astype(np.bool_) (A*cond).sum(axis=1) A.sum(axis=1,where=cond)
この場合も線形代数的な演算で書き直すことができます。先ほどのすべての要素が$1$である$N$次元ベクトル$b$の代わりに条件を満たす箇所だけ1でそうでない箇所を0にした配列cond
を用いて、行列積$A\cdot \rm{cond}$を考えることでBLASを呼び出しての処理を行うことが可能です。更にこの場合、行列積で求めるやり方だと(A*cond).sum(axis=1)
に比べて中間変数を保持する必要がないため、その点でも効率化されます。
速度計測
以下の環境で実際に各処理の実行時間を測定してみました。
条件なしのsum処理
実行コード | 実行時間(np.int32) | 実行時間(np.float32) | 高速化倍率(np.int32) | 高速化倍率(np.float32) |
---|---|---|---|---|
A.sum(axis=1) | 5.62 ms | 7.54 ms | 1.0x | 1.0x |
np.dot(A,b) | 8.53 ms | 2.66 ms | 0.65x | 2.83x |
int32型での計算速度はnp,sumのほうが速い結果となっていますが、float32型ではnp.dotのほうが2.83倍も高速になっています。実際にCPU使用率を見てもnp.dot計算中のほうが全てのコアが満遍なく使用されていてマシンリソースを最大限活用できているようでした。int32型の計算が遅いのはBLASが本来数値計算用に用いられるため、float型での最適化しか想定してないからでしょうか?(詳しくはわかりませんでした。)
条件付きのsum処理
処理内容 | 実行時間(np.int32) | 実行時間(np.float32) | 高速化倍率(np.int32) | 高速化倍率(np.float32) |
---|---|---|---|---|
(A*cond).sum(axis=1) | 28.0 ms | 28.8 ms | 1.0x | 1.0x |
np.dot(A,cond) | 8.39 ms | 2.32 ms | 3.34x | 12.4x |
中間変数の保持が省かれる分、単純なsumに比べてより高速化できています。今回に限らず中間変数の保持をなるべく省くことがNumPyの高速化のカギですね。*1
まとめ
NumPyは速いが、その中でもBLASを呼び出す関数を上手く使うとより高速化できる。ただし、float型を使わないとそこまで恩恵は得らない(かえって遅くなる) 行ないたい処理が線形代数的な演算で書けるかを常に気をつける。