(本記事は、2016年インターンシップを経て現在はアルバイトとして勤務されている包さんによる寄稿です)
はじめまして。Preferred Networksの分散深層学習チームでアルバイトをしている包です。私は分散深層学習の中でも主にモデル並列に関する機能実装を行っています。今回はモデル並列性の概要と、ChainerMNにおいてどのようにモデル並列性を実現しているのかについて紹介します。
分散深層学習: データ並列性とモデル並列性
深層学習における各種フレームワークは目覚ましい発展を遂げ続けており、最近では一般ユーザーでも簡単に複数GPUを用いたニューラルネットの訓練ができるようになってきました。たとえば、ChainerMNではoptimizerの定義にほんの数行加えるだけでニューラルネットを複数GPUで訓練できます[1]。これにより1024GPU上でImageNetによるResNet-50の学習を15分で行うなどの実績を上げています[2]。このような複数プロセス、複数ノードを用いた分散深層学習によってニューラルネットの訓練は高速に行えるようになっており、分散深層学習は現在の深層学習の基盤を支えているといえます。
![]()
ところで、「分散深層学習」にはデータ並列とモデル並列という2通りのアプローチがあることが知られています[3]。データ並列では、全プロセスに同じモデルのコピーして訓練することでバッチサイズをプロセス数倍し、学習を高速化させる手法です。先程お話したImageNetの並列訓練もデータ並列による高速化の一例です。一方でモデル並列とは、1つのモデルを分割して複数のプロセスに配置し、全プロセスで協調して1つのモデルを訓練する手法です。主なユースケースとしては超解像度を入力とするCNNやMixture of Experts[4]など、1プロセス上に載りきらないサイズのモデルを訓練したい場合に用いられます。最近ではMesh-Tensorflow[5]というTensorflow用のモデル並列ライブラリが公開されましたが、現状ではモデル並列をサポートしているフレームワークは非常に少ないです。
この記事では、ChainerMNに実装されているモデル並列APIを、実例を交えて紹介します。特に、Define-by-Runとともにモデル並列を実現する際に発生する問題と、その解決方法について重点的にお話をします。
ChainerMNにおけるモデル並列性の実現
ChainerMNでは、通信をChainerの関数呼び出しによって定義 します。これにより非常に柔軟な通信パターンを実現することができます。
![]()
図1: 関数呼び出しによる通信の定義の例
ChainerMNにおける通信はMPIを用いて実現されており、モデル並列でも基本的にMPIの通信スタイルを踏襲しています。MPIでは大きく分けて MPI_Send
を始めとした1対1通信と、 MPI_Bcast
のような集団通信向けのAPIが提供されています。ChainerMNでは、これらの通信APIと対応するように chainermn.functions.send
や chainermn.functions.bcast
のように、Chainerの関数を提供しています。通信用の関数は、それぞれbackwardにおいて「勾配を逆向きに通信」するように設計されています。例えば bcast
の場合、forward計算ではmasterからslaveに対して入力変数がbroadcast通信されます。一方で、backward計算ではslaveからmasterに対して勾配をallreduceします。
ChainerMNに実装されているforward通信に対応するbackwardの通信パターンは以下のようになります。
表: forward と backward における通信パターンの対応
forward |
backward |
allgather |
allgather |
alltoall |
alltoall |
bcast |
allgather |
gather |
scatter |
scatter |
gather |
send |
recv |
recv |
send |
次に、モデル並列APIの具体的な使い方について見ていきます。まず、データ並列の際と同様に、通信を行うためのコミュニケータを作成します。
comm = chainermn.create_communicator()
例えば、図1のようなモデルの実装イメージは次のようになります(図1のモデルに特に意味はありません)。
class ExampleModel(chainer.Chain):
def __init__(self):
self.comm = chainermn.create_communicator()
self.conv = L.Convolution2D(...)
def forward(self, x):
x = chainermn.functions.bcast(self.comm, x)
h = self.conv(x)
y = F.relu(h)
ys = chainermn.functions.gather(self.comm, y)
...
この例では、masterからブロードキャストされた変数が Convolution2D
の入力になります。一方で、backward計算の際には、 Convolution2D
の勾配が自動的に Bcast
のbackwardによってmasterへ集約されます。
ChainerMNに用意されているAPIの詳細については、ドキュメントを参照してください[6]。なお、モデル並列関連のAPIに関しては現状では実験段階なので、将来的に後方互換でないAPIの変更が起こる可能性があります。
Define-by-Runにおける注意点(その1)
ChainerをはじめとしたDefine-by-Runによる計算グラフの定義はモデルを直感的に記述することができる点で優れているといえます。backward計算時には、出力変数からグラフのバックトラックを行うことによってパラメータの更新を行うことができます。しかし、モデル並列を実現するために上述のように通信を関数として定義すると、計算グラフが正しくバックトラックできない状況が発生します。
例えば、下記のような2つのプロセス間におけるシンプルな1対1通信の例を考えます。
![]()
図2: 1対1通信の例
「Process #0」に注目してみると、出力変数 y
からバックトラックを行ったときに、 recv
から send
へ戻ることができません。その結果、「Process #1」は recv
のbackward(すなわち勾配のsend)を呼んでいるにもかかわらず、「Process #0」は send
のbackward(すなわち勾配のrecv
)を呼ぶことができず、デッドロックが発生します。このような状況は、1つのプロセス上における計算グラフが非連結になっているときに生じます。そのため、 send
関数が戻り値として返す特別な変数を recv
に渡すことによって、 計算グラフが連結になるようにモデルの定義を行います 。
![]()
図3: delegate variableによる計算グラフの連結化
このような send
と recv
を繋ぐような send
関数の戻り値を、便宜的に「delegate variable」と呼ぶことにします。Delegate variableは「Process #0」においてグラフを連結にする役割を果たす他に、「Process #1」でもバックトラックの起点となるダミーの出力変数として振る舞います。図3をコードで記述すると以下のようになります。
class ExampleModel_0(chainer.Chain):
def forward(self, x):
# first component
z = f(x)
phi = chainermn.functions.send(z, comm, rank=1)
# second component
z = chainermn.functions.recv(comm, rank=1, delegate_variable=phi)
y = h(z)
return y
class ExampleModel_1(chainer.Chain):
def forward(self, _):
z = chainermn.functions.recv(comm, rank=0)
z = g(z)
phi = chainermn.functions.send(z, comm, rank=0)
return phi
Define-by-Runにおける注意点 (その2)
先程の節ではグラフが非連結になると計算グラフのバックトラックができない例を1つ挙げました。このような例は他にも存在します。
![]()
図4: 1対1通信を2回呼ぶ例
図4では、1対1通信が2回発生しています。この場合、「Process #0」における send が返す2つのdelegate variableを適切に処理する必要があります。そこで、以下のように2つの変数を1つにまとめる処理を行います。
![]()
図5: pseudo_connectを用いた例
chainermn.functions.pseudo_connect
という関数は、「delegate variableがあたかも別の変数であるかのように振る舞うような変数」を返す関数です。図5の例では、 \( \phi_1 \) というdelegate variableが実際には \( \phi_2 \) という別の変数として振る舞うような変数 \( \psi \) を返します。 \( \psi \) をバックトラックする際には、まず \( \phi_1 \) のバックトラックを行い、次に \( \phi_2 \) のバックトラックを行います。このようにして、backward計算の際に2つのdelegate variableを正しくトラックバックすることができます。図5をコードで記述すると次のようになります。
class ExampleModel_0(chainer.Chain):
def forward(self, x):
z1, z2 = f(x)
phi1 = chainermn.functions.send(z1, comm, rank=1)
phi2 = chainermn.functions.send(z2, comm, rank=1)
psi = chainermn.functions.pseudo_connect(phi1, phi2)
return psi
class ExampleModel_1(chainer.Chain):
def forward(self, _):
z1 = chainermn.functions.recv(comm, rank=0)
z2 = chainermn.functions.recv(comm, rank=0)
y = g(z1, z2)
return y
図5では pseudo_connect
で2つのdelegate variableをまとめましたが、次の図6のように通常の変数にdelegate variableを結合することも可能です。
![]()
図6: delegate variableと通常の変数を結合する例
y_ = chainermn.functions.pseudo_connect(phi, y)
以上がChainerMNにおけるモデル並列の概要になります。次に、実際のモデルの例を見てみます。
1対1通信を用いた例: encoder-decoderモデル
Encoder-decoderモデル[7]は可変長の入力を可変長の出力に変換することを目的としたモデルで、自然言語処理をはじめとした応用分野で広く用いられています。
Chainerのexampleにも機械翻訳の例があります[8]。Encoder-decoderモデルの入力や出力に画像を用いるようなモデルの場合、CNNをencoderやdecoderに用いることになりますが、層数やパラメータ数が膨大なencoderやdecoderになると、全体のモデルが1GPUに載らないケースが発生します。その場合、モデルをいくつかに分割して複数プロセスでモデル並列学習を行うことによって学習できます。例えば、下図のようにencoderとdecoderにそれぞれ1プロセスずつ割り当てるような分割が考えられます。
![]()
図7: encoder-decoderのモデル並列化
ここでは、はじめのプロセスでencodeしたcontext vectorをdecoderへ送信して、decoder側のプロセスでdecodeするように分割を行っています。例えばLSTMの場合はcontext vectorが2つあるので、2回の1対1通信を行うことで実現できます。ただし、図5の例と同様に、encoder側では pseudo_connect
を用いてdelegate variableを1つにまとめる必要があることに注意してください。基本的には send
、 recv
と pseudo_connect
を用いれば実装することができますが、encoder-decoderモデルの分割は実装が煩雑になるので、専用のLinkを用意しています。
rnn = chainermn.links.create_multi_node_n_step_rnn(
L.NStepLSTM(n_layers, n_units, n_units, 0.1),
comm, rank_in=None, rank_out=1)
create_multi_node_n_step_rnn
は、Chainerで提供されている NStepRNN
[9](可変長系列をまとめて入出力するAPI)をラップして、内部で別のプロセスと自動的にcontext vectorを送受信します。rank_in
に指定したプロセスからcontext vectorを受信し、 rank_out に指定したプロセスに対してcontext vectorを送信します。これを用いると、次のようにモデル並列なencoder-decoderモデルを簡単に実装することができます。
class Encoder(chainer.Chain):
def __init__(self, comm, n_layers, n_units):
super(Encoder, self).__init__(
# Corresponding decoder LSTM will be invoked on process 1.
mn_encoder=chainermn.links.create_multi_node_n_step_rnn(
L.NStepLSTM(n_layers, n_units, n_units, 0.1),
comm, rank_in=None, rank_out=1
),
)
self.comm = comm
self.n_layers = n_layers
self.n_units = n_units
def forward(self, *xs):
exs = f(xs)
c, h, _, phi = self.mn_encoder(exs)
return phi
class Decoder(chainer.Chain):
def __init__(self, comm, n_layers, n_units):
super(Decoder, self).__init__(
# Corresponding encoder LSTM will be invoked on process 0.
mn_decoder=chainermn.links.create_multi_node_n_step_rnn(
L.NStepLSTM(n_layers, n_units, n_units, 0.1),
comm, rank_in=0, rank_out=None),
)
self.comm = comm
self.n_layers = n_layers
self.n_units = n_units
def forward(self, *ys):
c, h, os, _ = self.mn_decoder(ys)
...
この例はChainerMNのexampleに公開されています[10]。
集団通信を用いた例: チャネル方向の並列化
集団通信を用いると、下図のようにCNNのチャネル方向の並列化が実現できます。この並列化は高解像度画像を扱う際や、バッチサイズを大きくする際に有用です。
![]()
図8: チャネル方向の並列化
各プロセスはCNNのチャネルのうち一部だけを入力としてとって畳み込みを行うので、各々のプロセス上のCNNのパラメータ数を減らすことができます。CNNの出力に対して “allgather“ を用いることで、全チャネルを集約することができます。実装のイメージは以下のようになります。
class ParallelConvolution(chainer.Chain):
def __init__(self, comm, in_channels, out_channels):
self.comm = comm
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = L.Convolution2D(...)
@property
def _indices(self):
# index % comm.size == comm.rankとなるインデックスのチャネルを担当
# 例 (size=4, rank=1の場合): _indices = [1, 5, 9, ...]
idx = numpy.arange(self.in_channels)
return idx[idx % self.comm.size == self.comm.rank]
def forward(self, x):
# 当該プロセスの担当チャネルをスライス
x = x[:, self._indices, :, :]
y = self.conv(x)
# 全チャネルを集約
ys = chainermn.functions.allgather(self.comm, y)
return F.concat(ys, axis=1)
この例はchainerMNのexampleに公開されています[11]。
まとめ
本記事では、ChainerMNにおけるモデル並列の実現と、実際の例をいくつか紹介しました。特に、Defined-by-Runの下では計算グラフが連結でなければならないため、delegate variableやpseudo_connect
などのテクニックが必要になります。今回はスペースの都合で紹介がかないませんでしたが、特定のタイプのモデル向けによりシンプルにモデルを定義できるようなAPI( MultiNodeChainList
, MultiNodeNStepRNN
)も用意されているので、お手軽に試してみたい方はぜひドキュメント[6]をご覧ください。
参考文献