Skip to content

Commit

Permalink
Fixed fighting style sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Apr 26, 2024
1 parent 332a9a0 commit 3db805d
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,11 @@ def _sanity_check(self):
check_val_in_list("characters[{}]".format(idx), self.characters[idx], char_list)
check_num_in_range("outfits", self.outfits, self.games_dict[self.game_id]["outfits"])
check_val_in_list("super_art", self.super_art, [None, 1, 2, 3])
check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2, 3])
if self.game_id == "kof98umh":
check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2, 3])
else:
check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2])

if self.ultimate_style is not None:
for idx in range(3):
check_val_in_list("ultimate_style[{}]".format(idx), self.ultimate_style[idx], [1, 2])
Expand All @@ -268,7 +272,10 @@ def _process_random_values(self):
if self.super_art is None:
self.super_art = random.choice(list(range(1, 4)))
if self.fighting_style is None:
self.fighting_style = random.choice(list(range(1, 4)))
maxFightingStyle = 3
if self.game_id == "kof98umh":
maxFightingStyle = 4
self.fighting_style = random.choice(list(range(1, maxFightingStyle)))
if self.ultimate_style is None:
self.ultimate_style = tuple([random.choice(list(range(1, 3))) for _ in range(3)])
if self.speed_mode is None:
Expand Down Expand Up @@ -331,7 +338,10 @@ def _sanity_check(self):
check_val_in_list("characters[{}][{}]".format(idx, jdx), self.characters[idx][jdx], char_list)
check_num_in_range("outfits[{}]".format(idx), self.outfits[idx], self.games_dict[self.game_id]["outfits"])
check_val_in_list("super_art[{}]".format(idx), self.super_art[idx], [None, 1, 2, 3])
check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2, 3])
if self.game_id == "kof98umh":
check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2, 3])
else:
check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2])
if self.ultimate_style[idx] is not None:
for jdx in range(3):
check_val_in_list("ultimate_style[{}][{}]".format(idx, jdx), self.ultimate_style[idx][jdx], [1, 2])
Expand Down Expand Up @@ -363,7 +373,10 @@ def _process_random_values(self):
self.role = (self.role[0], Roles.P1 if self.role[0] == Roles.P2 else Roles.P2)

self.super_art = tuple([random.choice(list(range(1, 4))) if self.super_art[idx] is None else self.super_art[idx] for idx in range(2)])
self.fighting_style = tuple([random.choice(list(range(1, 4))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)])
maxFightingStyle = 3
if self.game_id == "kof98umh":
maxFightingStyle = 4
self.fighting_style = tuple([random.choice(list(range(1, maxFightingStyle))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)])
self.ultimate_style = tuple([[random.choice(list(range(1, 3))) for _ in range(3)] if self.ultimate_style[idx] is None else self.ultimate_style[idx] for idx in range(2)])
self.speed_mode = tuple([random.choice(list(range(1, 3))) if self.speed_mode[idx] is None else self.speed_mode[idx] for idx in range(2)])

Expand Down

0 comments on commit 3db805d

Please sign in to comment.