Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
AnonymousRandomPerson committed Aug 16, 2024
1 parent 9010b39 commit 74ca388
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
21 changes: 13 additions & 8 deletions skytemple_files/common/file_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,12 @@ def get_files_in_folder(folder_path, folder: Folder) -> list[Path]:
folder_files.extend(get_files_in_folder(Path(folder_path, subfolder[0]), subfolder[1]))
return folder_files

rom_files = get_files_in_folder(Path(), self.rom.filenames)
if self.rom is None:
rom_files = []
else:
rom_files = get_files_in_folder(Path(), self.rom.filenames)

data_handler: type[DataHandler[T]]
for data_handler in self._load_data_handlers():
for rom_file in rom_files:
asset_specs = data_handler.asset_specs(rom_file)
Expand Down Expand Up @@ -281,7 +285,7 @@ def open_file(
if assets is None and not load_from_rom:
assets = extract_assets(self.load_assets(handler, path_to_rom_obj))

if load_from_rom or len(assets) < 1:
if load_from_rom or assets is None or len(assets) < 1:
# Force ROM deserialization if no assets exist.
return handler.deserialize(self._file_storage.get_from_rom(path_to_rom_obj), **kwargs)

Expand All @@ -304,7 +308,7 @@ def save_file(
*,
skip_save_to_rom: bool = False,
skip_save_to_project_dir: bool = False,
extracted_rom_dir: Path = None,
extracted_rom_dir: Path | None = None,
**kwargs: OptionalKwargs,
):
"""
Expand Down Expand Up @@ -385,7 +389,8 @@ def _load_data_handlers() -> list[type[DataHandler[T]]]:
]

def _enrich_static_data(self):
RomDataLoader(self.rom).load_into(self.static_data)
if self.rom is not None:
RomDataLoader(self.rom).load_into(self.static_data)

def _load_extra_skypatches(self):
raise NotImplementedError()
Expand All @@ -410,8 +415,8 @@ class SkyTempleProjectFileStorage(FileStorage):
rom_path: Path
project_dir: Path
rom: NintendoDSRom
rom_hashes: dict[Path, AssetHash]
asset_hashes: dict[Path, AssetHash]
rom_hashes: dict[Path, AssetHash | None]
asset_hashes: dict[Path, AssetHash | None]

def __init__(self, rom_path: Path, project_dir: Path):
self.rom_path = rom_path
Expand Down Expand Up @@ -484,7 +489,7 @@ def _save_asset_hash(self, path: Path, data: bytes):
self.asset_hashes[path] = self.hash_from_bytes(data)
self._save_hash_file(ASSET_HASHES_FILE, self.asset_hashes)

def _read_hash_file(self, hash_file_name: str) -> dict[Path, AssetHash]:
def _read_hash_file(self, hash_file_name: str) -> dict[Path, AssetHash | None]:
hashes: dict[Path, AssetHash | None] = defaultdict(lambda: None)
hash_file_path = Path(self.project_dir, hash_file_name)
if hash_file_path.exists():
Expand All @@ -497,7 +502,7 @@ def _read_hash_file(self, hash_file_name: str) -> dict[Path, AssetHash]:
print(f"Malformed hash file {hash_file_name} detected. Skipping line: {line}")
return hashes

def _save_hash_file(self, hash_file_name: str, hashes: dict[Path, AssetHash]):
def _save_hash_file(self, hash_file_name: str, hashes: dict[Path, AssetHash | None]):
lines = []
for file_name in sorted(hashes.keys()):
lines.append(f"{hashes[file_name]} {file_name}\n")
Expand Down
4 changes: 3 additions & 1 deletion skytemple_files/common/types/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_asset_from_spec(self, spec: AssetSpec) -> Asset:
return asset

@abc.abstractmethod
def store_asset(self, path: Path, for_rom_path: Path, data_asset: bytes, custom_project_dir: Path = None) -> bytes:
def store_asset(
self, path: Path, for_rom_path: Path, data_asset: bytes, custom_project_dir: Path | None = None
) -> bytes:
"""Store an asset file."""
...

Expand Down
2 changes: 1 addition & 1 deletion skytemple_files/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def serialize_enum_or_default(enum_class: type[Enum], value: int) -> str | int:
return value


def deserialize_enum_or_default(enum_class: type[Enum], value: str | int) -> Enum | int:
def deserialize_enum_or_default(enum_class: type[Enum], value: str | int) -> int:
"""
If the value is a string, returns the corresponding enum value.
If the value is an integer, returns the integer back.
Expand Down
19 changes: 11 additions & 8 deletions skytemple_files/data/waza_p/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Sequence
import json

from range_typed_integers import u32
from range_typed_integers import u8, u32, u16

from skytemple_files.common.impl_cfg import get_implementation_type, ImplementationType
from skytemple_files.common.ppmdu_config.data import Pmd2StringBlock
Expand All @@ -41,6 +41,8 @@
WazaMoveRangeSettingsProtocol,
MoveLearnsetProtocol,
WazaMoveCategory,
_PokeType,
_WazaMoveCategory,
)


