Skip to content

Commit

Permalink
Refactoring changes to efr.py, primarily to _regret_matching.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamesflynn1 committed Aug 20, 2024
1 parent 6a5fa79 commit ef58697
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions open_spiel/python/algorithms/efr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, game, deviation_gen):
def return_cumulative_regret(self):
"""Returns a dictionary mapping.
The mapping is fromevery information state to its associated regret
The mapping is from every information state to its associated regret
(accumulated over all iterations).
"""
return {
Expand Down Expand Up @@ -491,8 +491,8 @@ def __init__(self, game, deviations_name):
deviation_sets = return_behavourial
else:
raise ValueError(
"Unsupported Deviation Set Passed As "
" Constructor Argument"
"Unsupported Deviation Set Passed\
As Constructor Argument"
)
super(EFRSolver, self).__init__(game, deviation_sets)
self._external_only = external_only
Expand All @@ -507,32 +507,32 @@ def _regret_matching(self, info_set_node):
info_set_node: the info state node to compute the policy for.
Returns:
A dict of action -> prob for all legal actions of the info_set_node.
A dict of action -> prob for all legal actions of the
info_set_node.
"""
legal_actions = info_set_node.legal_actions
num_actions = len(legal_actions)
info_state_policy = None
z = sum(info_set_node.y_values.values())
info_state_policy = {}

# The fixed point solution can be directly obtained through the
# weighted regret matrix if only external deviations are used.
if self._external_only and z > 0:
weighted_deviation_matrix = np.zeros(
(len(legal_actions), len(legal_actions))
(num_actions, num_actions)
)
for dev in list(info_set_node.y_values.keys()):
weighted_deviation_matrix += (
info_set_node.y_values[dev] / z
) * dev.return_transform_matrix()
new_strategy = weighted_deviation_matrix[:, 0]
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = new_strategy[index]
info_state_policy = dict(zip(legal_actions, new_strategy))

# Full regret matching by finding the least squares solution to the
# fixed point of the EFR regret matching function.
# Last row of matrix and the column entry minimises the solution
# towards a strategy.
elif z > 0:
num_actions = len(legal_actions)
weighted_deviation_matrix = -np.eye(num_actions)

for dev in list(info_set_node.y_values.keys()):
Expand All @@ -552,17 +552,16 @@ def _regret_matching(self, info_set_node):

# Adopt same clipping strategy as paper author's code.
np.clip(strategy, a_min=0, a_max=1, out=strategy)

strategy = strategy / np.sum(strategy)
for index in range(len(strategy)):
info_state_policy[legal_actions[index]] = strategy[index, 0]

info_state_policy = dict(zip(legal_actions, strategy[:,0]))
# Use a uniform strategy as sum of all regrets is negative.
else:
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = 1.0 / len(legal_actions)
unif_policy_value = 1.0 / num_actions
info_state_policy = {legal_actions[index]:unif_policy_value
for index in range(num_actions)}
return info_state_policy


def _update_average_policy(average_policy, info_state_nodes):
"""Updates in place `average_policy` to the average of all policies iterated.
Expand Down Expand Up @@ -616,10 +615,7 @@ def array_to_strat_dict(strategy_array, legal_actions):
Returns:
strategy_dictionary: a dictionary action -> prob value.
"""
strategy_dictionary = {}
for action in legal_actions:
strategy_dictionary[action] = strategy_array[action]
return strategy_dictionary
return dict(zip(legal_actions, strategy_array))


def create_probs_from_index(indices, current_policy):
Expand Down

0 comments on commit ef58697

Please sign in to comment.