Torch7で複雑なモデルを書くときに便利な技
Torch7はnnに用意されている部品を組み合わせることで複雑なモデルを作れて便利なのですが、複雑なモデルは入力ベクトル(orテンソル)を様々形に変換しながら実行するので、どの時点でどのサイズになっているか分からなくて書くのが難しいというのがあります。
たとえば、伝統的なCNNは、Torch7で以下のように書くのですが
require 'nn' function cnn_model() local model = nn.Sequential() -- convolution layers model:add(nn.SpatialConvolutionMM(3, 128, 5, 5, 1, 1)) model:add(nn.ReLU()) model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) model:add(nn.SpatialConvolutionMM(128, 256, 5, 5, 1, 1)) model:add(nn.ReLU()) model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) model:add(nn.SpatialZeroPadding(1, 1, 1, 1)) model:add(nn.SpatialConvolutionMM(256, 512, 4, 4, 1, 1)) model:add(nn.ReLU()) -- fully connected layers model:add(nn.SpatialConvolutionMM(512, 1024, 2, 2, 1, 1)) model:add(nn.ReLU()) model:add(nn.Dropout(0.5)) model:add(nn.SpatialConvolutionMM(1024, 10, 1, 1, 1, 1)) model:add(nn.Reshape(10)) model:add(nn.SoftMax()) return model end
このSpatialConvolutionMMの畳み込みカーネルが入力サイズをはみ出していないかとか、SpatialMaxPoolingの入力が奇数になっていて端っこが処理されていないのではないかとか気になります。
こういう場合、以下のようなデバッグプリントを入れると、各層でどのようなサイズになっているか分かります。
require 'nn' function cnn_model() local model = nn.Sequential() local debug_input = torch.Tensor(3, 24, 24):uniform() -- convolution layers model:add(nn.SpatialConvolutionMM(3, 128, 5, 5, 1, 1)) model:add(nn.ReLU()) -- model:cuda() -- 必要ならcudaにする print(model:forward(debug_input):size()) model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) print(model:forward(debug_input):size()) model:add(nn.SpatialConvolutionMM(128, 256, 5, 5, 1, 1)) model:add(nn.ReLU()) print(model:forward(debug_input):size()) model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) print(model:forward(debug_input):size()) model:add(nn.SpatialZeroPadding(1, 1, 1, 1)) model:add(nn.SpatialConvolutionMM(256, 512, 4, 4, 1, 1)) model:add(nn.ReLU()) print(model:forward(debug_input):size()) -- fully connected layers model:add(nn.SpatialConvolutionMM(512, 1024, 2, 2, 1, 1)) model:add(nn.ReLU()) model:add(nn.Dropout(0.5)) print(model:forward(debug_input):size()) model:add(nn.SpatialConvolutionMM(1024, 10, 1, 1, 1, 1)) print(model:forward(debug_input):size()) model:add(nn.Reshape(10)) model:add(nn.SoftMax()) return model end cnn_model()
実行結果
% th t.lua 128 20 20 [torch.LongStorage of size 3] 128 10 10 [torch.LongStorage of size 3] 256 6 6 [torch.LongStorage of size 3] 256 3 3 [torch.LongStorage of size 3] 512 2 2 [torch.LongStorage of size 3] 1024 1 1 [torch.LongStorage of size 3] 10 1 1 [torch.LongStorage of size 3]
各層で出力がどのようなサイズになるかは、ちゃんと計算すれば分かるのですが、僕のように暗算を得意としない人間には計算するのが非常にだるいのでデバッグプリントを入れる技がとても便利です。