強化学習

Stable Baselinesとは

概要

Open AIのBaselinesの強化学習を扱い安くしたしたライブラリです。Baselinesは研究として利用する上では問題なく利用できますが、例えば画像から特徴量を学習するときなどはコメントアウトされている部分をコメント消すような修正が必要となったりと地獄を見ることがあります。こちらのStable Baselinesは地獄に落ちないためのライブラリとなります。(ベイマックスがかわいいです)

最新のPytorchを利用したStable Baselines3がこちら。

学習/推論

例えばPPOの実装はこのような簡単なコードで学習と推論まで実装することができます。

# from https://github.com/hill-a/stable-baselines
import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2

env = gym.make('CartPole-v1')

model = PPO2(MlpPolicy, env, verbose=1)
# Train the agent
model.learn(total_timesteps=10000)

# Enjoy trained agent
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=False)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        obs = env.reset()
env.close()

評価

以下がtutorialを参照した評価関数です。env.stepを実行して得たrewardの平均で評価できるとわかります。

def evaluate(model, num_episodes=100):
    """
    Evaluate a RL agent
    :param model: (BaseRLModel object) the RL Agent
    :param num_episodes: (int) number of episodes to evaluate it
    :return: (float) Mean reward for the last num_episodes
    """
    # This function will only work for a single Environment
    env = model.get_env()
    all_episode_rewards = []
    for i in range(num_episodes):
        episode_rewards = []
        done = False
        obs = env.reset()
        while not done:
            # _states are only useful when using LSTM policies
            action, _states = model.predict(obs)
            # here, action, rewards and dones are arrays
            # because we are using vectorized env
            obs, reward, done, info = env.step(action)
            episode_rewards.append(reward)

        all_episode_rewards.append(sum(episode_rewards))

    mean_episode_reward = np.mean(all_episode_rewards)
    print("Mean reward:", mean_episode_reward, "Num episodes:", num_episodes)

    return mean_episode_reward

このような関数を自作しなくてもstable baselinesに評価関数evaluate_policyが既に存在します。学習済みのmodelを入れることでmodelを評価可能です。

from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

モデルの保存/読み込み

医科のように簡単に学習済みモデルの保存と読み込みができます。

import os
from stable_baselines3.common.vec_env import DummyVecEnv

# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = A2C('MlpPolicy', 'Pendulum-v0', verbose=0, gamma=0.9, n_steps=20).learn(8000)
# The model will be saved under A2C_tutorial.zip
model.save(save_dir + "/A2C_tutorial")

del model # delete trained model to demonstrate loading

# load the model, and when loading set verbose to 1
loaded_model = A2C.load(save_dir + "/A2C_tutorial", verbose=1)

# show the save hyperparameters
print("loaded:", "gamma =", loaded_model.gamma, "n_steps =", loaded_model.n_steps)

# as the environment is not serializable, we need to set a new instance of the environment
loaded_model.set_env(DummyVecEnv([lambda: gym.make('Pendulum-v0')]))
# and continue training
loaded_model.learn(8000)

独自環境の生成

Gymによって生成された環境も簡単にカスタムすることができます。

class CustomWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  """
  def __init__(self, env):
    # Call the parent constructor, so we can access self.env later
    super(CustomWrapper, self).__init__(env)

  def reset(self):
    """
    Reset the environment 
    """
    obs = self.env.reset()
    return obs

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    obs, reward, done, info = self.env.step(action)
    return obs, reward, done, info

学習データと推論データ

実際に利用する時に少し戸惑うのは、学習データはenvで学習したとしてどこに推論データを入れればいいのかというところです。上述のコードでは環境に利用するデータは同じで推論前に一度resetしています。この部分を推論対象のenvにすり替えれば学習ができるはずです。

obs = env.reset()

本ページのコードは学習と推論を同じファイルに記載していますが、実際は学習/モデル保存とモデル読み込み/推論で別のファイルに作ればよいかなと思います。

参考

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です