Torch7で最近傍探索のベンチマーク

最近、Torch7で最近傍探索を繰り返し行いたかったけど、すごく遅いのでは??という不安があったのでk-NN(k=1)でベンチマークしてみた。

設定

  • MNISTをk-NN(k=1)で評価する
  • 尺度はコサイン類似度とする
  • テストを全件評価してかかった時間を計測する

環境

  • Intel(R) Core(TM) i7-3770K CPU @ 3.50GHz
  • 32GB RAM
  • GeForce GTX 760

パッと思いついた実装

初めは難しいことは考えず、パッと思いついた方法を試してみる。

工夫としては、

  • コサイン類似度を求める際にベクトルのノルムを毎回計算したくないので、最初にノルムが1になるように正規化しておく(内積=コサイン類似度になる)
  • 各テストデータと各学習データの比較は、gemv(torch.mv)で一度に計算すれば速いのではないか
require 'optim'
require 'xlua'

torch.setdefaulttensortype('torch.FloatTensor')
local EPSILON = 1.0e-6

-- normを1に正規化
local function normalize_l2(x)
   local norm = torch.pow(x, 2):sum(2):sqrt():add(EPSILON)
   x:cdiv(torch.expand(norm, x:size(1), x:size(2)))
end

function main()
   local mnist = require 'mnist'
   local trainset = mnist.traindataset()
   local testset = mnist.testdataset()
   local train_x, train_y = trainset.data:float(), trainset.label
   local test_x, test_y = testset.data:float(), testset.label
   local classes = {0,1,2,3,4,5,6,7,8,9}
   local confusion = optim.ConfusionMatrix(classes)
   local t = sys.clock()
   
   -- データを整形
   train_x = train_x:reshape(train_x:size(1), 28 * 28)
   test_x = test_x:reshape(test_x:size(1), 28 * 28)
   train_y:add(1)
   test_y:add(1)
   
   -- L2 normを1に正規化
   normalize_l2(train_x)
   normalize_l2(test_x)
   
   -- 各テストデータについて
   for i = 1, test_x:size(1) do
      -- コサイン類似度が最も大きいインスタンスを選択
      local _, nn_index = torch.mv(train_x, test_x[i]):max(1)
      -- 結果を評価
      local y = train_y[nn_index[1]]
      confusion:add(y, test_y[i])
      if i % 100 == 0 then
	 xlua.progress(i, test_x:size(1))
      end
   end
   -- 結果を表示
   print(confusion)
   print(string.format("*** %.2fs", sys.clock() - t))
end
main()
結果
ConfusionMatrix:
[[     978       1       0       0       0       0       0       1       0       0]   99.796% 	[class: 0]
 [       0    1129       3       1       0       1       1       0       0       0]   99.471% 	[class: 1]
 [       9       0    1003       4       0       0       2      10       3       1]   97.190% 	[class: 2]
 [       0       0       1     977       0      13       0       5       9       5]   96.733% 	[class: 3]
 [       1       3       0       0     940       0       6       3       1      28]   95.723% 	[class: 4]
 [       1       1       0      17       1     852      10       1       4       5]   95.516% 	[class: 5]
 [       4       3       0       0       2       3     946       0       0       0]   98.747% 	[class: 6]
 [       2      11       5       2       2       0       0     995       0      11]   96.790% 	[class: 7]
 [       6       1       1      13       2       3       5       4     935       4]   95.996% 	[class: 8]
 [       5       6       1       4       9       3       1       8       4     968]]  95.937% 	[class: 9]
 + average row correct: 97.189832925797% 
 + average rowUcol correct (VOC measure): 94.583150148392% 
 + global correct: 97.23%
*** 95.11s	

95.11秒だった。正解率は97.23%。世の中にはMNISTをk-NNしてみたら数時間かかったとか言っている人も散見されるので、そう考えると速い気もする。

gemvではなくgemmで計算する

インスタンスひとつづずgemvしてたけど、全体をgemm(torch.mm, 行列の積)で計算すればもっと速くなるだろうと思ったので変更してみた。

require 'optim'
require 'xlua'

torch.setdefaulttensortype('torch.FloatTensor')
local EPSILON = 1.0e-6

-- normを1に正規化
local function normalize_l2(x)
   local norm = torch.pow(x, 2):sum(2):sqrt()
   norm:add(EPSILON)
   x:cdiv(torch.expand(norm, x:size(1), x:size(2)))
end

function main()
   local mnist = require 'mnist'
   local trainset = mnist.traindataset()
   local testset = mnist.testdataset()
   local train_x, train_y = trainset.data:float(), trainset.label
   local test_x, test_y = testset.data:float(), testset.label
   local classes = {0,1,2,3,4,5,6,7,8,9}
   local confusion = optim.ConfusionMatrix(classes)
   local t = sys.clock()
   
   -- データを整形
   train_x = train_x:reshape(train_x:size(1), 28 * 28)
   test_x = test_x:reshape(test_x:size(1), 28 * 28)
   train_y:add(1)
   test_y:add(1)
   
   -- L2 normを1に正規化
   normalize_l2(train_x)
   normalize_l2(test_x)
   -- 全てのテストデータについてコサイン類似度が最も大きいインスタンスを選択
   local cosine = torch.mm(train_x, test_x:t())
   local _, nn_index = cosine:max(1)
   -- 結果を評価
   for i = 1, test_x:size(1) do
      local y = train_y[nn_index[1][i]]
      confusion:add(y, test_y[i])
      if i % 100 == 0 then
	 xlua.progress(i, test_x:size(1))
      end
   end
   -- 結果を表示
   print(confusion)
   print(string.format("*** %.2fs", sys.clock() - t))
