データ並列分散学習

DataParallelCommunicator は、複数のデバイスを使ってニューラルネットワークを学習することができます。 DataParallelCommunicator は通常、データ並列分散学習において勾配交換に使われます。基本的に、ニューラルネットワークには、 データ並列とモデル並列、 2 つのタイプの分散学習があります。ここでは、前者のデータ並列学習のみに焦点を当てます。データ並列分散学習は、 ( ミニバッチ ) 確率的勾配降下法と呼ばれるニューラルネットワークの最適化に使われる非常に簡単な式に基づいています。

最適化プロセスにおいて、最小化を試みる目的関数は以下となります。

\[f(\mathbf{w}; X) = \frac{1}{B \times N} \sum_{i=1}^{B \times N} \ell(\mathbf{w}, \mathbf{x}_i),\]

ここで、 \(f\) はニューラルネットワーク、 \(B \times N\) はバッチサイズ、 \(\ell\) はそれぞれのデータポイント \(\mathbf{x} \in X\) に対する loss 関数、 \(\mathbf{w}\) はニューラルネットワークの学習可能なパラメータを表しています。

この関数を微分すると、以下が得られます。

\[\nabla_{\mathbf{w}} f(\mathbf{w}; X) = \frac{1}{B \times N} \sum_{i=1}^{B \times N} \nabla_{\mathbf{w}} \ell (\mathbf{w}, \mathbf{x}_i).\]

この導関数は線形なので、上記の目的関数を各 \(B\) データポイントにおける導関数の合計の総和に変えることができます。

\[\nabla_{\mathbf{w}} f(\mathbf{w}; X) = \frac{1}{N} \left( \frac{1}{B} \sum_{i=1}^{B} \nabla_{\mathbf{w}} \ell (\mathbf{w}, \mathbf{x}_i) \ + \frac{1}{B} \sum_{i=B+1}^{B \times 2} \nabla_{\mathbf{w}} \ell (\mathbf{w}, \mathbf{x}_i) \ + \ldots \ + \frac{1}{B} \sum_{i=B \times (N-1) + 1}^{B \times N} \nabla_{\mathbf{w}} \ell (\mathbf{w}, \mathbf{x}_i) \right)\]

データ並列分散学習では、上記の式に従って次のステップが実行されます。

  1. 各項ごとに、導関数 ( 勾配 ) の合計をバッチサイズ \(B\) で割る計算を個別のデバイス ( 一般的には GPU ) で行い、

  2. それらのデバイスの結果を総和し、

  3. その結果をデバイスの個数 \(N\) で割ります。

これはデータ並列分散学習の基礎となります。

このチュートリアルでは、とても簡単なサンプルを使って、データ並列分散学習に対する Multi Process Data Parallel Communicator の使い方を示します。

注意

このチュートリアルは IPython Cluster に依拠しているため、次のような Jupyter Notebook のスクリプトの抜粋を実行する場合は、 ここ に従い mpiexec/mpirun モードを有効し、Ipython Clusters タブで対応する Ipython Cluster を起動します。

クライアントの起動

以下のコードは、 Jupyter Notebook を用いるこのチュートリアル のみ で必要となります。

import ipyparallel as ipp
rc = ipp.Client(profile='mpi')

依存関係の準備

%%px
import os
import time

import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context
import nnabla.functions as F
from nnabla.initializer import (
    calc_uniform_lim_glorot,
    UniformInitializer)
import nnabla.parametric_functions as PF
import nnabla.solvers as S
import numpy as np

勾配交換のためのコミュニケーターの定義

%%px
extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()
n_devices = comm.size
mpi_rank = comm.rank
device_id = mpi_rank
ctx = get_extension_context(extension_module, device_id=device_id)

異なるランクが異なるデバイスに割り当てられていることを確認します

%%px
print("n_devices={}".format(n_devices))
print("mpi_rank={}".format(mpi_rank))
[stdout:0]
n_devices=2
mpi_rank=1
[stdout:1]
n_devices=2
mpi_rank=0

データポイントととても簡単なニューラルネットワークの作成

%%px
# Data points setting
n_class = 2
b, c, h, w = 4, 1, 32, 32

# Data points
x_data = np.random.rand(b, c, h, w)
y_data = np.random.choice(n_class, b).reshape((b, 1))
x = nn.Variable(x_data.shape)
y = nn.Variable(y_data.shape)
x.d = x_data
y.d = y_data

# Network setting
C = 1
kernel = (3, 3)
pad = (1, 1)
stride = (1, 1)
%%px
rng = np.random.RandomState(0)
w_init = UniformInitializer(
                    calc_uniform_lim_glorot(C, C/2, kernel=(1, 1)),
                    rng=rng)
%%px
# Network
with nn.context_scope(ctx):
    h = PF.convolution(x, C, kernel, pad, stride, w_init=w_init)
    pred = PF.affine(h, n_class, w_init=w_init)
    loss = F.mean(F.softmax_cross_entropy(pred, y))

