diff --git a/client/src/ledger_app_clients/ethereum/eip712/InputData.py b/client/src/ledger_app_clients/ethereum/eip712/InputData.py index 966c47b6c..136de2852 100644 --- a/client/src/ledger_app_clients/ethereum/eip712/InputData.py +++ b/client/src/ledger_app_clients/ethereum/eip712/InputData.py @@ -17,6 +17,7 @@ # global variables app_client: EthAppClient = None filtering_paths: dict = {} +filtering_tokens: list[dict] = list() current_path: list[str] = list() sig_ctx: dict[str, Any] = {} @@ -194,6 +195,18 @@ def encode_bytes_dyn(value: str, typesize: int) -> bytes: encoding_functions[EIP712FieldType.DYN_BYTES] = encode_bytes_dyn +def send_filtering_token(token_idx: int): + assert token_idx < len(filtering_tokens) + if len(filtering_tokens[token_idx]) > 0: + token = filtering_tokens[token_idx] + if not token["sent"]: + app_client.provide_token_metadata(token["ticker"], + bytes.fromhex(token["addr"][2:]), + token["decimals"], + token["chain_id"]) + token["sent"] = True + + def send_struct_impl_field(value, field): # Something wrong happened if this triggers if isinstance(value, list) or (field["enum"] == EIP712FieldType.CUSTOM): @@ -204,16 +217,19 @@ def send_struct_impl_field(value, field): if filtering_paths: path = ".".join(current_path) if path in filtering_paths.keys(): - if filtering_paths[path]["type"] == "amount_join_token": - send_filtering_amount_join_token(filtering_paths[path]["token"]) - elif filtering_paths[path]["type"] == "amount_join_value": + if filtering_paths[path]["type"].startswith("amount_join_"): if "token" in filtering_paths[path].keys(): - token = filtering_paths[path]["token"] + token_idx = filtering_paths[path]["token"] + send_filtering_token(token_idx) else: # Permit (ERC-2612) - token = 0xff - send_filtering_amount_join_value(token, - filtering_paths[path]["name"]) + send_filtering_token(0) + token_idx = 0xff + if filtering_paths[path]["type"].endswith("_token"): + send_filtering_amount_join_token(token_idx) + elif filtering_paths[path]["type"].endswith("_value"): + send_filtering_amount_join_value(token_idx, + filtering_paths[path]["name"]) elif filtering_paths[path]["type"] == "datetime": send_filtering_datetime(filtering_paths[path]["name"]) elif filtering_paths[path]["type"] == "raw": @@ -351,17 +367,20 @@ def send_filtering_raw(display_name): def prepare_filtering(filtr_data, message): global filtering_paths + global filtering_tokens if "fields" in filtr_data: filtering_paths = filtr_data["fields"] else: filtering_paths = {} + if "tokens" in filtr_data: - for token in filtr_data["tokens"]: - app_client.provide_token_metadata(token["ticker"], - bytes.fromhex(token["addr"][2:]), - token["decimals"], - token["chain_id"]) + filtering_tokens = filtr_data["tokens"] + for token in filtering_tokens: + if len(token) > 0: + token["sent"] = False + else: + filtering_tokens = [] def handle_optional_domain_values(domain):