Skip to content

Commit

Permalink
Add scaling for PositionBonus (#433)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Towers <[email protected]>
  • Loading branch information
Mahrkeenerh and pseudo-rnd-thoughts authored Jan 13, 2025
1 parent 4986bf5 commit 8710e91
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions minigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def step(self, action):

class PositionBonus(Wrapper):
"""
Adds an exploration bonus based on which positions
Adds a scaled exploration bonus based on which positions
are visited on the grid.
Note:
Expand All @@ -142,7 +142,7 @@ class PositionBonus(Wrapper):
>>> _, reward, _, _, _ = env.step(1)
>>> print(reward)
0
>>> env_bonus = PositionBonus(env)
>>> env_bonus = PositionBonus(env, scale=1)
>>> obs, _ = env_bonus.reset(seed=0)
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
>>> print(reward)
Expand All @@ -152,14 +152,15 @@ class PositionBonus(Wrapper):
0.7071067811865475
"""

def __init__(self, env):
def __init__(self, env, scale=1):
"""A wrapper that adds an exploration bonus to less visited positions.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
self.counts = {}
self.scale = 1

def step(self, action):
"""Steps through the environment with `action`."""
Expand All @@ -171,16 +172,14 @@ def step(self, action):
tup = tuple(env.agent_pos)

# Get the count for this key
pre_count = 0
if tup in self.counts:
pre_count = self.counts[tup]
pre_count = self.counts.get(tup, 0)

# Update the count for this key
new_count = pre_count + 1
self.counts[tup] = new_count

bonus = 1 / math.sqrt(new_count)
reward += bonus
reward += bonus * self.scale

return obs, reward, terminated, truncated, info

Expand Down

0 comments on commit 8710e91

Please sign in to comment.