diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index a281e2d18..57cf63ca3 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -100,12 +100,13 @@ def __init__( # Get Action Names mapping_file = f"{self.hparams.data_dir}/{mapping_file_name}" - file_ptr = open(mapping_file, "r") - actions = file_ptr.read().split("\n")[:-1] - file_ptr.close() actions_dict = dict() - for a in actions: - actions_dict[a.split()[1]] = int(a.split()[0]) + with open(mapping_file, "r") as file_ptr: + actions = file_ptr.readlines() + actions = [a.strip() for a in actions] # drop leading/trailing whitespace + for a in actions: + parts = a.split() # split on any number of whitespace + actions_dict[parts[1]] = int(parts[0]) self.class_ids = list(actions_dict.values()) self.classes = list(actions_dict.keys())