Skip to content

Commit cb1b478

Browse files
committed
comments for reinforce
1 parent 6fd5a23 commit cb1b478

File tree

6 files changed

+25
-22
lines changed

6 files changed

+25
-22
lines changed

1-grid-world/4-sarsa/environment.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ def step(self, action):
135135

136136
next_state = self.coords_to_state(next_state)
137137

138-
139-
140138
return next_state, reward, done
141139

142140
def render(self):

1-grid-world/4-sarsa/sarsa_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import random
3-
import time
43
from collections import defaultdict
54
from environment import Env
65

1-grid-world/6-deep-sarsa/deep_sarsa_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,3 @@ def train_model(self, state, action, reward, next_state, next_action, done):
115115

116116
if e % 100 == 0:
117117
agent.model.save_weights("./save_model/deep_sarsa.h5")
118-
119-
# end of game
120-
print('game over')
121-
env.destroy()

1-grid-world/6-deep-sarsa/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,5 +236,5 @@ def move(self, target, action):
236236
return s_
237237

238238
def render(self):
239-
time.sleep(0.05)
239+
time.sleep(0.07)
240240
self.update()

1-grid-world/7-reinforce/environment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,5 @@ def move(self, target, action):
235235
return s_
236236

237237
def render(self):
238-
# time.sleep(0.1)
239-
time.sleep(0.01)
238+
time.sleep(0.07)
240239
self.update()

1-grid-world/7-reinforce/reinforce_agent.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
EPISODES = 2500
1111

1212

13+
# this is REINFORCE Agent for GridWorld
1314
class ReinforceAgent:
1415
def __init__(self):
15-
self.load_model = False
16+
self.load_model = True
17+
# actions which agent can do
1618
self.action_space = [0, 1, 2, 3, 4]
19+
# get size of state and action
1720
self.action_size = len(self.action_space)
1821
self.state_size = 15
19-
self.discount_factor = 0.99 # decay rate
22+
self.discount_factor = 0.99
2023
self.learning_rate = 0.001
2124

2225
self.model = self.build_model()
@@ -26,6 +29,7 @@ def __init__(self):
2629
if self.load_model:
2730
self.model.load_weights('./save_model/reinforce_trained.h5')
2831

32+
# state is input and probability of each action(policy) is output of network
2933
def build_model(self):
3034
model = Sequential()
3135
model.add(Dense(24, input_dim=self.state_size, activation='relu'))
@@ -34,25 +38,31 @@ def build_model(self):
3438
model.summary()
3539
return model
3640

41+
# create error function and training function to update policy network
3742
def optimizer(self):
3843
action = K.placeholder(shape=[None, 5])
3944
discounted_rewards = K.placeholder(shape=[None, ])
40-
good_prob = K.sum(action * self.model.output, axis=1)
41-
eligibility = K.log(good_prob) * K.stop_gradient(discounted_rewards)
42-
loss = -K.sum(eligibility)
4345

46+
# Calculate cross entropy error function
47+
action_prob = K.sum(action * self.model.output, axis=1)
48+
cross_entropy = K.log(action_prob) * discounted_rewards
49+
loss = -K.sum(cross_entropy)
50+
51+
# create training function
4452
optimizer = Adam(lr=self.learning_rate)
45-
updates = optimizer.get_updates(self.model.trainable_weights,[],
53+
updates = optimizer.get_updates(self.model.trainable_weights, [],
4654
loss)
4755
train = K.function([self.model.input, action, discounted_rewards], [],
4856
updates=updates)
4957

5058
return train
5159

60+
# get action from policy network
5261
def get_action(self, state):
5362
policy = self.model.predict(state)[0]
5463
return np.random.choice(self.action_size, 1, p=policy)[0]
5564

65+
# calculate discounted rewards
5666
def discount_rewards(self, rewards):
5767
discounted_rewards = np.zeros_like(rewards)
5868
running_add = 0
@@ -61,13 +71,15 @@ def discount_rewards(self, rewards):
6171
discounted_rewards[t] = running_add
6272
return discounted_rewards
6373

64-
def remember_episode(self, state, action, reward):
74+
# save states, actions and rewards for an episode
75+
def append_sample(self, state, action, reward):
6576
self.states.append(state[0])
6677
self.rewards.append(reward)
6778
act = np.zeros(self.action_size)
6879
act[action] = 1
6980
self.actions.append(act)
7081

82+
# update policy neural network
7183
def train_model(self):
7284
discounted_rewards = np.float32(self.discount_rewards(self.rewards))
7385
discounted_rewards -= np.mean(discounted_rewards)
@@ -87,21 +99,23 @@ def train_model(self):
8799
for e in range(EPISODES):
88100
done = False
89101
score = 0
102+
# fresh env
90103
state = env.reset()
91104
state = np.reshape(state, [1, 15])
92105

93106
while not done:
94107
global_step += 1
95-
108+
# get action for the current state and go one step in environment
96109
action = agent.get_action(state)
97110
next_state, reward, done = env.step(action)
98111
next_state = np.reshape(next_state, [1, 15])
99112

100-
agent.remember_episode(state, action, reward)
113+
agent.append_sample(state, action, reward)
101114
score += reward
102115
state = copy.deepcopy(next_state)
103116

104117
if done:
118+
# update policy neural network for each episode
105119
agent.train_model()
106120
scores.append(score)
107121
episodes.append(e)
@@ -113,6 +127,3 @@ def train_model(self):
113127
pylab.plot(episodes, scores, 'b')
114128
pylab.savefig("./save_graph/reinforce.png")
115129
agent.model.save_weights("./save_model/reinforce.h5")
116-
117-
print('game over')
118-
env.destroy()

0 commit comments

Comments
 (0)