From 5e6ae75b0cb1b4af28f8f639eba64b81f0eeb863 Mon Sep 17 00:00:00 2001 From: Jimmy Chu <898091+jimmychu0807@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:23:11 +0800 Subject: [PATCH] test: added test on addMembers() (#17) * completed addMembers() * completed removeMember() --- src/SemaphoreMSAValidator.sol | 53 +++++++++++++++++--------------- test/SemaphoreMSAValidator.t.sol | 41 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/SemaphoreMSAValidator.sol b/src/SemaphoreMSAValidator.sol index c755199..36ae198 100644 --- a/src/SemaphoreMSAValidator.sol +++ b/src/SemaphoreMSAValidator.sol @@ -31,10 +31,10 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase { } // Errors - error CannotRemoveOwner(); - error InvalidCommitment(); - error InvalidThreshold(); - error MaxMemberReached(); + error MemberCntReachesThreshold(address account); + error InvalidCommitment(address account); + error InvalidThreshold(address account); + error MaxMemberReached(address account); error CommitmentsNotUnique(); error MemberNotExists(address account, uint256 cmt); error IsMemberAlready(address acount, uint256 cmt); @@ -54,7 +54,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase { // Events event ModuleInitialized(address indexed account); event ModuleUninitialized(address indexed account); - event AddedMember(address indexed, uint256 indexed commitment); + event AddedMembers(address indexed, uint256 indexed length); event RemovedMember(address indexed, uint256 indexed commitment); event ThresholdSet(address indexed account, uint8 indexed threshold); event InitiatedTx(address indexed account, uint256 indexed seq, bytes32 indexed txHash); @@ -116,14 +116,14 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase { uint256[] memory cmts = cmtBytes.convertToCmts(); // Check the relation between threshold and ownersLen are valid - if (cmts.length > MAX_MEMBERS) revert MaxMemberReached(); - if (cmts.length < threshold) revert InvalidThreshold(); + if (cmts.length > MAX_MEMBERS) revert MaxMemberReached(account); + if (cmts.length < threshold) revert InvalidThreshold(account); // Check no duplicate commitment and no `0` cmts.insertionSort(); if (!cmts.isSortedAndUniquified()) revert CommitmentsNotUnique(); (bool found,) = cmts.searchSorted(uint256(0)); - if (found) revert InvalidCommitment(); + if (found) revert InvalidCommitment(account); // Completed all checks by this point. Write to the storage. thresholds[account] = threshold; @@ -159,41 +159,44 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase { function setThreshold(uint8 newThreshold) external moduleInstalled { address account = msg.sender; - if (newThreshold == 0 || newThreshold > memberCount(account)) revert InvalidThreshold(); + if (newThreshold == 0 || newThreshold > memberCount(account)) { + revert InvalidThreshold(account); + } thresholds[account] = newThreshold; emit ThresholdSet(account, newThreshold); } - function addMember(uint256 cmt) external moduleInstalled { + function addMembers(uint256[] calldata cmts) external moduleInstalled { address account = msg.sender; - // 0. check the module is initialized for the acct - // 1. check newOwner != 0 - // 2. check ownerCount < MAX_MEMBERS - // 3. cehck owner not existed yet - if (cmt == uint256(0)) revert InvalidCommitment(); - if (memberCount(account) == MAX_MEMBERS) revert MaxMemberReached(); - uint256 groupId = groupMapping[account]; - if (groups.hasMember(groupId, cmt)) revert IsMemberAlready(account, cmt); + if (memberCount(account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account); - semaphore.addMember(groupId, cmt); + for (uint256 i = 0; i < cmts.length; ++i) { + if (cmts[i] == uint256(0)) revert InvalidCommitment(account); + if (groups.hasMember(groupId, cmts[i])) revert IsMemberAlready(account, cmts[i]); + } - emit AddedMember(account, cmt); + semaphore.addMembers(groupId, cmts); + emit AddedMembers(account, cmts.length); } - function removeMember(uint256 cmt) external moduleInstalled { + function removeMember( + uint256 cmt, + uint256[] calldata merkleProofSiblings + ) + external + moduleInstalled + { address account = msg.sender; - if (memberCount(account) == thresholds[account]) revert CannotRemoveOwner(); + if (memberCount(account) == thresholds[account]) revert MemberCntReachesThreshold(account); uint256 groupId = groupMapping[account]; if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt); - //TODO: add the 3rd param: merkleProofSiblings. Now I set it to 0 to make it passes the - // compiler - semaphore.removeMember(groupId, cmt, new uint256[](0)); + semaphore.removeMember(groupId, cmt, merkleProofSiblings); emit RemovedMember(account, cmt); } diff --git a/test/SemaphoreMSAValidator.t.sol b/test/SemaphoreMSAValidator.t.sol index 6d0ae30..25a892e 100644 --- a/test/SemaphoreMSAValidator.t.sol +++ b/test/SemaphoreMSAValidator.t.sol @@ -209,6 +209,47 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test { assertEq(validationData, VALIDATION_SUCCESS); } + function test_addMembers() public setupSmartAcctWithMembersThreshold(1, 1) { + Identity newIdentity = $users[1].identity; + uint256 newCommitment = newIdentity.commitment(); + + // Compose the userOp + PackedUserOperation memory userOp = getEmptyUserOperation(); + userOp.sender = smartAcct.account; + userOp.callData = getTestUserOpCallData( + 0, + address(semaphoreValidator), + abi.encodeWithSelector(SemaphoreMSAValidator.initiateTx.selector) + ); + bytes32 userOpHash = bytes32(keccak256("userOpHash")); + userOp.signature = newIdentity.signHash(userOpHash); + + // expecting the vm to revert + vm.expectRevert( + abi.encodeWithSelector( + SemaphoreMSAValidator.MemberNotExists.selector, smartAcct.account, newCommitment + ) + ); + semaphoreValidator.validateUserOp(userOp, userOpHash); + + // Now we add the new member + uint256[] memory newMembers = new uint256[](1); + newMembers[0] = newCommitment; + + vm.prank(smartAcct.account); + semaphoreValidator.addMembers(newMembers); + + // Test: the userOp should pass + uint256 validationData = ERC7579ValidatorBase.ValidationData.unwrap( + semaphoreValidator.validateUserOp(userOp, userOpHash) + ); + assertEq(validationData, VALIDATION_SUCCESS); + } + + function test_removeMember() public setupSmartAcctWithMembersThreshold(2, 1) { + revert("to be implemented"); + } + function _getSemaphoreValidatorUserOpData( Identity id, bytes memory callData,