end
main()
結果
ConfusionMatrix:
[[     978       1       0       0       0       0       0       1       0       0]   99.796% 	[class: 0]
 [       0    1129       3       1       0       1       1       0       0       0]   99.471% 	[class: 1]
 [       9       0    1003       4       0       0       2      10       3       1]   97.190% 	[class: 2]
 [       0       0       1     977       0      13       0       5       9       5]   96.733% 	[class: 3]
 [       1       3       0       0     940       0       6       3       1      28]   95.723% 	[class: 4]
 [       1       1       0      17       1     852      10       1       4       5]   95.516% 	[class: 5]
 [       4       3       0       0       2       3     946       0       0       0]   98.747% 	[class: 6]
 [       2      11       5       2       2       0       0     995       0      11]   96.790% 	[class: 7]
 [       6       1       1      13       2       3       5       4     935       4]   95.996% 	[class: 8]
 [       5       6       1       4       9       3       1       8       4     968]]  95.937% 	[class: 9]
 + average row correct: 97.189832925797% 
 + average rowUcol correct (VOC measure): 94.583150148392% 
 + global correct: 97.23%
*** 12.41s

12.41秒だった。正解率は当然同じ。かなり速くなった。

CUDAでやってみる

Torch7はBLASでできるような計算なら簡単にCUDA化できるのでやってみた。

工夫として、

  • gemmは使用メモリが多すぎてGPUにメモリが確保できなかったので分割して計算するようにした
require 'cutorch'
require 'optim'
require 'xlua'

torch.setdefaulttensortype('torch.FloatTensor')
local EPSILON = 1.0e-6

-- normを1に正規化
local function normalize_l2(x)
   local norm = torch.pow(x, 2):sum(2):sqrt()
   norm:add(EPSILON)
   x:cdiv(torch.expand(norm, x:size(1), x:size(2)))
end

-- 行列の積を直接計算しようとするとGPUメモリに載らなかったので16分割して計算
local function split_mm(a, b)
   local BLOCKS = 16 -- 分割数
   local step = math.floor(b:size(1) / BLOCKS)
   local results = torch.Tensor(a:size(1), b:size(1))
   for i = 1, b:size(1), step do
      local n = step
      if i + n > b:size(1) then
	 n = b:size(1) - i
      end
      if n > 0 then
	 results:narrow(2, i, n):copy(torch.mm(a, b:narrow(1, i, n):t()))
      end
      collectgarbage()
   end
   return results
end

function main()
   local mnist = require 'mnist'
   local trainset = mnist.traindataset()
   local testset = mnist.testdataset()
   local train_x, train_y = trainset.data:float(), trainset.label
   local test_x, test_y = testset.data:float(), testset.label
   local classes = {0,1,2,3,4,5,6,7,8,9}
   local confusion = optim.ConfusionMatrix(classes)
   local t = sys.clock()

   -- データを整形
   train_x = train_x:reshape(train_x:size(1), 28 * 28)
   test_x = test_x:reshape(test_x:size(1), 28 * 28)
   train_y:add(1)
   test_y:add(1)
   
   -- L2 normを1に正規化
   normalize_l2(train_x)
   normalize_l2(test_x)
   
   -- 計算用のデータをCudaTensorに変換(GPUのデバイスメモリに転送)
   train_x = train_x:cuda()
   test_x = test_x:cuda()
   
   -- 全てのテストデータについてコサイン類似度が最も大きいインスタンスを選択
   local cosine = split_mm(train_x, test_x)
   local _, nn_index = cosine:max(1)
   -- 結果を評価
   for i = 1, test_x:size(1) do
      local y = train_y[nn_index[1][i]]
      confusion:add(y, test_y[i])
      if i % 100 == 0 then
	 xlua.progress(i, test_x:size(1))
      end
   end
   print(confusion)
   print(string.format("*** %.2fs", sys.clock() - t))
end
main()
結果
ConfusionMatrix:
[[     978       1       0       0       0       0       0       1       0       0]   99.796% 	[class: 0]
 [       0    1129       3       1       0       1       1       0       0       0]   99.471% 	[class: 1]
 [       9       0    1003       4       0       0       2      10       3       1]   97.190% 	[class: 2]
 [       0       0       1     977       0      13       0       5       9       5]   96.733% 	[class: 3]
 [       1       3       0       0     940       0       6       3       1      28]   95.723% 	[class: 4]
 [       1       1       0      17       1     852      10       1       4       5]   95.516% 	[class: 5]
 [       4       3       0       0       2       4     945       0       0       0]   98.643% 	[class: 6]
 [       2      11       5       2       2       0       0     995       0      11]   96.790% 	[class: 7]
 [       6       1       1      13       2       3       5       4     935       4]   95.996% 	[class: 8]
 [       5       6       1       4       9       3       1       8       4     968]]  95.937% 	[class: 9]
 + average row correct: 97.179394364357% 
 + average rowUcol correct (VOC measure): 94.562811851501% 
 + global correct: 97.22%
*** 5.12s	

5.12秒だった。爆速!