From 3db805d22101c3a7496ca93700c272ac9c5bb476 Mon Sep 17 00:00:00 2001 From: Alessandro Palmas Date: Wed, 24 Apr 2024 18:54:55 -0400 Subject: [PATCH] Fixed fighting style sampling --- diambra/arena/env_settings.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/diambra/arena/env_settings.py b/diambra/arena/env_settings.py index c1dc8a4..50262af 100644 --- a/diambra/arena/env_settings.py +++ b/diambra/arena/env_settings.py @@ -242,7 +242,11 @@ def _sanity_check(self): check_val_in_list("characters[{}]".format(idx), self.characters[idx], char_list) check_num_in_range("outfits", self.outfits, self.games_dict[self.game_id]["outfits"]) check_val_in_list("super_art", self.super_art, [None, 1, 2, 3]) - check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2, 3]) + if self.game_id == "kof98umh": + check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2, 3]) + else: + check_val_in_list("fighting_style", self.fighting_style, [None, 1, 2]) + if self.ultimate_style is not None: for idx in range(3): check_val_in_list("ultimate_style[{}]".format(idx), self.ultimate_style[idx], [1, 2]) @@ -268,7 +272,10 @@ def _process_random_values(self): if self.super_art is None: self.super_art = random.choice(list(range(1, 4))) if self.fighting_style is None: - self.fighting_style = random.choice(list(range(1, 4))) + maxFightingStyle = 3 + if self.game_id == "kof98umh": + maxFightingStyle = 4 + self.fighting_style = random.choice(list(range(1, maxFightingStyle))) if self.ultimate_style is None: self.ultimate_style = tuple([random.choice(list(range(1, 3))) for _ in range(3)]) if self.speed_mode is None: @@ -331,7 +338,10 @@ def _sanity_check(self): check_val_in_list("characters[{}][{}]".format(idx, jdx), self.characters[idx][jdx], char_list) check_num_in_range("outfits[{}]".format(idx), self.outfits[idx], self.games_dict[self.game_id]["outfits"]) check_val_in_list("super_art[{}]".format(idx), self.super_art[idx], [None, 1, 2, 3]) - check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2, 3]) + if self.game_id == "kof98umh": + check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2, 3]) + else: + check_val_in_list("fighting_style[{}]".format(idx), self.fighting_style[idx], [None, 1, 2]) if self.ultimate_style[idx] is not None: for jdx in range(3): check_val_in_list("ultimate_style[{}][{}]".format(idx, jdx), self.ultimate_style[idx][jdx], [1, 2]) @@ -363,7 +373,10 @@ def _process_random_values(self): self.role = (self.role[0], Roles.P1 if self.role[0] == Roles.P2 else Roles.P2) self.super_art = tuple([random.choice(list(range(1, 4))) if self.super_art[idx] is None else self.super_art[idx] for idx in range(2)]) - self.fighting_style = tuple([random.choice(list(range(1, 4))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)]) + maxFightingStyle = 3 + if self.game_id == "kof98umh": + maxFightingStyle = 4 + self.fighting_style = tuple([random.choice(list(range(1, maxFightingStyle))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)]) self.ultimate_style = tuple([[random.choice(list(range(1, 3))) for _ in range(3)] if self.ultimate_style[idx] is None else self.ultimate_style[idx] for idx in range(2)]) self.speed_mode = tuple([random.choice(list(range(1, 3))) if self.speed_mode[idx] is None else self.speed_mode[idx] for idx in range(2)])