Flood Sung已經在最前沿:深度解讀Soft Actor-Critic 演算法比較完整的解讀了SAC的程式碼,本文章主要是對SAC一些細節的補充。

如何計算policy的entropy?

SAC中使用了Gaussian 函式作為policy,policy

\pi

的 entropy的定義如下:

[1]

H^{\pi}(\bm{s}_t)=E_{\bm{a}_{t}\sim\pi}[-\log \pi(\bm{a}_{t}|\bm{s}_{t}) ]

在程式中,我們可以根據當前的policy

\pi

的機率密度函式取樣得到

\bm{a}_{t}

,然後計算取樣後的

-\log \pi(\bm{a}_{t}|\bm{s}_{t})

即可。取樣的程式碼如下:

normal = Normal(mean, std)

x_t = normal。rsample()

在不加限制的條件下,求

-\log \pi(\bm{x}_{t}|\bm{s}_{t})

很簡單:

log_prob = normal。log_prob(x_t)

但是action

\bm{a}_{t}

有界,因此需要用tanh函式加以限制:

[2]

\bm{a}_{t}=\tanh(\bm{x}_t)

接下來就變成了我們已知

\bm{x}_{t}

的機率密度函式

f_x(\bm{x}_{t})

,以及action

\bm{a}_{t}

\bm{x}_{t}

的關係,如何求

\bm{a}_{t}

的機率密度函式

f_a(\bm{a}_{t})

的問題。

首先介紹一個機率密度函式變換的定理:【證明過程連結】

強化學習論文解讀1:Soft Actor-Critic

因為tanh函式單調遞增,且有反函式

\bm{x}_{t}=v(\bm{a}_{t})=0.5\ln(\frac{1+\bm{a}_{t}}{1-\bm{a}_{t}})

,因此我們可以計算

f_a(\bm{a}_{t})

[3]

f_a(\bm{a}_{t}) = f_x(v(\bm{a}_{t}))| v

[4]

v

根據[3]和[4]我們可以計算得到

f_a(\bm{a}_{t})

[5]

f_a(\bm{a}_{t})= \frac{f_x(v(\bm{a}_{t}))}{1-\bm{a}^2_{t}} = \frac{f_x(\bm{x}_{t})}{1-\tanh^2(\bm{x}_{t})}

[6]

\log(f_a(\bm{a}_{t})) = \log( f_x(\bm{x}_{t}))-\log({1-\tanh^2(\bm{x}_{t})})

計算程式碼如下, epsilon 是一個小量,防止除以0:

log_prob = normal。log_prob(x_t)

# Enforcing Action Bound

y_t = torch。tanh(x_t)

log_prob -= torch。log(self。action_scale * (1 - y_t。pow(2)) + epsilon)

SAC的GaussianPolicy的完整程式碼如下,接下來有時間再補充其他問題。

class GaussianPolicy(nn。Module):

def __init__(self, state_dim, action_dim, hidden_dim, max_action=None):

super(GaussianPolicy, self)。__init__()

self。linear1 = nn。Linear(state_dim, hidden_dim)

self。linear2 = nn。Linear(hidden_dim, hidden_dim)

self。mean_linear = nn。Linear(hidden_dim, action_dim)

self。log_std_linear = nn。Linear(hidden_dim, action_dim)

self。apply(weights_init_)

# action rescaling

if max_action is None:

self。action_scale = torch。tensor(1。)

self。action_bias = torch。tensor(0。)

else:

self。action_scale = torch。tensor(max_action)

self。action_bias = torch。tensor(0。)

def forward(self, state):

x = F。relu(self。linear1(state))

x = F。relu(self。linear2(x))

mean = self。mean_linear(x)

log_std = self。log_std_linear(x)

log_std = torch。clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)

return mean, log_std

def sample(self, state):

mean, log_std = self。forward(state)

std = log_std。exp()

normal = Normal(mean, std)

x_t = normal。rsample() # for reparameterization trick (mean + std * N(0,1))

y_t = torch。tanh(x_t)

action = y_t * self。action_scale + self。action_bias

log_prob = normal。log_prob(x_t)

# Enforcing Action Bound

log_prob -= torch。log(self。action_scale * (1 - y_t。pow(2)) + epsilon)

log_prob = log_prob。sum(1, keepdim=True)

mean = torch。tanh(mean) * self。action_scale + self。action_bias

return action, log_prob, mean

著作權歸作者所有。商業轉載請聯絡作者獲得授權,非商業轉載請註明作者、出處、及原文連結。