Expand Down Expand Up @@ -189,7 +191,7 @@ def serialize_asset(
"move_id": move.move_id,
"message_id": move.message_id,
}
if strings is not None:
if strings is not None and string_blocks is not None:
move_name_index = string_blocks["Move Names"].begin + move.move_id
if move_name_index <= string_blocks["Move Names"].end:
serialized_move["name"] = strings.strings[move_name_index]
Expand All @@ -203,12 +205,13 @@ def serialize_asset(
strings, string_blocks = cls._get_strings_from_rom(["Move Names", "Pokemon Names"], **kwargs)

def get_move_value(move_id: u32) -> str | u32:
if strings is None:
if strings is None or string_blocks is None:
return move_id
else:
move_name_idx = string_blocks["Move Names"].begin + move_id
if move_name_idx <= string_blocks["Move Names"].end:
return strings.strings[move_name_idx]
return move_id

for i, learnset in enumerate(data.learnsets):
serialized_learnset = {
Expand All @@ -223,7 +226,7 @@ def get_move_value(move_id: u32) -> str | u32:
"egg_moves": [get_move_value(move_id) for move_id in learnset.egg_moves],
}

if strings is not None:
if strings is not None and string_blocks is not None:
serialized_learnset["pokemon_name"] = strings.strings[string_blocks["Pokemon Names"].begin + i]

learnsets.append(serialized_learnset)
Expand Down Expand Up @@ -253,8 +256,8 @@ def deserialize_from_assets(
protocol.moves.append(move)

move.base_power = move_json["base_power"]
move.type = deserialize_enum_or_default(PokeType, move_json["type"])
move.category = deserialize_enum_or_default(WazaMoveCategory, move_json["category"])
move.type = _PokeType(deserialize_enum_or_default(PokeType, move_json["type"]))
move.category = _WazaMoveCategory(deserialize_enum_or_default(WazaMoveCategory, move_json["category"]))
move.settings_range = cls.get_range_settings_model()(bytes())
move.settings_range = cls._deserialize_move_range_settings(move_json["settings_range"])
move.settings_range_ai = cls._deserialize_move_range_settings(move_json["settings_range_ai"])
Expand Down Expand Up @@ -295,7 +298,7 @@ def get_move_id(move_value: str | u32) -> u32:
for learnset_json in learnsets:
learnset = cls.get_learnset_model()(
[
cls.get_level_up_model()(get_move_id(level_up["move_id"]), level_up["level_id"])
cls.get_level_up_model()(u16(get_move_id(level_up["move_id"])), level_up["level_id"])
for level_up in learnset_json["level_up_moves"]
],
[get_move_id(move_value) for move_value in learnset_json["tm_hm_moves"]],
Expand Down Expand Up @@ -330,7 +333,7 @@ def _get_strings_from_rom(
"""
Fetches blocks of strings from the ROM if it is specified in kwargs.
"""
if "rom_project" in kwargs:
if "rom_project" in kwargs and kwargs["rom_project"] is not None:
rom_project = kwargs["rom_project"]
# Default to English strings (NA/EU). Use Japanese strings if this is the JP ROM.
try:
Expand Down

0 comments on commit 74ca388

Please sign in to comment.