Approximate Q-Learning#

Минусы Q-Learning#

У Q-Learning есть несколько минусов:

  • Q-Learning работает работать только с конечным пространством состояний и действий (мы должны сопоставить число каждой паре состояние-действие)

  • Q-Learning плохо работает для конечных, но больших сред - каждое состояние посещается мало раз, и \(q\) плохо выучивается

  • Q-Learning для больших сред хранит \(\mathcal{A}\times\mathcal{S}\) чисел, что невероятно много (для шахмат \(\mathcal{S} \approx 10^{120}\))

Нейронные сети#

Как мы уже знаем, нейронные сети умеют приближать, фактически, любые зависимости в данных. Давайте натренируем нейронную сеть приближать \(q(s, a)\).

Но для упрощения задачи, давайте подавать сети на вход \(s\) и на выходе получать \(q_\theta(s, a)\) для всех \(a \in \mathcal{A}\) (За \(\theta\) мы обозначим параметры сети):

image.png

Среда и архитектура#

Воспользуемся Approximate Q-Learning для решения окружения CartPole-v0 из библиотеки gym

import gym
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
env = gym.make("CartPole-v0",  render_mode="rgb_array").env
env.reset()
n_actions = env.action_space.n          # количество возможных действий в среде
state_dim = env.observation_space.shape[0] # количество наблюдаемых в среде

