Communicator

Communicator は、計算グラフに含まれているパラメータを転送するAPIです。

このドキュメントは、communicator.py のエイリアスです。

Communicator のインターフェイス

class nnabla.communicators.Communicator

Communicator のインターフェイスクラスです。

Communicator は、MPI のようにグループ通信を利用して、データ(例:勾配など)をやり取りします。このクラスは、分散学習を行う際に使用されます。

abort(self)

MPI 実行環境を終了します。

add_context_and_parameters(self, ctx_param_dict)

コンテキストとパラメータを追加します。

パラメータ:

ctx_param_dict (tuple of Context, dict) -- キーが string 、値が Variable の辞書です。

all_gather(self, ndarray, ndarray_list, string group='world')

異なるデバイスのデータに対して、All gether を実行します。

パラメータ:
  • ndarray (NdArray) -- 収集するデータ。

  • ndarray_list (NdArray) -- 保存するデータ 。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

例:

# Run like `mpirun -n 2 python <code_snippet.py>`
# note: the order of the output to stdout are stochastic because of multiprocesses.

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# Data
x = nn.Variable([2, 2])
x.d = np.random.rand(*x.shape)
y_list = [nn.Variable([2, 2]), nn.Variable([2, 2])]
print("Before the collective ({}-th)".format(comm.rank))
print(x.d)

# AllGather
comm.all_gather(x.data, [y.data for y in y_list])

# Check
print("After the collective ({}-th)".format(comm.rank))
for y in y_list:
    print(y.d)
all_reduce(self, data, bool division=False, bool inplace=False, string group='world')

異なるデバイスのデータに対して、 All reduce を実行します。

パラメータ:
  • data (NdArray or list of NdArray) --

  • division (bool) -- all_reduce した値を 与えられた contexts 数、または通信を行うデバイス数で割るかどうかを決定するフラグです。

  • inplace (bool) -- パック配列を使用するためのフラグです。デフォルトは false です。true の場合、メモリ効率は良いですが、処理が遅くなります。false の場合、メモリ効率は良くありませんが高速です。どちらの場合も、最終的な処理結果は同じメモリ領域に得られます。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

例:

# Run like `mpirun -n 2 python <code_snippet.py>`
# note: the order of the output to stdout are stochastic because of multiprocesses.

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# Data
x_list = [nn.Variable([2, 2]), nn.Variable([2, 2])]
print("Before the collective ({}-th)".format(comm.rank))
for x in x_list:
    x.d = np.random.rand(*x.shape)
    print(x.d)

# AllReduce
comm.all_reduce([x.data for x in x_list], inplace=True)

# Check
print("After the collective ({}-th)".format(comm.rank))
for x in x_list:
    print(x.d)
all_reduce_callback(self, data, size_t pack_size, bool division=False, string group='world', float scale_grad=1, bool keep_dtype=False)

異なるデバイスのデータに対して、 All reduce を実行します。

注釈

この関数は現在、共有パラメータ ( RNN など) をサポートしていません。

パラメータ:
  • data (NdArray or list of NdArray) --

  • pack_size (int) -- パックデータに含まれる値の数。

  • division (bool) -- all_reduce した値を 与えられた contexts 数、または通信を行うデバイス数で割るかどうかを決定するフラグです。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

  • scale_grad (float) -- Apply scaling by the specified factor before performing all-reduce. This is useful when you apply loss scaling in mixed precision training and cancel it for gradient arrays before all-reduce.

  • keep_dtype (bool) -- If True, the dtype of arrays is kept the same regardless of communicator's dtype used in all-reduce operation. This is useful when you use the all-reduce callback in mixed precision training and when any of gradient NdArray`s is narrowed by :py:meth:`NdArray.narrow. In this case, you will get an error unless you specify True because a narrowed array prohibits dtype casting.

例:

マルチプロセスデータ並列分散学習を行う例を以下に示します。

# Run like `mpirun -n 2 python <code_snippet.py>`

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

n_class = 2
b, c, h, w = 4, 1, 32, 32

# Data
x = nn.Variable([b, c, h, w])
y = nn.Variable([b, 1])

# Network setting
h = PF.convolution(x, 1, (3, 3), (1, 1), (1, 1))
pred = PF.affine(h, 2)
loss = F.mean(F.softmax_cross_entropy(pred, y))

loss.forward()
# AllReduce during backward
loss.backward(communicator_callbacks = comm.all_reduce_callback([v.grad for v in nn.get_parameters().values()], 1024 * 1024 * 2))
allreduce(self, bool division=False, bool inplace=False)

この関数の利用は非推奨です。代わりに all_reduce を参照してください。

与えられたパラメータに対して、 All reduce を実行します。現在、allreduce は勾配領域に適用されます。

パラメータ:
  • division (bool) -- all_reduce した値を 与えられた contexts 数、または通信を行うデバイス数で割るかどうかを決定するフラグです。

  • inplace (bool) -- パック配列を使用するためのフラグです。デフォルトは false です。true の場合、メモリ効率は良いですが、処理が遅くなります。false の場合、メモリ効率は良くありませんが高速です。どちらの場合も、最終的な処理結果は同じメモリ領域に得られます。

barrier(self)

Communicator のすべてのプロセスがこのルーチンに到達するまで、それぞれのプロセスをブロックします。

bcast(self, data, int src, bool inplace=False, string group='world')

異なるデバイスにデータをブロードキャストします。

