Skip to content

Commit

Permalink
Merge pull request #1271 from Jamesflynn1:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665977302
Change-Id: I5b5486ce89e53b990dcbb4c8952389a3cd9d4382
  • Loading branch information
lanctot committed Aug 27, 2024
2 parents 5a1f76f + ef58697 commit 42ff9ba
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions open_spiel/python/algorithms/efr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, game, deviation_gen):
def return_cumulative_regret(self):
"""Returns a dictionary mapping.
The mapping is fromevery information state to its associated regret
The mapping is from every information state to its associated regret
(accumulated over all iterations).
"""
return {
Expand Down Expand Up @@ -274,9 +274,7 @@ def _update_current_policy(self, state, current_policy):
)

state_policy = current_policy.policy_for_key(info_state)
for action, value in self._regret_matching(
info_state_node.legal_actions, info_state_node
).items():
for action, value in self._regret_matching(info_state_node).items():
state_policy[action] = value

for action in info_state_node.legal_actions:
Expand Down Expand Up @@ -491,48 +489,46 @@ def __init__(self, game, deviations_name):
deviation_sets = return_behavourial
else:
raise ValueError(
"Unsupported Deviation Set Passed As "
" Constructor Argument"
"Unsupported Deviation Set Passed As Constructor"
" Argument"
)
super(EFRSolver, self).__init__(game, deviation_sets)
self._external_only = external_only

def _regret_matching(self, legal_actions, info_set_node):
def _regret_matching(self, info_set_node):
"""Returns an info state policy.
The info state policy returned is the one obtained by applying
regret-matching function over all deviations and time selection functions.
Args:
legal_actions: the list of legal actions at this state.
info_set_node: the info state node to compute the policy for.
Returns:
A dict of action -> prob for all legal actions.
A dict of action -> prob for all legal actions of the
info_set_node.
"""
legal_actions = info_set_node.legal_actions
num_actions = len(legal_actions)
info_state_policy = None
z = sum(info_set_node.y_values.values())
info_state_policy = {}

# The fixed point solution can be directly obtained through the
# weighted regret matrix if only external deviations are used.
if self._external_only and z > 0:
weighted_deviation_matrix = np.zeros(
(len(legal_actions), len(legal_actions))
)
weighted_deviation_matrix = np.zeros((num_actions, num_actions))
for dev in list(info_set_node.y_values.keys()):
weighted_deviation_matrix += (
info_set_node.y_values[dev] / z
) * dev.return_transform_matrix()
new_strategy = weighted_deviation_matrix[:, 0]
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = new_strategy[index]
info_state_policy = dict(zip(legal_actions, new_strategy))

# Full regret matching by finding the least squares solution to the
# fixed point of the EFR regret matching function.
# Last row of matrix and the column entry minimises the solution
# towards a strategy.
elif z > 0:
num_actions = len(info_set_node.legal_actions)
weighted_deviation_matrix = -np.eye(num_actions)

for dev in list(info_set_node.y_values.keys()):
Expand All @@ -551,16 +547,17 @@ def _regret_matching(self, legal_actions, info_set_node):
strategy = linalg.lstsq(weighted_deviation_matrix, b)[0]

# Adopt same clipping strategy as paper author's code.
strategy[np.where(strategy < 0)] = 0
strategy[np.where(strategy > 1)] = 1
np.clip(strategy, a_min=0, a_max=1, out=strategy)
strategy = strategy / np.sum(strategy)

strategy = strategy / sum(strategy)
for index in range(len(strategy)):
info_state_policy[info_set_node.legal_actions[index]] = strategy[index]
info_state_policy = dict(zip(legal_actions, strategy[:, 0]))
# Use a uniform strategy as sum of all regrets is negative.
else:
for index in range(len(legal_actions)):
info_state_policy[legal_actions[index]] = 1.0 / len(legal_actions)
unif_policy_value = 1.0 / num_actions
info_state_policy = {
legal_actions[index]: unif_policy_value
for index in range(num_actions)
}
return info_state_policy


Expand Down Expand Up @@ -617,10 +614,7 @@ def array_to_strat_dict(strategy_array, legal_actions):
Returns:
strategy_dictionary: a dictionary action -> prob value.
"""
strategy_dictionary = {}
for action in legal_actions:
strategy_dictionary[action] = strategy_array[action]
return strategy_dictionary
return dict(zip(legal_actions, strategy_array))


def create_probs_from_index(indices, current_policy):
Expand Down

0 comments on commit 42ff9ba

Please sign in to comment.