plt.imshow(env.render())
env.close()
/home/rfrelpe/.local/lib/python3.8/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment CartPole-v0 is out of date. You should consider upgrading to version `v1`.
  logger.warn(
../../_images/13b36af4ca5c0a33bf3c34fbbb279cd5a4b947429e7d3b8e03a6e38cdf57c157.png

В этом окружении мы управляем тележкой, на которой на шарнире закреплен стержень. Тележка может ездить вправо и влево. Получаемая нами информация о состоянии среды - угол и угловая скорость стержня, скорость тележки. НАграду мы получаем за то, что стержень стоит и не падает.

Напишем нейронную сеть:

import torch
import torch.nn as nn
import torch.nn.functional as F
network = nn.Sequential(
    nn.Linear(state_dim, 20), # размер входа - количество наблюдаемых
    nn.ReLU(),
    nn.Linear(20, 20),
    nn.ReLU(),
    nn.Linear(20, n_actions), # размер выхода - количество действий
)

Напишем код для \(\varepsilon\)-жадного выбора действия:

def get_action(state, epsilon=0):
    """
    sample actions with epsilon-greedy policy
    recap: with p = epsilon pick random action, else pick action with highest Q(s,a)
    """
    state = torch.tensor(state[None], dtype=torch.float32)
    q_values = network(state).detach().numpy()

    if np.random.rand() < epsilon:
      ch_action = np.random.randint(n_actions)
    else:
      ch_action = np.argmax(q_values)

    return int(ch_action)

Тесты для get_action

s = env.reset()[0]
assert tuple(network(torch.tensor([s]*3, dtype=torch.float32)).size()) == (
    3, n_actions), "please make sure your model maps state s -> [Q(s,a0), ..., Q(s, a_last)]"
assert isinstance(list(network.modules(
))[-1], nn.Linear), "please make sure you predict q-values without nonlinearity (ignore if you know what you're doing)"
assert isinstance(get_action(
    s), int), "get_action(s) must return int, not %s. try int(action)" % (type(get_action(s)))

# test epsilon-greedy exploration
for eps in [0., 0.1, 0.5, 1.0]:
    state_frequencies = np.bincount(
        [get_action(s, epsilon=eps) for i in range(10000)], minlength=n_actions)
    best_action = state_frequencies.argmax()
    assert abs(state_frequencies[best_action] -
               10000 * (1 - eps + eps / n_actions)) < 200
    for other_action in range(n_actions):
        if other_action != best_action:
            assert abs(state_frequencies[other_action] -
                       10000 * (eps / n_actions)) < 200
    print('e=%.1f tests passed' % eps)
/tmp/ipykernel_10950/3376964059.py:2: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:245.)
  assert tuple(network(torch.tensor([s]*3, dtype=torch.float32)).size()) == (
e=0.0 tests passed
e=0.1 tests passed
e=0.5 tests passed
e=1.0 tests passed

Обучение#

Обучим нейронную сеть, минимизируя разницу между оценкой \(q_{\theta}(s,a)\) и «улучшенным значением» \(r(s,a) + \gamma \cdot max_{a'} q_{-}(s', a')\):

\[ L = { 1 \over N} \sum_i (q_{\theta}(s,a) - [r(s,a) + \gamma \cdot max_{a'} q_{-}(s', a')]) ^2 \]

, где

  • \(s, a, r, s'\) - текущее состояние, действие, награда и следующее состояние

  • \(\gamma\) - фактор дисконтирования награды

  • \(q_{-} = q_{\theta}\)

Стоит отметить, что из соображений стабильности обучения, мы не вычисляем градиент ошибки по \(q_{-}\), хотя через них градиенты могут доходить до параметров нашей нейронной сети.

Для этого надо использовать x.detach(), которая отмечает, что по этому тензору ошибку распростанять не надо.

def compute_td_loss(states, actions, rewards, next_states, is_done, gamma=0.99, check_shapes=False):
    """ Compute td loss using torch operations only. Use the formula above. """
    states = torch.tensor(
        states, dtype=torch.float32)                                  # shape: [batch_size, state_size]
    actions = torch.tensor(actions, dtype=torch.long)                 # shape: [batch_size]
    rewards = torch.tensor(rewards, dtype=torch.float32)              # shape: [batch_size]
    # shape: [batch_size, state_size]
    next_states = torch.tensor(next_states, dtype=torch.float32)
    is_done = torch.tensor(is_done, dtype=torch.uint8)                # shape: [batch_size]

    # get q-values for all actions in current states
    predicted_qvalues = network(states)                               # shape: [batch_size, n_actions]

    # select q-values for chosen actions
    predicted_qvalues_for_actions = predicted_qvalues[                # shape: [batch_size]
      range(states.shape[0]), actions
    ]

    # compute q-values for all actions in next states
    predicted_next_qvalues = network(next_states).detach()

    # compute V*(next_states) using predicted next q-values
    next_state_values = gamma * predicted_next_qvalues.max(dim=-1).values
    assert next_state_values.dtype == torch.float32

    # compute "target q-values" for loss - it's what's inside square parentheses in the above formula.
    target_qvalues_for_actions = rewards + next_state_values

    # at the last state we shall use simplified formula: Q(s,a) = r(s,a) since s' doesn't exist
    target_qvalues_for_actions = torch.where(
        is_done, rewards, target_qvalues_for_actions)

    # mean squared error loss to minimize
    loss = torch.mean((predicted_qvalues_for_actions -
                       target_qvalues_for_actions.detach()) ** 2)

    if check_shapes:
        assert predicted_next_qvalues.data.dim(
        ) == 2, "make sure you predicted q-values for all actions in next state"
        assert next_state_values.data.dim(
        ) == 1, "make sure you computed V(s') as maximum over just the actions axis and not all axes"
        assert target_qvalues_for_actions.data.dim(
        ) == 1, "there's something wrong with target q-values, they must be a vector"

    return loss

Создами оптимизатор обучения:

opt = torch.optim.Adam(network.parameters(), lr=1e-4)

Напишем фукцию, взаимодействующую с окружением и обновляющую нейронную сеть:

def generate_session(env, t_max=10000, epsilon=0, train=False):
    """play env with approximate q-learning agent and train it at the same time"""
    total_reward = 0
    s = env.reset()[0]

    for t in range(t_max):
        a = get_action(s, epsilon=epsilon)
        next_s, r, done, *_ = env.step(a)

        if train:
            opt.zero_grad()
            compute_td_loss([s], [a], [r], [next_s], [done]).backward()
            opt.step()

        total_reward += r
        s = next_s
        if done:
            break

    return total_reward
epsilon = 0.9

Обучим сеть:

for i in range(100):
    session_rewards = [generate_session(env, epsilon=epsilon, train=True) for _ in range(100)]
    print("epoch #{}\tmean reward = {:.3f}\tepsilon = {:.3f}".format(i, np.mean(session_rewards), epsilon))

    epsilon *= 0.98
    assert epsilon >= 1e-4, "Make sure epsilon is always nonzero during training"

    if np.mean(session_rewards) > 300:
        print("You Win!")
        break
/tmp/ipykernel_10950/1648553857.py:30: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at ../aten/src/ATen/native/TensorCompare.cpp:493.)
  target_qvalues_for_actions = torch.where(
epoch #0	mean reward = 21.500	epsilon = 0.900
epoch #1	mean reward = 18.630	epsilon = 0.882
epoch #2	mean reward = 22.610	epsilon = 0.864
epoch #3	mean reward = 23.380	epsilon = 0.847
epoch #4	mean reward = 20.780	epsilon = 0.830
epoch #5	mean reward = 20.010	epsilon = 0.814
epoch #6	mean reward = 21.720	epsilon = 0.797
epoch #7	mean reward = 22.690	epsilon = 0.781
epoch #8	mean reward = 21.340	epsilon = 0.766
epoch #9	mean reward = 19.600	epsilon = 0.750
epoch #10	mean reward = 18.190	epsilon = 0.735
epoch #11	mean reward = 20.430	epsilon = 0.721
epoch #12	mean reward = 17.310	epsilon = 0.706
epoch #13	mean reward = 20.320	epsilon = 0.692
epoch #14	mean reward = 27.100	epsilon = 0.678
epoch #15	mean reward = 19.760	epsilon = 0.665
epoch #16	mean reward = 20.210	epsilon = 0.651
epoch #17	mean reward = 19.660	epsilon = 0.638
epoch #18	mean reward = 19.030	epsilon = 0.626
epoch #19	mean reward = 22.730	epsilon = 0.613
epoch #20	mean reward = 20.370	epsilon = 0.601
epoch #21	mean reward = 27.320	epsilon = 0.589
epoch #22	mean reward = 16.540	epsilon = 0.577
epoch #23	mean reward = 19.610	epsilon = 0.566
epoch #24	mean reward = 28.420	epsilon = 0.554
epoch #25	mean reward = 31.600	epsilon = 0.543
epoch #26	mean reward = 31.080	epsilon = 0.532
epoch #27	mean reward = 22.650	epsilon = 0.522
epoch #28	mean reward = 44.050	epsilon = 0.511
epoch #29	mean reward = 27.290	epsilon = 0.501
epoch #30	mean reward = 31.870	epsilon = 0.491
epoch #31	mean reward = 38.300	epsilon = 0.481
epoch #32	mean reward = 34.270	epsilon = 0.471
epoch #33	mean reward = 51.250	epsilon = 0.462
epoch #34	mean reward = 41.910	epsilon = 0.453
epoch #35	mean reward = 37.790	epsilon = 0.444
epoch #36	mean reward = 42.110	epsilon = 0.435
epoch #37	mean reward = 43.700	epsilon = 0.426
epoch #38	mean reward = 47.180	epsilon = 0.418
epoch #39	mean reward = 45.390	epsilon = 0.409
epoch #40	mean reward = 54.630	epsilon = 0.401
epoch #41	mean reward = 58.470	epsilon = 0.393
epoch #42	mean reward = 39.850	epsilon = 0.385
epoch #43	mean reward = 62.650	epsilon = 0.378
epoch #44	mean reward = 58.570	epsilon = 0.370
epoch #45	mean reward = 59.480	epsilon = 0.363
epoch #46	mean reward = 72.460	epsilon = 0.355
epoch #47	mean reward = 91.380	epsilon = 0.348
epoch #48	mean reward = 103.490	epsilon = 0.341
epoch #49	mean reward = 93.850	epsilon = 0.334
epoch #50	mean reward = 90.360	epsilon = 0.328
epoch #51	mean reward = 103.180	epsilon = 0.321
epoch #52	mean reward = 123.590	epsilon = 0.315
epoch #53	mean reward = 165.910	epsilon = 0.308
epoch #54	mean reward = 141.670	epsilon = 0.302
epoch #55	mean reward = 183.550	epsilon = 0.296
epoch #56	mean reward = 293.860	epsilon = 0.290
epoch #57	mean reward = 229.750	epsilon = 0.285
epoch #58	mean reward = 15.330	epsilon = 0.279
epoch #59	mean reward = 227.680	epsilon = 0.273
epoch #60	mean reward = 176.110	epsilon = 0.268
epoch #61	mean reward = 19.440	epsilon = 0.262
epoch #62	mean reward = 97.220	epsilon = 0.257
epoch #63	mean reward = 168.950	epsilon = 0.252
epoch #64	mean reward = 122.550	epsilon = 0.247
epoch #65	mean reward = 118.090	epsilon = 0.242
epoch #66	mean reward = 125.130	epsilon = 0.237
epoch #67	mean reward = 132.210	epsilon = 0.232
epoch #68	mean reward = 144.840	epsilon = 0.228
epoch #69	mean reward = 159.740	epsilon = 0.223
epoch #70	mean reward = 195.700	epsilon = 0.219
epoch #71	mean reward = 183.010	epsilon = 0.214
epoch #72	mean reward = 165.740	epsilon = 0.210
epoch #73	mean reward = 270.900	epsilon = 0.206
epoch #74	mean reward = 192.200	epsilon = 0.202
epoch #75	mean reward = 36.600	epsilon = 0.198
epoch #76	mean reward = 10.140	epsilon = 0.194
epoch #77	mean reward = 20.670	epsilon = 0.190
epoch #78	mean reward = 12.310	epsilon = 0.186
epoch #79	mean reward = 134.020	epsilon = 0.182
epoch #80	mean reward = 92.060	epsilon = 0.179
epoch #81	mean reward = 109.500	epsilon = 0.175
epoch #82	mean reward = 203.660	epsilon = 0.172
epoch #83	mean reward = 329.670	epsilon = 0.168
You Win!

Визуализация#

# Record sessions

from gym.wrappers.monitoring.video_recorder import VideoRecorder

env = gym.make("CartPole-v0", render_mode="rgb_array")
video = VideoRecorder(env, "after_training.mp4")

def capture_session(env, video, t_max=10000, epsilon=0):
    """play env with approximate q-learning agent and capture the video"""
    s = env.reset()[0]

    for t in range(t_max):
        a = get_action(s, epsilon=epsilon)
        next_s, r, done, *_ = env.step(a)
        video.capture_frame()
        s = next_s
        if done:
            print(t)
            break

capture_session(env, video)
video.close()
271
Moviepy - Building video after_training.mp4.
Moviepy - Writing video after_training.mp4
                                                                                                                        
Moviepy - Done !
Moviepy - video ready after_training.mp4

# Show video. This may not work in some setups. If it doesn't
# work for you, you can download the videos and view them locally.

import sys, os
from pathlib import Path
from base64 import b64encode
from IPython.display import HTML

video_path = Path("after_training.mp4")


with video_path.open('rb') as fp:
    mp4 = fp.read()
data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()


HTML("""
<video width="640" height="480" controls>
  <source src="{}" type="video/mp4">
</video>
""".format(data_url))