10
10
EPISODES = 2500
11
11
12
12
13
+ # this is REINFORCE Agent for GridWorld
13
14
class ReinforceAgent :
14
15
def __init__ (self ):
15
- self .load_model = False
16
+ self .load_model = True
17
+ # actions which agent can do
16
18
self .action_space = [0 , 1 , 2 , 3 , 4 ]
19
+ # get size of state and action
17
20
self .action_size = len (self .action_space )
18
21
self .state_size = 15
19
- self .discount_factor = 0.99 # decay rate
22
+ self .discount_factor = 0.99
20
23
self .learning_rate = 0.001
21
24
22
25
self .model = self .build_model ()
@@ -26,6 +29,7 @@ def __init__(self):
26
29
if self .load_model :
27
30
self .model .load_weights ('./save_model/reinforce_trained.h5' )
28
31
32
+ # state is input and probability of each action(policy) is output of network
29
33
def build_model (self ):
30
34
model = Sequential ()
31
35
model .add (Dense (24 , input_dim = self .state_size , activation = 'relu' ))
@@ -34,25 +38,31 @@ def build_model(self):
34
38
model .summary ()
35
39
return model
36
40
41
+ # create error function and training function to update policy network
37
42
def optimizer (self ):
38
43
action = K .placeholder (shape = [None , 5 ])
39
44
discounted_rewards = K .placeholder (shape = [None , ])
40
- good_prob = K .sum (action * self .model .output , axis = 1 )
41
- eligibility = K .log (good_prob ) * K .stop_gradient (discounted_rewards )
42
- loss = - K .sum (eligibility )
43
45
46
+ # Calculate cross entropy error function
47
+ action_prob = K .sum (action * self .model .output , axis = 1 )
48
+ cross_entropy = K .log (action_prob ) * discounted_rewards
49
+ loss = - K .sum (cross_entropy )
50
+
51
+ # create training function
44
52
optimizer = Adam (lr = self .learning_rate )
45
- updates = optimizer .get_updates (self .model .trainable_weights ,[],
53
+ updates = optimizer .get_updates (self .model .trainable_weights , [],
46
54
loss )
47
55
train = K .function ([self .model .input , action , discounted_rewards ], [],
48
56
updates = updates )
49
57
50
58
return train
51
59
60
+ # get action from policy network
52
61
def get_action (self , state ):
53
62
policy = self .model .predict (state )[0 ]
54
63
return np .random .choice (self .action_size , 1 , p = policy )[0 ]
55
64
65
+ # calculate discounted rewards
56
66
def discount_rewards (self , rewards ):
57
67
discounted_rewards = np .zeros_like (rewards )
58
68
running_add = 0
@@ -61,13 +71,15 @@ def discount_rewards(self, rewards):
61
71
discounted_rewards [t ] = running_add
62
72
return discounted_rewards
63
73
64
- def remember_episode (self , state , action , reward ):
74
+ # save states, actions and rewards for an episode
75
+ def append_sample (self , state , action , reward ):
65
76
self .states .append (state [0 ])
66
77
self .rewards .append (reward )
67
78
act = np .zeros (self .action_size )
68
79
act [action ] = 1
69
80
self .actions .append (act )
70
81
82
+ # update policy neural network
71
83
def train_model (self ):
72
84
discounted_rewards = np .float32 (self .discount_rewards (self .rewards ))
73
85
discounted_rewards -= np .mean (discounted_rewards )
@@ -87,21 +99,23 @@ def train_model(self):
87
99
for e in range (EPISODES ):
88
100
done = False
89
101
score = 0
102
+ # fresh env
90
103
state = env .reset ()
91
104
state = np .reshape (state , [1 , 15 ])
92
105
93
106
while not done :
94
107
global_step += 1
95
-
108
+ # get action for the current state and go one step in environment
96
109
action = agent .get_action (state )
97
110
next_state , reward , done = env .step (action )
98
111
next_state = np .reshape (next_state , [1 , 15 ])
99
112
100
- agent .remember_episode (state , action , reward )
113
+ agent .append_sample (state , action , reward )
101
114
score += reward
102
115
state = copy .deepcopy (next_state )
103
116
104
117
if done :
118
+ # update policy neural network for each episode
105
119
agent .train_model ()
106
120
scores .append (score )
107
121
episodes .append (e )
@@ -113,6 +127,3 @@ def train_model(self):
113
127
pylab .plot (episodes , scores , 'b' )
114
128
pylab .savefig ("./save_graph/reinforce.png" )
115
129
agent .model .save_weights ("./save_model/reinforce.h5" )
116
-
117
- print ('game over' )
118
- env .destroy ()
0 commit comments