Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Change _reward_fn to take action taken
Browse files Browse the repository at this point in the history
In standard RL often the reward function takes as input the action
taken. This change allows that to happen by adjusting _reward_fn.

See issue #155 for more details
  • Loading branch information
RobertKirk authored and Heinrich Kuttler committed Jun 16, 2021
1 parent 16cc9d6 commit fc8e272
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
7 changes: 5 additions & 2 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def step(self, action: int):
end_status = self._is_episode_end(observation)
end_status = self.StepStatus(done or end_status)

reward = float(self._reward_fn(last_observation, observation, end_status))
reward = float(
self._reward_fn(last_observation, action, observation, end_status)
)

if end_status and not done:
# Try to end the game nicely.
Expand Down Expand Up @@ -605,14 +607,15 @@ def _is_episode_end(self, observation):
"""
return self.StepStatus.RUNNING

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
"""Reward function. Difference between previous score and new score."""
if not self.env.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0
old_score = last_observation[self._blstats_index][BLSTATS_SCORE_INDEX]
score = observation[self._blstats_index][BLSTATS_SCORE_INDEX]
del end_status # Unused for "score" reward.
del action # Unused for "score reward.
return score - old_score

def _perform_known_steps(self, observation, done, exceptions=True):
Expand Down
18 changes: 12 additions & 6 deletions nle/env/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def _get_time_penalty(self, last_observation, observation):
penalty += (new_time - old_time) * self.penalty_time
return penalty

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
"""Score delta, but with added a state loop penalty."""
score_diff = super()._reward_fn(last_observation, observation, end_status)
score_diff = super()._reward_fn(
last_observation, action, observation, end_status
)
time_penalty = self._get_time_penalty(last_observation, observation)
return score_diff + time_penalty

Expand All @@ -111,7 +113,8 @@ def _is_episode_end(self, observation):
return self.StepStatus.TASK_SUCCESSFUL
return self.StepStatus.RUNNING

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
del action # Unused
time_penalty = self._get_time_penalty(last_observation, observation)
if end_status == self.StepStatus.TASK_SUCCESSFUL:
reward = 1
Expand Down Expand Up @@ -194,9 +197,10 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, options=options, **kwargs)

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
"""Difference between previous gold and new gold."""
del end_status # Unused
del action # Unused
if not self.env.in_normal_game():
# Before game started or after it ended stats are zero.
return 0.0
Expand Down Expand Up @@ -224,9 +228,10 @@ class NetHackEat(NetHackScore):
comestibles or monster corpses), rather than the score.
"""

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
"""Difference between previous hunger and new hunger."""
del end_status # Unused
del action # Unused

if not self.env.in_normal_game():
# Before game started or after it ended stats are zero.
Expand Down Expand Up @@ -256,8 +261,9 @@ def reset(self, *args, **kwargs):
self.dungeon_explored = {}
return super().reset(*args, **kwargs)

def _reward_fn(self, last_observation, observation, end_status):
def _reward_fn(self, last_observation, action, observation, end_status):
del end_status # Unused
del action # Unused

if not self.env.in_normal_game():
# Before game started or after it ended stats are zero.
Expand Down

0 comments on commit fc8e272

Please sign in to comment.