Skip to content

Commit

Permalink
Fixup for 'Training An Agent' page (#1281)
Browse files Browse the repository at this point in the history
Co-authored-by: chr0nikler <[email protected]>
  • Loading branch information
chr0nikler and chr0nikler authored Jan 6, 2025
1 parent 87cc458 commit fc74bb8
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions docs/introduction/train_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,37 +155,48 @@ You can use `matplotlib` to visualize the training reward and length.

```python
from matplotlib import pyplot as plt
# visualize the episode rewards, episode length and training error in one figure
fig, axs = plt.subplots(1, 3, figsize=(20, 8))

# np.convolve will compute the rolling mean for 100 episodes

axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100))
axs[0].set_title("Episode Rewards")
axs[0].set_xlabel("Episode")
axs[0].set_ylabel("Reward")
def get_moving_avgs(arr, window, convolution_mode):
return np.convolve(
np.array(arr).flatten(),
np.ones(window),
mode=convolution_mode
) / window

# Smooth over a 500 episode window
rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))

axs[0].set_title("Episode rewards")
reward_moving_average = get_moving_avgs(
env.return_queue,
rolling_length,
"valid"
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)

axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100))
axs[1].set_title("Episode Lengths")
axs[1].set_xlabel("Episode")
axs[1].set_ylabel("Length")
axs[1].set_title("Episode lengths")
length_moving_average = get_moving_avgs(
env.length_queue,
rolling_length,
"valid"
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)

axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100))
axs[2].set_title("Training Error")
axs[2].set_xlabel("Episode")
axs[2].set_ylabel("Temporal Difference")

training_error_moving_average = get_moving_avgs(
agent.training_error,
rolling_length,
"same"
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
```

![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")

## Visualising the policy

![](../_static/img/tutorials/blackjack_with_usable_ace.png "With a usable ace")
```

![](../_static/img/tutorials/blackjack_without_usable_ace.png "Without a usable ace")
![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")

Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges.

Expand Down

0 comments on commit fc74bb8

Please sign in to comment.