Torch7で最近傍探索のベンチマーク
最近、Torch7で最近傍探索を繰り返し行いたかったけど、すごく遅いのでは??という不安があったのでk-NN(k=1)でベンチマークしてみた。
設定
- MNISTをk-NN(k=1)で評価する
- 尺度はコサイン類似度とする
- テストを全件評価してかかった時間を計測する
パッと思いついた実装
初めは難しいことは考えず、パッと思いついた方法を試してみる。
工夫としては、
- コサイン類似度を求める際にベクトルのノルムを毎回計算したくないので、最初にノルムが1になるように正規化しておく(内積=コサイン類似度になる)
- 各テストデータと各学習データの比較は、gemv(torch.mv)で一度に計算すれば速いのではないか
require 'optim' require 'xlua' torch.setdefaulttensortype('torch.FloatTensor') local EPSILON = 1.0e-6 -- normを1に正規化 local function normalize_l2(x) local norm = torch.pow(x, 2):sum(2):sqrt():add(EPSILON) x:cdiv(torch.expand(norm, x:size(1), x:size(2))) end function main() local mnist = require 'mnist' local trainset = mnist.traindataset() local testset = mnist.testdataset() local train_x, train_y = trainset.data:float(), trainset.label local test_x, test_y = testset.data:float(), testset.label local classes = {0,1,2,3,4,5,6,7,8,9} local confusion = optim.ConfusionMatrix(classes) local t = sys.clock() -- データを整形 train_x = train_x:reshape(train_x:size(1), 28 * 28) test_x = test_x:reshape(test_x:size(1), 28 * 28) train_y:add(1) test_y:add(1) -- L2 normを1に正規化 normalize_l2(train_x) normalize_l2(test_x) -- 各テストデータについて for i = 1, test_x:size(1) do -- コサイン類似度が最も大きいインスタンスを選択 local _, nn_index = torch.mv(train_x, test_x[i]):max(1) -- 結果を評価 local y = train_y[nn_index[1]] confusion:add(y, test_y[i]) if i % 100 == 0 then xlua.progress(i, test_x:size(1)) end end -- 結果を表示 print(confusion) print(string.format("*** %.2fs", sys.clock() - t)) end main()
結果
ConfusionMatrix: [[ 978 1 0 0 0 0 0 1 0 0] 99.796% [class: 0] [ 0 1129 3 1 0 1 1 0 0 0] 99.471% [class: 1] [ 9 0 1003 4 0 0 2 10 3 1] 97.190% [class: 2] [ 0 0 1 977 0 13 0 5 9 5] 96.733% [class: 3] [ 1 3 0 0 940 0 6 3 1 28] 95.723% [class: 4] [ 1 1 0 17 1 852 10 1 4 5] 95.516% [class: 5] [ 4 3 0 0 2 3 946 0 0 0] 98.747% [class: 6] [ 2 11 5 2 2 0 0 995 0 11] 96.790% [class: 7] [ 6 1 1 13 2 3 5 4 935 4] 95.996% [class: 8] [ 5 6 1 4 9 3 1 8 4 968]] 95.937% [class: 9] + average row correct: 97.189832925797% + average rowUcol correct (VOC measure): 94.583150148392% + global correct: 97.23% *** 95.11s
95.11秒だった。正解率は97.23%。世の中にはMNISTをk-NNしてみたら数時間かかったとか言っている人も散見されるので、そう考えると速い気もする。
gemvではなくgemmで計算する
インスタンスひとつづずgemvしてたけど、全体をgemm(torch.mm, 行列の積)で計算すればもっと速くなるだろうと思ったので変更してみた。
require 'optim' require 'xlua' torch.setdefaulttensortype('torch.FloatTensor') local EPSILON = 1.0e-6 -- normを1に正規化 local function normalize_l2(x) local norm = torch.pow(x, 2):sum(2):sqrt() norm:add(EPSILON) x:cdiv(torch.expand(norm, x:size(1), x:size(2))) end function main() local mnist = require 'mnist' local trainset = mnist.traindataset() local testset = mnist.testdataset() local train_x, train_y = trainset.data:float(), trainset.label local test_x, test_y = testset.data:float(), testset.label local classes = {0,1,2,3,4,5,6,7,8,9} local confusion = optim.ConfusionMatrix(classes) local t = sys.clock() -- データを整形 train_x = train_x:reshape(train_x:size(1), 28 * 28) test_x = test_x:reshape(test_x:size(1), 28 * 28) train_y:add(1) test_y:add(1) -- L2 normを1に正規化 normalize_l2(train_x) normalize_l2(test_x) -- 全てのテストデータについてコサイン類似度が最も大きいインスタンスを選択 local cosine = torch.mm(train_x, test_x:t()) local _, nn_index = cosine:max(1) -- 結果を評価 for i = 1, test_x:size(1) do local y = train_y[nn_index[1][i]] confusion:add(y, test_y[i]) if i % 100 == 0 then xlua.progress(i, test_x:size(1)) end end -- 結果を表示 print(confusion) print(string.format("*** %.2fs", sys.clock() - t)) end main()
結果
ConfusionMatrix: [[ 978 1 0 0 0 0 0 1 0 0] 99.796% [class: 0] [ 0 1129 3 1 0 1 1 0 0 0] 99.471% [class: 1] [ 9 0 1003 4 0 0 2 10 3 1] 97.190% [class: 2] [ 0 0 1 977 0 13 0 5 9 5] 96.733% [class: 3] [ 1 3 0 0 940 0 6 3 1 28] 95.723% [class: 4] [ 1 1 0 17 1 852 10 1 4 5] 95.516% [class: 5] [ 4 3 0 0 2 3 946 0 0 0] 98.747% [class: 6] [ 2 11 5 2 2 0 0 995 0 11] 96.790% [class: 7] [ 6 1 1 13 2 3 5 4 935 4] 95.996% [class: 8] [ 5 6 1 4 9 3 1 8 4 968]] 95.937% [class: 9] + average row correct: 97.189832925797% + average rowUcol correct (VOC measure): 94.583150148392% + global correct: 97.23% *** 12.41s
12.41秒だった。正解率は当然同じ。かなり速くなった。
CUDAでやってみる
Torch7はBLASでできるような計算なら簡単にCUDA化できるのでやってみた。
工夫として、
- gemmは使用メモリが多すぎてGPUにメモリが確保できなかったので分割して計算するようにした
require 'cutorch' require 'optim' require 'xlua' torch.setdefaulttensortype('torch.FloatTensor') local EPSILON = 1.0e-6 -- normを1に正規化 local function normalize_l2(x) local norm = torch.pow(x, 2):sum(2):sqrt() norm:add(EPSILON) x:cdiv(torch.expand(norm, x:size(1), x:size(2))) end -- 行列の積を直接計算しようとするとGPUメモリに載らなかったので16分割して計算 local function split_mm(a, b) local BLOCKS = 16 -- 分割数 local step = math.floor(b:size(1) / BLOCKS) local results = torch.Tensor(a:size(1), b:size(1)) for i = 1, b:size(1), step do local n = step if i + n > b:size(1) then n = b:size(1) - i end if n > 0 then results:narrow(2, i, n):copy(torch.mm(a, b:narrow(1, i, n):t())) end collectgarbage() end return results end function main() local mnist = require 'mnist' local trainset = mnist.traindataset() local testset = mnist.testdataset() local train_x, train_y = trainset.data:float(), trainset.label local test_x, test_y = testset.data:float(), testset.label local classes = {0,1,2,3,4,5,6,7,8,9} local confusion = optim.ConfusionMatrix(classes) local t = sys.clock() -- データを整形 train_x = train_x:reshape(train_x:size(1), 28 * 28) test_x = test_x:reshape(test_x:size(1), 28 * 28) train_y:add(1) test_y:add(1) -- L2 normを1に正規化 normalize_l2(train_x) normalize_l2(test_x) -- 計算用のデータをCudaTensorに変換(GPUのデバイスメモリに転送) train_x = train_x:cuda() test_x = test_x:cuda() -- 全てのテストデータについてコサイン類似度が最も大きいインスタンスを選択 local cosine = split_mm(train_x, test_x) local _, nn_index = cosine:max(1) -- 結果を評価 for i = 1, test_x:size(1) do local y = train_y[nn_index[1][i]] confusion:add(y, test_y[i]) if i % 100 == 0 then xlua.progress(i, test_x:size(1)) end end print(confusion) print(string.format("*** %.2fs", sys.clock() - t)) end main()
結果
ConfusionMatrix: [[ 978 1 0 0 0 0 0 1 0 0] 99.796% [class: 0] [ 0 1129 3 1 0 1 1 0 0 0] 99.471% [class: 1] [ 9 0 1003 4 0 0 2 10 3 1] 97.190% [class: 2] [ 0 0 1 977 0 13 0 5 9 5] 96.733% [class: 3] [ 1 3 0 0 940 0 6 3 1 28] 95.723% [class: 4] [ 1 1 0 17 1 852 10 1 4 5] 95.516% [class: 5] [ 4 3 0 0 2 4 945 0 0 0] 98.643% [class: 6] [ 2 11 5 2 2 0 0 995 0 11] 96.790% [class: 7] [ 6 1 1 13 2 3 5 4 935 4] 95.996% [class: 8] [ 5 6 1 4 9 3 1 8 4 968]] 95.937% [class: 9] + average row correct: 97.179394364357% + average rowUcol correct (VOC measure): 94.562811851501% + global correct: 97.22% *** 5.12s
5.12秒だった。爆速!