留意事項 ここでは、最適化プロセスにおいて、各 GPU のネットワークが学習可能なパラメータ同士の同じ値から開始できるように w_init をパラメトリック関数に渡しています。

Solver の作成

%%px
# Solver and add parameters
solver = S.Adam()
solver.set_parameters(nn.get_parameters())

学習

ニューラルネットワークを学習するための nnabla API の基本的な使い方を思い出しましょう。

  1. loss.forward()

  2. solver.zero_grad()

  3. loss.backward()

  4. solver.update()

C.MultiProcessCommunicator を使用する場合、これらのステップは異なる GPU で実行され、これらのステップと 唯一異なる のは comm.all_reduce() です。従って、 C.MultiProcessCommunicator の場合、学習のステップは次のようになります。

  1. loss.forward()

  2. solver.zero_grad()

  3. loss.backward()

  4. comm.all_reduce([x.grad for x in nn.get_parameters().values()])

  5. solver.update()

まず、順方向、 zero_grad 、そして逆方向。

%%px
# Training steps
loss.forward()
solver.zero_grad()
loss.backward()

一度、重みの勾配を確かめます。

%%px
for n, v in nn.get_parameters().items():
    print(n, v.g)
[stdout:0]
('conv/W', array([[[[ 5.0180483,  0.457942 , -2.8701296],
         [ 2.0715926,  3.0698593, -1.6650047],
         [-2.5591214,  6.4248834,  9.881935 ]]]], dtype=float32))
('conv/b', array([8.658947], dtype=float32))
('affine/W', array([[-0.93160367,  0.9316036 ],
       [-1.376812  ,  1.376812  ],
       [-1.8957546 ,  1.8957543 ],
       ...,
       [-0.33000934,  0.33000934],
       [-0.7211893 ,  0.72118926],
       [-0.25237036,  0.25237036]], dtype=float32))
('affine/b', array([-0.48865744,  0.48865741], dtype=float32))
[stdout:1]
('conv/W', array([[[[ -1.2505884 ,  -0.87151337,  -8.685524  ],
         [ 10.738419  ,  14.676786  ,   7.483423  ],
         [  5.612471  , -12.880402  ,  19.141157  ]]]], dtype=float32))
('conv/b', array([13.196114], dtype=float32))
('affine/W', array([[-1.6865108 ,  1.6865108 ],
       [-0.938529  ,  0.938529  ],
       [-1.028422  ,  1.028422  ],
       ...,
       [-0.98217344,  0.98217344],
       [-0.97528917,  0.97528917],
       [-0.413546  ,  0.413546  ]], dtype=float32))
('affine/b', array([-0.7447065,  0.7447065], dtype=float32))

それぞれのデバイスで異なる値を確認できたら、 all_reduce を呼び出します。

%%px
comm.all_reduce([x.grad for x in nn.get_parameters().values()], division=True)

一般に、 all_reduce は合計を意味するだけですが、 comm.all_reduce は合計と加算除算どちらの場合にも対応します。

再度、重みの勾配を確かめます。

%%px
for n, v in nn.get_parameters().items():
    print(n, v.g)
[stdout:0]
('conv/W', array([[[[ 1.8837299 , -0.20678568, -5.777827  ],
         [ 6.4050055 ,  8.8733225 ,  2.9092093 ],
         [ 1.5266749 , -3.2277591 , 14.511546  ]]]], dtype=float32))
('conv/b', array([21.85506], dtype=float32))
('affine/W', array([[-2.6181145,  2.6181145],
       [-2.315341 ,  2.315341 ],
       [-2.9241767,  2.9241762],
       ...,
       [-1.3121828,  1.3121828],
       [-1.6964785,  1.6964784],
       [-0.6659163,  0.6659163]], dtype=float32))
('affine/b', array([-1.233364 ,  1.2333639], dtype=float32))
[stdout:1]
('conv/W', array([[[[ 1.8837299 , -0.20678568, -5.777827  ],
         [ 6.4050055 ,  8.8733225 ,  2.9092093 ],
         [ 1.5266749 , -3.2277591 , 14.511546  ]]]], dtype=float32))
('conv/b', array([21.85506], dtype=float32))
('affine/W', array([[-2.6181145,  2.6181145],
       [-2.315341 ,  2.315341 ],
       [-2.9241767,  2.9241762],
       ...,
       [-1.3121828,  1.3121828],
       [-1.6964785,  1.6964784],
       [-0.6659163,  0.6659163]], dtype=float32))
('affine/b', array([-1.233364 ,  1.2333639], dtype=float32))

all_reduce を使うことで、これらのデバイス上で同じ値を確認できます。

重みを更新します。

%%px
solver.update()

これで、データ並列分散学習に対する C.MultiProcessDataCommunicator の使用については終了です。

さて、 C.MultiProcessCommunicator の使い方について理解できたら、さらに詳細については 以下の cifar10 example を参照してください。

  1. multi_device_multi_process_classification.sh

  2. multi_device_multi_process_classification.py

​​​