Open AI Gym Box2D BipedalWalkerをColaboratoryで動かしてみる(5)
前回はDDPG(Deep Deterministic Policy Gradient)でMountainCarContinuousに挑戦し、無事学習して山登りに成功しました。(BipedalWalkerは手強いので後回しです・・・)
今回は中身について、ざっくりですが解説してみます。
1.DDPG(Deep Deterministic Policy Gradient)
DQNでは、ある環境の時に行動a1,a2・・・を取った時の価値Q1,Q2・・(未来に得られる報酬の合計)を推定するQネットワークを学習します。ここで環境やQxは連続値を取ります。学習が完了したら、Q1,Q2・・・を比較して最も大きな価値の行動axを選択します。ネットワークの構造上、行動a1,a2・・・は離散的になります。点数を限りなく増やせば連続値を模擬できますが、あまり現実的ではありません。
これに対しDDPGでは行動aもQ値を計算する入力にします。ここでの行動aは連続値です。ただ、このままだと学習後にどの行動aを選ぶかといった時に、aの値を変化させて価値Qが最大になるところを探し出す必要があります。
じゃあということで、最初から価値Qが最大になる行動aを推定するpolicyネットワークを追加しちゃえばというのがDDPGです。
policyネットワークはQネットワークの値が最大になる様、Qネットワークを損失関数に使って学習させます。
2.Qネットワーク(Q(s,a))の学習
'Bellman equation'という考え方を使います。
価値関数Qを、これからもらえる報酬の総和で定義します。Qの引数は1項の図を見てわかるようにobservationの状態sと行動aです。
ただし、そのまま報酬を未来永劫に渡って足し続けるとQが無限大に発散してしまうので、報酬に割引率を乗じて意味のある数字に収束するようにします(①)。すると、Qは次のステップQ(st+1,at+1)と割引率γを使って書くことができます(②)。
次に左辺=0となるように変形します(③)。で、2乗して損失関数Lを定義して(④)、損失関数が最小になるように学習すればQを学習できることになります。
ちなみに時間tでゴールした場合は、それ以降は報酬をもらえない、すなわちQt+1=0になるはずなので、その模擬も含めて損失関数は最終的に次のように書けます。
赤字はtarget関数です。これは価値関数Qをそのまま使うと、学習で自身の更新の影響を受け安定して収束しなくなるのを防ぐため、価値関数Qより少し遅れてゆっくり変化するようにしたものです。
割引率は0.999等、1に近い値を使います。1に近ければ近いほど遠い将来の報酬の影響を受け、値が小さければ比較的短い時間範囲の報酬までしか影響を受けません。
Qネットワークは上記の損失関数(loss関数)が最小になるように学習させます。
環境st、その時に取った行動at、その結果の環境st+1は実際に1ステップ動かした時の値(経験値)を使えます。1ステップの経験値からat+1の情報は得られないので、後で説明する行動を予測するpolicyネットワークを使って計算します。この時使用するpolicyネットワークも実際のpolicyネットワークに遅れて学習するtarget関数です。
ソースリストのQネットワークの学習はdef optimize_model()中で行っています。該当する箇所を次に示します。
#Qネットのloss関数計算
next_actions = policy_target_net(next_state_batch)
next_Q_values =q_target_net(next_state_batch ,next_actions)
expected_Q_values = (next_Q_values * GAMMA)*(1.0-done_batch) + reward_batch
Q_values = q_net(state_batch ,action_batch)
#Qネットのloss関数
q_loss = F.mse_loss(Q_values,expected_Q_values)
#Qネットの学習
q_optimizer.zero_grad()
#誤差逆伝搬
q_loss.backward()
#重み更新
q_optimizer.step()
F.mse_loss()は差の2乗平均を算出する関数で、q_optimizer.zero_grad()~q_optimizer.step()で学習します。
なお、複数の経験(ステップ)を一括して学習するため、環境s、行動a等は複数の経験分をstack(テンソルを複数まとめる)しています。
3. policyネットワーク(policy(s))の学習
policyネットワークはQが最大となる行動を推測するよう学習させます。pytorchは損失関数が最小になるよう学習させるため、Q値に-1を乗じた損失関数を定義、値が最小になるよう学習させます。
該当する箇所を次に示します。policyネットワーク(Target)はQネットの学習、経験の取得に使用するため、Qネットと交互に学習させ、並行して学習させます。
#policyネットのloss関数
actions = policy_net(state_batch)
p_loss = -q_net(state_batch,actions).mean()
#policyネットの学習
p_optimizer.zero_grad()
#誤差逆伝搬
p_loss.backward()
#重み更新
p_optimizer.step()
4.Target関数のソフトアップデート
Target関数は前述の通り実際のQネット、policyネットにゆっくり追従するネットワークで、それぞれQネット、policyネットを学習した後、次のように更新します。
tau = 0.001
#targetネットのソフトアップデート
#学習の収束を安定させるためゆっくり学習するtarget netを作りloss関数の計算に使う。
#学習後のネット重み×tau分を反映させ、ゆっくり追従させる。
for target_param, local_param in zip(policy_target_net.parameters(), policy_net.parameters()):
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
for target_param, local_param in zip(q_target_net.parameters(), q_net.parameters()):
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
ネットワークの重みをそのままコピーするのではなく、tau(この場合は1/1000)だけ学習後の値を使い、残りの割合分は現状の値をそのまま使うことで、学習の効果がゆっくり浸透するようにしています(1次遅れと同じイメージ)。
5.リプレイメモリ
直前の経験だけを使って学習すると、直近のpolicyが出す行動に偏りQネットの学習もその領域に偏ってしまうため、過去の経験をメモリにためておき満遍なく学習するようにしています。リプレイメモリは有限の大きさにしておき、あまりに古くて使えない経験は新しい経験に上書きされます。
6.ノイズ
経験を取得するのにpolicyの出力する行動を使うと、行動が偏ってしまい、学習もpolicyが出力する行動のごく狭い範囲に留まってしまいます。そのため極所最適解に陥り、他に良い行動があっても発見できまないまま終わってしまいます。
DDPGでは、より良い行動を探すため、policyが出力する行動にノイズを混ぜて幅広く経験させる様、工夫されています。
7.ネットワーク定義
今回はQネット、policyネット共に4層、隠れ要素128のネットワークにしてみました。層数、隠れ要素数ともに特にこの値にした理由があるわけではないので、変更して試してみるのもいいと思います。
8.まとめ
DDPGのアルゴリズムの概要を纏めました。学習を効率良く行うには、報酬の与え方(今回はOpen AI Gymなので変更しない)、割引率、ノイズの大きさ、ネットワーク構造といったあたりが鍵になると思われます。
次回は再びBipedalWalkerを動かします。