強化學習論文解讀1:Soft Actor-Critic
Flood Sung已經在最前沿:深度解讀Soft Actor-Critic 演算法比較完整的解讀了SAC的程式碼,本文章主要是對SAC一些細節的補充。
如何計算policy的entropy?
SAC中使用了Gaussian 函式作為policy,policy
的 entropy的定義如下:
[1]
在程式中,我們可以根據當前的policy
的機率密度函式取樣得到
,然後計算取樣後的
即可。取樣的程式碼如下:
normal = Normal(mean, std)
x_t = normal。rsample()
在不加限制的條件下,求
很簡單:
log_prob = normal。log_prob(x_t)
但是action
有界,因此需要用tanh函式加以限制:
[2]
接下來就變成了我們已知
的機率密度函式
,以及action
與
的關係,如何求
的機率密度函式
的問題。
首先介紹一個機率密度函式變換的定理:【證明過程連結】
因為tanh函式單調遞增,且有反函式
,因此我們可以計算
:
[3]
[4]
根據[3]和[4]我們可以計算得到
:
[5]
[6]
計算程式碼如下, epsilon 是一個小量,防止除以0:
log_prob = normal。log_prob(x_t)
# Enforcing Action Bound
y_t = torch。tanh(x_t)
log_prob -= torch。log(self。action_scale * (1 - y_t。pow(2)) + epsilon)
SAC的GaussianPolicy的完整程式碼如下,接下來有時間再補充其他問題。
class GaussianPolicy(nn。Module):
def __init__(self, state_dim, action_dim, hidden_dim, max_action=None):
super(GaussianPolicy, self)。__init__()
self。linear1 = nn。Linear(state_dim, hidden_dim)
self。linear2 = nn。Linear(hidden_dim, hidden_dim)
self。mean_linear = nn。Linear(hidden_dim, action_dim)
self。log_std_linear = nn。Linear(hidden_dim, action_dim)
self。apply(weights_init_)
# action rescaling
if max_action is None:
self。action_scale = torch。tensor(1。)
self。action_bias = torch。tensor(0。)
else:
self。action_scale = torch。tensor(max_action)
self。action_bias = torch。tensor(0。)
def forward(self, state):
x = F。relu(self。linear1(state))
x = F。relu(self。linear2(x))
mean = self。mean_linear(x)
log_std = self。log_std_linear(x)
log_std = torch。clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
return mean, log_std
def sample(self, state):
mean, log_std = self。forward(state)
std = log_std。exp()
normal = Normal(mean, std)
x_t = normal。rsample() # for reparameterization trick (mean + std * N(0,1))
y_t = torch。tanh(x_t)
action = y_t * self。action_scale + self。action_bias
log_prob = normal。log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch。log(self。action_scale * (1 - y_t。pow(2)) + epsilon)
log_prob = log_prob。sum(1, keepdim=True)
mean = torch。tanh(mean) * self。action_scale + self。action_bias
return action, log_prob, mean
著作權歸作者所有。商業轉載請聯絡作者獲得授權,非商業轉載請註明作者、出處、及原文連結。