Skip to content

Commit

Permalink
Add string names to waza_p asset files
Browse files Browse the repository at this point in the history
  • Loading branch information
AnonymousRandomPerson committed Jul 29, 2024
1 parent aa0aeb7 commit e8315f7
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 59 deletions.
5 changes: 4 additions & 1 deletion skytemple_files/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ def deserialize_enum_or_default(enum_class: type[Enum], value: str | int) -> Enu
If the value is an integer, returns the integer back.
"""
if isinstance(value, str):
return enum_class[value].value
try:
return enum_class[value].value
except NameError:
raise NameError(f"Invalid value {value} for type {enum_class.name}")
else:
return value
147 changes: 105 additions & 42 deletions skytemple_files/data/waza_p/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from typing import TYPE_CHECKING, Sequence
import json

from range_typed_integers import u32

from skytemple_files.common.impl_cfg import get_implementation_type, ImplementationType
from skytemple_files.common.ppmdu_config.data import Pmd2StringBlock
from skytemple_files.common.types.file_storage import AssetSpec, Asset
from skytemple_files.common.types.hybrid_data_handler import (
WriterProtocol,
HybridSir0DataHandler,
)
from skytemple_files.common.util import OptionalKwargs, serialize_enum_or_default, deserialize_enum_or_default
from skytemple_files.data.md.protocol import PokeType
from skytemple_files.data.str.handler import StrHandler
from skytemple_files.data.str.model import Str
from skytemple_files.data.waza_p.protocol import (
WazaPProtocol,
LevelUpMoveProtocol,
Expand Down Expand Up @@ -157,51 +162,71 @@ def serialize_asset(
if spec.category == MOVES:
moves_list = []
move: WazaMoveProtocol

strings, string_blocks = cls._get_strings_from_rom(["Move Names"], **kwargs)

for move in data.moves:
moves_list.append(
{
"base_power": move.base_power,
"type": serialize_enum_or_default(PokeType, move.type),
"category": serialize_enum_or_default(WazaMoveCategory, move.category),
"settings_range": cls._serialize_move_range_settings(move.settings_range),
"settings_range_ai": cls._serialize_move_range_settings(move.settings_range_ai),
"base_pp": move.base_pp,
"ai_weight": move.ai_weight,
"miss_accuracy": move.miss_accuracy,
"accuracy": move.accuracy,
"ai_condition1_chance": move.ai_condition1_chance,
"number_chained_hits": move.number_chained_hits,
"max_upgrade_level": move.max_upgrade_level,
"crit_chance": move.crit_chance,
"affected_by_magic_coat": move.affected_by_magic_coat,
"is_snatchable": move.is_snatchable,
"uses_mouth": move.uses_mouth,
"ai_frozen_check": move.ai_frozen_check,
"ignores_taunted": move.ignores_taunted,
"range_check_text": move.range_check_text,
"move_id": move.move_id,
"message_id": move.message_id,
}
)
serialized_move = {
"base_power": move.base_power,
"type": serialize_enum_or_default(PokeType, move.type),
"category": serialize_enum_or_default(WazaMoveCategory, move.category),
"settings_range": cls._serialize_move_range_settings(move.settings_range),
"settings_range_ai": cls._serialize_move_range_settings(move.settings_range_ai),
"base_pp": move.base_pp,
"ai_weight": move.ai_weight,
"miss_accuracy": move.miss_accuracy,
"accuracy": move.accuracy,
"ai_condition1_chance": move.ai_condition1_chance,
"number_chained_hits": move.number_chained_hits,
"max_upgrade_level": move.max_upgrade_level,
"crit_chance": move.crit_chance,
"affected_by_magic_coat": move.affected_by_magic_coat,
"is_snatchable": move.is_snatchable,
"uses_mouth": move.uses_mouth,
"ai_frozen_check": move.ai_frozen_check,
"ignores_taunted": move.ignores_taunted,
"range_check_text": move.range_check_text,
"move_id": move.move_id,
"message_id": move.message_id,
}
if strings is not None:
move_name_index = string_blocks["Move Names"].begin + move.move_id
if move_name_index <= string_blocks["Move Names"].end:
serialized_move["name"] = strings.strings[move_name_index]
moves_list.append(serialized_move)

return Asset(spec, None, None, None, None, bytes(json.dumps(moves_list, indent=4), "utf-8"))
elif spec.category == LEARNSETS:
learnsets = []
learnset: MoveLearnsetProtocol
for learnset in data.learnsets:
learnsets.append(
{
"level_up_moves": [
{
"move_id": level_up_move.move_id,
"level_id": level_up_move.level_id,
}
for level_up_move in learnset.level_up_moves
],
"tm_hm_moves": learnset.tm_hm_moves,
"egg_moves": learnset.egg_moves,
}
)

strings, string_blocks = cls._get_strings_from_rom(["Move Names", "Pokemon Names"], **kwargs)

def get_move_value(move_id: u32) -> str | u32:
if strings is None:
return move_id
else:
move_name_idx = string_blocks["Move Names"].begin + move_id
if move_name_idx <= string_blocks["Move Names"].end:
return strings.strings[move_name_idx]

for i, learnset in enumerate(data.learnsets):
serialized_learnset = {
"level_up_moves": [
{
"move_id": get_move_value(level_up_move.move_id),
"level_id": level_up_move.level_id,
}
for level_up_move in learnset.level_up_moves
],
"tm_hm_moves": [get_move_value(move_id) for move_id in learnset.tm_hm_moves],
"egg_moves": [get_move_value(move_id) for move_id in learnset.egg_moves],
}

if strings is not None:
serialized_learnset["pokemon_name"] = strings.strings[string_blocks["Pokemon Names"].begin + i]

learnsets.append(serialized_learnset)

return Asset(spec, None, None, None, None, bytes(json.dumps(learnsets, indent=4), "utf-8"))
else:
Expand All @@ -216,6 +241,9 @@ def deserialize_from_assets(
protocol: WazaPProtocol = cls.get_model_cls()(bytes(), 0)

assets_by_category = {asset.spec.category: asset for asset in assets}

# Keep track of move names/IDs to deserialize learnsets later.
move_names_to_ids: dict[str, u32] = {}
if MOVES in assets_by_category:
move_asset = assets_by_category[MOVES]
moves = json.loads(move_asset.data)
Expand Down Expand Up @@ -247,18 +275,31 @@ def deserialize_from_assets(
move.move_id = move_json["move_id"]
move.message_id = move_json["message_id"]

if "name" in move_json:
move_names_to_ids[move_json["name"]] = move_json["move_id"]

if LEARNSETS in assets_by_category:
learnset_asset = assets_by_category[LEARNSETS]
learnsets = json.loads(learnset_asset.data)
protocol.learnsets = []

def get_move_id(move_value: str | u32) -> u32:
if isinstance(move_value, str):
if move_value in move_names_to_ids:
return move_names_to_ids[move_value]
else:
raise KeyError(f"Invalid move name {move_value} in learnset.")

return move_value

for learnset_json in learnsets:
learnset = cls.get_learnset_model()(
[
cls.get_level_up_model()(level_up["move_id"], level_up["level_id"])
cls.get_level_up_model()(get_move_id(level_up["move_id"]), level_up["level_id"])
for level_up in learnset_json["level_up_moves"]
],
learnset_json["tm_hm_moves"],
learnset_json["egg_moves"],
[get_move_id(move_value) for move_value in learnset_json["tm_hm_moves"]],
[get_move_id(move_value) for move_value in learnset_json["egg_moves"]],
)
protocol.learnsets.append(learnset)

Expand All @@ -281,3 +322,25 @@ def _deserialize_move_range_settings(cls, settings_json: dict) -> WazaMoveRangeS
settings.condition = settings_json["condition"]
settings.unused = settings_json["unused"]
return settings

@staticmethod
def _get_strings_from_rom(
block_names: list[str], **kwargs: OptionalKwargs
) -> tuple[Str | None, dict[str, Pmd2StringBlock] | None]:
"""
Fetches blocks of strings from the ROM if it is specified in kwargs.
"""
if "rom_project" in kwargs:
rom_project = kwargs["rom_project"]
# Default to English strings (NA/EU). Use Japanese strings if this is the JP ROM.
try:
strings_file = rom_project.rom.getFileByName("MESSAGE/text_e.str")
except ValueError:
strings_file = rom_project.rom.getFileByName("MESSAGE/text_j.str")
strings = StrHandler.deserialize(strings_file)
string_blocks = {
block_name: rom_project.static_data.string_index_data.string_blocks[block_name]
for block_name in block_names
}
return strings, string_blocks
return None, None
2 changes: 1 addition & 1 deletion skytemple_files/transfer_asset_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def extract_rom_files_to_project(rom_path: Path, asset_dir: Path):
for file_path, data_handler in rom_files.items():
file_data = project.open_file(data_handler, file_path, force=True, load_from_rom=True)

project.save_file(data_handler, file_path, file_data, skip_save_to_rom=True)
project.save_file(data_handler, file_path, file_data, skip_save_to_rom=True, rom_project=project)


def save_project_to_rom(rom_path: Path, asset_dir: Path, extracted_rom_dir: Path):
Expand Down
14 changes: 12 additions & 2 deletions test/skytemple_files_test/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import sys
import typing
from abc import ABC
from pathlib import Path
from tempfile import TemporaryFile
from typing import Any, Generic, Mapping, Optional, Protocol, Type, TypeVar
from unittest import SkipTest

from skytemple_files.common.util import (
OptionalKwargs,
Expand All @@ -31,6 +33,7 @@
)
from skytemple_files_test.image import ImageTestCaseAbc

SKYTEMPLE_TEST_ROM_ENV = "SKYTEMPLE_TEST_ROM"
U = TypeVar("U")


Expand Down Expand Up @@ -106,8 +109,8 @@ def _outer_wrapper(wrapped_function):
from parameterized import parameterized

rom = None
if "SKYTEMPLE_TEST_ROM" in os.environ and os.environ["SKYTEMPLE_TEST_ROM"] != "":
rom = NintendoDSRom.fromFile(os.environ["SKYTEMPLE_TEST_ROM"])
if SKYTEMPLE_TEST_ROM_ENV in os.environ and os.environ[SKYTEMPLE_TEST_ROM_ENV] != "":
rom = NintendoDSRom.fromFile(os.environ[SKYTEMPLE_TEST_ROM_ENV])

if rom:

Expand Down Expand Up @@ -191,3 +194,10 @@ def dataset_name_func(testcase_func, _, param):
frame_locals[local_name] = local

return _outer_wrapper


def load_rom_path() -> Path:
if SKYTEMPLE_TEST_ROM_ENV in os.environ and os.environ[SKYTEMPLE_TEST_ROM_ENV] != "":
return Path(os.environ[SKYTEMPLE_TEST_ROM_ENV])
else:
raise SkipTest("No ROM file provided or ROM not found.")
19 changes: 9 additions & 10 deletions test/skytemple_files_test/common/file_api_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,12 @@
from skytemple_files.common.types.file_storage import Asset, AssetSpec
from skytemple_files.common.file_api_v2 import RomProject, SkyTempleProjectFileStorage, ALLOW_EXTRA_SKYPATCHES
from skytemple_files.common.types.file_types import FileType
from skytemple_files_test.case import load_rom_path

SKYTEMPLE_TEST_ROM_ENV = "SKYTEMPLE_TEST_ROM"
ASSET_PROJECT_PATH = Path("skytemple_files_test", "common", "fixtures", "asset_project")
ROM_COPY_PATH = Path("skytemple_files_test", "common", "fixtures", "rom_copy.nds")


def load_rom_path() -> Path:
if SKYTEMPLE_TEST_ROM_ENV in os.environ and os.environ[SKYTEMPLE_TEST_ROM_ENV] != "":
return Path(os.environ[SKYTEMPLE_TEST_ROM_ENV])
else:
raise SkipTest("No ROM file provided or ROM not found.")


def copy_rom_to_temp_file() -> Path:
"""
Copies the provided ROM to a temporary file for testing.
Expand Down Expand Up @@ -74,8 +67,14 @@ def test_save_file_extracted_rom_dir(self):
extracted_rom_dir = Path(ASSET_PROJECT_PATH, "extracted_rom")
expected_file_path = Path(extracted_rom_dir, rom_path)
try:
project.save_file(FileType.WAZA_P, rom_path, file_data,
skip_save_to_rom=True, skip_save_to_project_dir=True, extracted_rom_dir=extracted_rom_dir)
project.save_file(
FileType.WAZA_P,
rom_path,
file_data,
skip_save_to_rom=True,
skip_save_to_project_dir=True,
extracted_rom_dir=extracted_rom_dir,
)

self.assertTrue(expected_file_path.exists())
finally:
Expand Down
Loading

0 comments on commit e8315f7

Please sign in to comment.