パラメータ:
  • data (NdArray or list of NdArray) --

  • src (int) -- データがブロードキャストされるソースランク。

  • inplace (bool) -- パック配列を使用するためのフラグです。デフォルトは false です。true の場合、メモリ効率は良いですが、処理が遅くなります。false の場合、メモリ効率は良くありませんが高速です。どちらの場合も、最終的な処理結果は同じメモリ領域に得られます。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

例:

# Run like `mpirun -n 2 python <code_snippet.py>`
# note: the order of the output to stdout are stochastic because of multiprocesses.

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# Data
x_list = [nn.Variable([2, 2]), nn.Variable([2, 2])]
print("Before the collective ({}-th)".format(comm.rank))
for x in x_list:
    x.d = np.random.rand(*x.shape)
    print(x.d)

# Bcast
comm.bcast([x.data for x in x_list], src=0, inplace=True)

# Check
print("After the collective ({}-th)".format(comm.rank))
for x in x_list:
    print(x.d)
clear_context_parameters(self)

登録されているすべてのコンテキストとパラメータを消去します。

find_group(self, group)

グループ内のランクのリストを返します。グループが存在しない場合は、空のリストが返されます。

パラメータ:

group (str) -- グループの名前です。

戻り値:

ランクのリストです (int) 。

戻り値の型:

ranks (list)

init(self)

Communicator を初期化します。

マルチスレッドまたはマルチプロセスにより、 Initall または initrank となります。必ず 先にadd_context_and_parameters によって通信を行う全てのパラメータを追加してから、この関数を実行してください。

list_groups(self)
戻り値:

グループ (str) のランク (list)。

戻り値の型:

groups (dict)

local_rank

Communicator のローカルランクを取得します。

name

Communicator の名前を取得します。

new_group(self, name_ranks)
パラメータ:

name_ranks (tuple) -- 名前のタプル (str) とランク (list)。

戻り値:

グループ名 (str)。

例:

# Communicator and Context
extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# New group
group = comm.new_group("node0", [0, 1, 2, 3])
rank

Communicator のランクを取得します。

reduce(self, data, int dst, bool division=False, bool inplace=False, string group='world')

異なるデバイスのデータに対して、 reduce を実行します。

パラメータ:
  • data (NdArray or list of NdArray) --

  • dst (int) -- 結果が保存される送り先のランク。

  • division (bool) -- all_reduce した値を 与えられた contexts 数、または通信を行うデバイス数で割るかどうかを決定するフラグです。

  • inplace (bool) -- パック配列を使用するためのフラグです。デフォルトは false です。true の場合、メモリ効率は良いですが、処理が遅くなります。false の場合、メモリ効率は良くありませんが高速です。どちらの場合も、最終的な処理結果は同じメモリ領域に得られます。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

例:

# Run like `mpirun -n 2 python <code_snippet.py>`
# note: the order of the output to stdout are stochastic because of multiprocesses.

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# Data
x_list = [nn.Variable([2, 2]), nn.Variable([2, 2])]
print("Before the collective ({}-th)".format(comm.rank))
for x in x_list:
    x.d = np.random.rand(*x.shape)
    print(x.d)

# Reduce
comm.reduce([x.data for x in x_list], dst=0, inplace=True)

# Check
print("After the collective ({}-th)".format(comm.rank))
for x in x_list:
    print(x.d)
reduce_scatter(self, ndarray_list, ndarray, bool division=False, string group='world')

異なるデバイスのデータに対して、 Reduce scatter を実行します。

パラメータ:
  • ndarray_list (NdArray) -- さまざまなデバイスで削減されるデータのリスト。

  • ndarray (NdArray) -- 保存するデータ 。

  • group (string) -- グループの名前です。このグループは、集合が呼び出されるときに使用されます。

例:

# Run like `mpirun -n 2 python <code_snippet.py>`
# note: the order of the output to stdout are stochastic because of multiprocesses.

# Communicator and Context
import numpy as np
import nnabla as nn
import nnabla.communicators as C
from nnabla.ext_utils import get_extension_context

extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()

# Data
x_list = [nn.Variable([2, 2]), nn.Variable([2, 2])]
y = nn.Variable([2, 2])
print("Before the collective ({}-th)".format(comm.rank))
for x in x_list:
    x.d = np.random.rand(*x.shape)
    print(x.d)

# ReduceScatter
comm.reduce_scatter([x.data for x in x_list], y.data)

# Check
print("After the collective ({}-th)".format(comm.rank))
print(y.d)
size

Communicator のサイズを取得します。

Communicator の一覧

nnabla.communicators.MultiProcessDataParalellCommunicator()

MultiProcessDataParallelCommunicator(CContext ctx)

分散学習のためのマルチプロセスデータ並列 Communicator。

パラメータ:

context (Context) -- この Communicator で使用されるコンテキスト。

例:

マルチプロセスデータ並列分散学習を行う例を以下に示します。

# Communicator and Context
extension_module = "cudnn"
ctx = get_extension_context(extension_module)
comm = C.MultiProcessCommunicator(ctx)
comm.init()
n_devices = comm.size
mpi_rank = comm.rank
device_id = comm.local_rank
ctx.device_id = str(device_id)
nn.set_default_context(ctx)

# Network and Solver created here

...


# Training loop
for itr in range(num_itr):
    # Forward, zerograd, backward
    loss.forward()
    solver.zero_grad()
    loss.backward()

    # Allreduce
    comm.all_reduce([v.grad for v in nn.get_parameters().values()])

    # Update
    solver.update()