やってみた!

やってみた!

試したことを中心に、書評や興味のあること、思ったこととか

PyBullet-HumanoidFlagrunHarderBulletEnv-v0(3)

 今回はSoft Actor-Critic(SAC)について備忘録ということで解説します。以下の論文は初期のSACをさらに改良したものです。

1.深層強化学習の理解に必要な事項の整理

 まずは、各記号、考え方の整理など。

 f:id:akifukka:20200125111703j:plain

 行動価値関数が最大になる行動を出力する方策関数を学習するのが目的です。
行動価値関数と方策関数をニューラルネットワークで表現します。

①実際に乱数で試行錯誤した経験(状態st、行動at、報酬r、次の状態st+1)から、Bellman方程式を使って行動価値関数Qを学習させます。この時にBellman方程式中で経験記録に無い値、すなわち次回の行動at+1が出てくるため、これはその時最新の方策関数を使って推定します。

②行動価値関数を更新したら、その時の経験を使ってQを最大化する方策関数を学習させます。

①と②を繰り返して行動価値関数Qと方策関数πのニューラルネットワークを学習させます。

2.Soft Actor Critic(SAC)

 SACは初期の論文のもの(Open AIのSpining Upのやり方)と、後から発表された改良型のものがありますが、ここでは改良型について説明します。

  • 報酬と方策のエントロピー(ばらつきの大きさ)の将来に渡る総合計を最大化する方策を求めます。エントロピーが加わるところが従来と異なるところです。より大きなエントロピーを持つ方策(方策がばらついても結果(報酬)が良い方策分布)を学習させます。不安定な極所的最適解が排除され、学習が安定して進むと思われます。

    f:id:akifukka:20200125113831j:plain 

  • 行動価値関数Qを報酬の将来総和と次回ステップ以降のエントロピー将来総和で定義すると、確率的なBellman方程式は次のように書けます。

    f:id:akifukka:20200125120343j:plain

     学習時は0.5×(左辺-右辺)の2乗をLoss関数として定義し、Loss関数が最小になる様、ネットワークの重みを更新します。ソースリストの次の箇所で計算しています。なお後述するテクニックを使っています。

    ①右辺の計算 next_actions:as+1、next_logp_pis:t+1のエントロピー×ー1
    #Qターゲット

        #Qターゲット
        next_mus,next_actions,next_logp_pis = policy_net(next_state_batch)
        next_q1s = q1_target_net(torch.cat([next_state_batch, next_actions], 1))
        next_q2s = q2_target_net(torch.cat([next_state_batch, next_actions], 1))
        next_qs = torch.min(next_q1s, next_q2s)
    
        q_targets = reward_batch + GAMMA * (1.0-done_batch) * (next_qs - alpha * next_logp_pis)
    

    ②Loss関数の計算

    
        #Q1ネット
        #loss関数
    #    q1_loss = 0.5 * F.mse_loss(q1_net(torch.cat([state_batch, action_batch],1)),q_targets) 中身がわかるような記述に変更
        q1_loss = 0.5 * torch.mean((q1_net(torch.cat([state_batch, action_batch],1)) - q_targets)**2)
        #ネットの学習
        q1_optimizer.zero_grad()
        #誤差逆伝搬
        q1_loss.backward(retain_graph=True)
        #重み更新
        q1_optimizer.step()
    
  • DDPGの改良型であるTD3で使われたテクニックclipped double-Qを使います。Qネットワークを2つ(Q1、Q2)それぞれ学習させて、小さい方を使うことでQの過大学習によるpolicyの破壊を予防します。
  • Bellman方程式を使ってQを学習する際Qが不安定にならないように、右辺の計算に使うQは別に定義(ターゲット関数)したものを使います。Qのターゲット関数はQに遅れて追従するようにします。
       tau = 0.005
        #q1,2 targetネットのソフトアップデート
        #学習の収束を安定させるためゆっくり学習するtarget netを作りloss関数の計算に使う。
        #学習後のネット重み×tau分を反映させ、ゆっくり追従させる。
        for target_param, local_param in zip(q1_target_net.parameters(), q1_net.parameters()):
          target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
    
        for target_param, local_param in zip(q2_target_net.parameters(), q2_net.parameters()):
          target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
  • 方策(policy)はSquashed Gaussian Policyを使います(tanhで-1~1の範囲に押しつぶしたGaussian Policy関数の意味です)。最初に通常のGaussian policiesを計算し、その後squashします。

    ある状態stを入力値として、行動の平均 μ(st)、log分散 log σ(st)をニューラルネットワークで推定、平均にノイズを加算して行動を導きます。

    f:id:akifukka:20200125112832j:plain
    エントロピーを次の式で計算した後、squashします。

    f:id:akifukka:20200125144804j:plain

  • Qの学習結果をpolicyに反映します。Qは最初のエントロピーHを含んでいないので、Qにエントロピーを加えたものが最大になるようなpolicy πを学習させます。実際はLoss関数 Jπ =-1×(Qにエントロピーを加えたもの)と定義し、Loss関数が最小になるようにpolicy πを学習させて上記と同じことをしています。

    f:id:akifukka:20200125160941j:plain
        #policyネットのloss関数
        mus,actions,logp_pis = policy_net(state_batch)
        q1s = q1_net(torch.cat([state_batch,actions],1))
        q2s = q2_net(torch.cat([state_batch,actions],1))
        qs = torch.min(q1s, q2s)
        p_loss = torch.mean(alpha * logp_pis - qs)
    
        #policyネットの学習
        p_optimizer.zero_grad()
        #誤差逆伝搬
        p_loss.backward(retain_graph=True)
        #重み更新
        p_optimizer.step()
    
  • あらかじめエントロピーの目標値を決め、その値になるようにtemperature parameter α(alpha)を自動調整します。エントロピーの目標値は、以下の論文によると単純にactionの次元1つあたりー1として、-1×actionの要素数(次元)としたとのことです。

    [1812.11103] Learning to Walk via Deep Reinforcement Learning

    Loss関数Jαを次のように定義し、Jαが小さくなるようαを更新します。

    f:id:akifukka:20200125175648j:plain

    ①Ht > H (目標エントロピーの方が大)
    αを増やすとJαが減るのでαの更新でαが増加します。αが増加すると、Qの学習時にエントロピーの影響が大きくなり、エントロピーは増加します。
    ②Ht < H (目標エントロピーの方が小)
    αを減らすとJαが減るのでαの更新でαが減少します。αが減少するとQにおけるエントロピーの影響が小さくなり、学習が進むとエントロピーは減少していきます。
        alpha_loss = -torch.mean(log_alpha*(target_entropy + logp_pis))
        #log_alpha更新
        alpha_optimizer.zero_grad()
        #誤差逆伝搬
        alpha_loss.backward(retain_graph=True)
        #log_alpha更新
        alpha_optimizer.step()
        #alpha更新
        alpha = log_alpha.exp()

3.パラメータについて

 パラメータについて気付いたことをまとめておきます。

  • BATCH_SIZE = 256
    学習に使う経験数。以外に影響が大きく、少ないと学習が安定せず進まなく(報酬が飽和)します。humanoidの場合、128では少なすぎでした。少ないとQが滑らかに更新されず、いびつになってしまうのではないかと想像しています。
  • GAMMA = 0.99
    報酬に対する割引率。具体的な影響については未経験。
  • lr=3e-4(optim.Adam)
    ニューラルネットワーク更新時の大きさの係数。大きいと更新後のネットワークがいびつになり、学習が不安定になりやすい。小さいと学習に時間がかかる。
  • tau = 0.005
    ターゲット関数の更新係数。この係数分だけ最新のネットワークの重みに変わる。変更したことがないので影響については何とも言えないが、大きすぎると学習が不安定になりやすいと考えられます。

以上、今回はSACについてまとめてみました。

つづく

 強化学習 カテゴリーの記事一覧 - やってみた!