diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 9ebe4ab4e42..162e04f9f5e 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -1,10 +1,14 @@ use assert_matches::assert_matches; use blockifier::execution::contract_class::ContractClass; use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; -use mempool_test_utils::starknet_api_test_utils::declare_tx; +use mempool_test_utils::starknet_api_test_utils::declare_tx as rpc_declare_tx; use rstest::{fixture, rstest}; use starknet_api::core::CompiledClassHash; -use starknet_api::rpc_transaction::{RpcDeclareTransaction, RpcTransaction}; +use starknet_api::rpc_transaction::{ + RpcDeclareTransaction, + RpcDeclareTransactionV3, + RpcTransaction, +}; use starknet_sierra_compile::errors::CompilationUtilError; use crate::compilation::GatewayCompiler; @@ -15,17 +19,24 @@ fn gateway_compiler() -> GatewayCompiler { GatewayCompiler { config: Default::default() } } +#[fixture] +fn declare_tx_v3() -> RpcDeclareTransactionV3 { + assert_matches!( + rpc_declare_tx(), + RpcTransaction::Declare(RpcDeclareTransaction::V3(declare_tx)) => declare_tx + ) +} + // TODO(Arni): Redesign this test once the compiler is passed with dependancy injection. #[rstest] -fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) { - let mut tx = assert_matches!( - declare_tx(), - RpcTransaction::Declare(RpcDeclareTransaction::V3(tx)) => tx - ); - let expected_hash = tx.compiled_class_hash; +fn test_compile_contract_class_compiled_class_hash_mismatch( + gateway_compiler: GatewayCompiler, + mut declare_tx_v3: RpcDeclareTransactionV3, +) { + let expected_hash = declare_tx_v3.compiled_class_hash; let wrong_supplied_hash = CompiledClassHash::default(); - tx.compiled_class_hash = wrong_supplied_hash; - let declare_tx = RpcDeclareTransaction::V3(tx); + declare_tx_v3.compiled_class_hash = wrong_supplied_hash; + let declare_tx = RpcDeclareTransaction::V3(declare_tx_v3); let result = gateway_compiler.process_declare_tx(&declare_tx); assert_matches!( @@ -36,14 +47,14 @@ fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: Ga } #[rstest] -fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { - let mut tx = assert_matches!( - declare_tx(), - RpcTransaction::Declare(RpcDeclareTransaction::V3(tx)) => tx - ); +fn test_compile_contract_class_bad_sierra( + gateway_compiler: GatewayCompiler, + mut declare_tx_v3: RpcDeclareTransactionV3, +) { // Truncate the sierra program to trigger an error. - tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); - let declare_tx = RpcDeclareTransaction::V3(tx); + declare_tx_v3.contract_class.sierra_program = + declare_tx_v3.contract_class.sierra_program[..100].to_vec(); + let declare_tx = RpcDeclareTransaction::V3(declare_tx_v3); let result = gateway_compiler.process_declare_tx(&declare_tx); assert_matches!( @@ -55,16 +66,17 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { } #[rstest] -fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) { - let declare_tx = assert_matches!( - declare_tx(), - RpcTransaction::Declare(declare_tx) => declare_tx - ); - let RpcDeclareTransaction::V3(declare_tx_v3) = &declare_tx; +fn test_process_declare_tx_success( + gateway_compiler: GatewayCompiler, + declare_tx_v3: RpcDeclareTransactionV3, +) { let contract_class = &declare_tx_v3.contract_class; + let sierra_program_length = contract_class.sierra_program.len(); + let abi_length = contract_class.abi.len(); + let declare_tx = RpcDeclareTransaction::V3(declare_tx_v3); let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); - assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); - assert_eq!(class_info.abi_length(), contract_class.abi.len()); + assert_eq!(class_info.sierra_program_length(), sierra_program_length); + assert_eq!(class_info.abi_length(), abi_length); }