diff --git a/src/block_hash/state_diff_hash.rs b/src/block_hash/state_diff_hash.rs index b3cfd19..b868966 100644 --- a/src/block_hash/state_diff_hash.rs +++ b/src/block_hash/state_diff_hash.rs @@ -39,8 +39,8 @@ fn chain_deployed_contracts( mut hash_chain: HashChain, ) -> HashChain { hash_chain = hash_chain.chain(&deployed_contracts.len().into()); - for (address, class_hash) in deployed_contracts { - hash_chain = hash_chain.chain(address).chain(class_hash); + for (address, class_hash) in sorted_index_map(deployed_contracts) { + hash_chain = hash_chain.chain(&address.0).chain(&class_hash); } hash_chain } @@ -52,8 +52,8 @@ fn chain_declared_classes( mut hash_chain: HashChain, ) -> HashChain { hash_chain = hash_chain.chain(&declared_classes.len().into()); - for (class_hash, compiled_class_hash) in declared_classes { - hash_chain = hash_chain.chain(class_hash).chain(&compiled_class_hash.0) + for (class_hash, compiled_class_hash) in sorted_index_map(declared_classes) { + hash_chain = hash_chain.chain(&class_hash).chain(&compiled_class_hash.0) } hash_chain } @@ -63,9 +63,11 @@ fn chain_deprecated_declared_classes( deprecated_declared_classes: &[ClassHash], hash_chain: HashChain, ) -> HashChain { + let mut sorted_deprecated_declared_classes = deprecated_declared_classes.to_vec(); + sorted_deprecated_declared_classes.sort_unstable(); hash_chain - .chain(&deprecated_declared_classes.len().into()) - .chain_iter(deprecated_declared_classes.iter().map(|class_hash| &class_hash.0)) + .chain(&sorted_deprecated_declared_classes.len().into()) + .chain_iter(sorted_deprecated_declared_classes.iter().map(|class_hash| &class_hash.0)) } // Chains: [number_of_updated_contracts, @@ -77,11 +79,11 @@ fn chain_storage_diffs( mut hash_chain: HashChain, ) -> HashChain { hash_chain = hash_chain.chain(&storage_diffs.len().into()); - for (contract_address, key_value_map) in storage_diffs { - hash_chain = hash_chain.chain(contract_address); + for (contract_address, key_value_map) in sorted_index_map(storage_diffs) { + hash_chain = hash_chain.chain(&contract_address); hash_chain = hash_chain.chain(&key_value_map.len().into()); - for (key, value) in key_value_map { - hash_chain = hash_chain.chain(key).chain(value); + for (key, value) in sorted_index_map(&key_value_map) { + hash_chain = hash_chain.chain(&key).chain(&value); } } hash_chain @@ -92,9 +94,16 @@ fn chain_storage_diffs( // ] fn chain_nonces(nonces: &IndexMap, mut hash_chain: HashChain) -> HashChain { hash_chain = hash_chain.chain(&nonces.len().into()); - for (contract_address, nonce) in nonces { - hash_chain = hash_chain.chain(contract_address); - hash_chain = hash_chain.chain(nonce); + for (contract_address, nonce) in sorted_index_map(nonces) { + hash_chain = hash_chain.chain(&contract_address); + hash_chain = hash_chain.chain(&nonce); } hash_chain } + +// Returns a clone of the map, sorted by keys. +fn sorted_index_map(map: &IndexMap) -> IndexMap { + let mut sorted_map = map.clone(); + sorted_map.sort_unstable_keys(); + sorted_map +} diff --git a/src/block_hash/state_diff_hash_test.rs b/src/block_hash/state_diff_hash_test.rs index a938a16..9bf795f 100644 --- a/src/block_hash/state_diff_hash_test.rs +++ b/src/block_hash/state_diff_hash_test.rs @@ -1,7 +1,11 @@ use indexmap::indexmap; -use crate::block_hash::state_diff_hash::calculate_state_diff_hash; +use crate::block_hash::state_diff_hash::{ + calculate_state_diff_hash, chain_declared_classes, chain_deployed_contracts, + chain_deprecated_declared_classes, chain_nonces, chain_storage_diffs, +}; use crate::core::{ClassHash, CompiledClassHash, Nonce, StateDiffCommitment}; +use crate::crypto::utils::HashChain; use crate::hash::{PoseidonHash, StarkFelt}; use crate::state::ThinStateDiff; @@ -41,3 +45,89 @@ fn test_state_diff_hash_regression() { assert_eq!(expected_hash, calculate_state_diff_hash(&state_diff)); } + +#[test] +fn test_sorting_deployed_contracts() { + let deployed_contracts_0 = indexmap! { + 0u64.into() => ClassHash(3u64.into()), + 1u64.into() => ClassHash(2u64.into()), + }; + let deployed_contracts_1 = indexmap! { + 1u64.into() => ClassHash(2u64.into()), + 0u64.into() => ClassHash(3u64.into()), + }; + assert_eq!( + chain_deployed_contracts(&deployed_contracts_0, HashChain::new()).get_poseidon_hash(), + chain_deployed_contracts(&deployed_contracts_1, HashChain::new()).get_poseidon_hash(), + ); +} + +#[test] +fn test_sorting_declared_classes() { + let declared_classes_0 = indexmap! { + ClassHash(0u64.into()) => CompiledClassHash(3u64.into()), + ClassHash(1u64.into()) => CompiledClassHash(2u64.into()), + }; + let declared_classes_1 = indexmap! { + ClassHash(1u64.into()) => CompiledClassHash(2u64.into()), + ClassHash(0u64.into()) => CompiledClassHash(3u64.into()), + }; + assert_eq!( + chain_declared_classes(&declared_classes_0, HashChain::new()).get_poseidon_hash(), + chain_declared_classes(&declared_classes_1, HashChain::new()).get_poseidon_hash(), + ); +} + +#[test] +fn test_sorting_deprecated_declared_classes() { + let deprecated_declared_classes_0 = vec![ClassHash(0u64.into()), ClassHash(1u64.into())]; + let deprecated_declared_classes_1 = vec![ClassHash(1u64.into()), ClassHash(0u64.into())]; + assert_eq!( + chain_deprecated_declared_classes(&deprecated_declared_classes_0, HashChain::new()) + .get_poseidon_hash(), + chain_deprecated_declared_classes(&deprecated_declared_classes_1, HashChain::new()) + .get_poseidon_hash(), + ); +} + +#[test] +fn test_sorting_storage_diffs() { + let storage_diffs_0 = indexmap! { + 0u64.into() => indexmap! { + 1u64.into() => 2u64.into(), + 3u64.into() => 4u64.into(), + }, + 5u64.into() => indexmap! { + 6u64.into() => 7u64.into(), + }, + }; + let storage_diffs_1 = indexmap! { + 5u64.into() => indexmap! { + 6u64.into() => 7u64.into(), + }, + 0u64.into() => indexmap! { + 3u64.into() => 4u64.into(), + 1u64.into() => 2u64.into(), + }, + }; + assert_eq!( + chain_storage_diffs(&storage_diffs_0, HashChain::new()).get_poseidon_hash(), + chain_storage_diffs(&storage_diffs_1, HashChain::new()).get_poseidon_hash(), + ); +} + +#[test] +fn test_sorting_nonces() { + let nonces_0 = indexmap! { + 0u64.into() => Nonce(3u64.into()), + 1u64.into() => Nonce(2u64.into()), + }; + let nonces_1 = indexmap! { + 1u64.into() => Nonce(2u64.into()), + 0u64.into() => Nonce(3u64.into()), + }; + assert_eq!( + chain_nonces(&nonces_0, HashChain::new()).get_poseidon_hash(), + chain_nonces(&nonces_1, HashChain::new()).get_poseidon_hash(), + ); +}