-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy.py
108 lines (84 loc) · 2.56 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
def normal_entropy(std):
"""
compute the entropy of normal distribution.
please refer to https://math.stackexchange.com/questions/1804805/how-is-the-entropy-of-the-normal-distribution-derived
for more details.
:param std: [b, a_dim]
:return: [b, 1]
"""
var = std.pow(2)
entropy = 0.5 + 0.5 * torch.log(2 * var * np.pi)
return entropy.sum(1, keepdim=True)
def normal_log_density(x, mean, log_std):
"""
x ~ N(mean, std)
this function will return log(prob(x)) while x belongs to guassian distrition(mean, std)
:param x: [b, a_dim]
:param mean: [b, a_dim]
:param log_std: [b, a_dim]
:return: [b, 1]
"""
std = torch.exp(log_std)
var = std.pow(2)
log_density = - torch.pow(x - mean, 2) / (2 * var) - 0.5 * np.log(2 * np.pi) - log_std
return log_density.sum(1, keepdim=True)
class Policy(nn.Module):
def __init__(self, s_dim, a_dim):
super(Policy, self).__init__()
# self.net = nn.Sequential(nn.Linear(s_dim, 5),
# nn.ReLU(),
# nn.Linear(5, 5),
# nn.ReLU(),
# nn.Linear(5, a_dim))
# self.fc1 = nn.Linear(s_dim, 10)
# self.fc2 = nn.Linear(10, 10)
# self.fc3 = nn.Linear(10, a_dim)
# #
self.fc1 = nn.Linear(s_dim, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, a_dim)
# this is Variable of nn.Module, added to class automatically
# it will be optimized as well.
self.a_log_std = nn.Parameter(torch.zeros(s_dim, a_dim))
def forward(self, s):
# [b, s_dim] => [b, a_dim]
# s = s.double()
# a_mean = self.net(s)
a = F.relu(self.fc1(s))
a = F.relu(self.fc2(a))
a_mean = self.fc3(a)
# [1, a_dim] => [b, a_dim]
a_log_std = self.a_log_std
# a_log_std = self.a_log_std.expand_as(a_mean)
return a_mean, a_log_std
def select_action(self, s):
"""
:param s:
:return:
"""
# forward to get action mean and log_std
# [b, s_dim] => [b, a_dim]
s = s.double()
a_mean, a_log_std = self.forward(s)
# print('std:', np.exp(a_log_std))
# randomly sample from normal distribution, whose mean and variance come from policy network.
# [b, a_dim]
a = torch.normal(a_mean, torch.exp(a_log_std))
return a
def get_log_prob(self, s, a):
"""
:param s:
:param a:
:return:
"""
# forward to get action mean and log_std
# [b, s_dim] => [b, a_dim]
a_mean, a_log_std = self.forward(s)
# [b, a_dim] => [b, 1]
log_prob = normal_log_density(a, a_mean, a_log_std)
return log_prob