Skip to content

Commit

Permalink
test: added test on addMembers() (#17)
Browse files Browse the repository at this point in the history
* completed addMembers()

* completed removeMember()
  • Loading branch information
jimmychu0807 authored Jan 12, 2025
1 parent 903681d commit 5e6ae75
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 25 deletions.
53 changes: 28 additions & 25 deletions src/SemaphoreMSAValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
}

// Errors
error CannotRemoveOwner();
error InvalidCommitment();
error InvalidThreshold();
error MaxMemberReached();
error MemberCntReachesThreshold(address account);
error InvalidCommitment(address account);
error InvalidThreshold(address account);
error MaxMemberReached(address account);
error CommitmentsNotUnique();
error MemberNotExists(address account, uint256 cmt);
error IsMemberAlready(address acount, uint256 cmt);
Expand All @@ -54,7 +54,7 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
// Events
event ModuleInitialized(address indexed account);
event ModuleUninitialized(address indexed account);
event AddedMember(address indexed, uint256 indexed commitment);
event AddedMembers(address indexed, uint256 indexed length);
event RemovedMember(address indexed, uint256 indexed commitment);
event ThresholdSet(address indexed account, uint8 indexed threshold);
event InitiatedTx(address indexed account, uint256 indexed seq, bytes32 indexed txHash);
Expand Down Expand Up @@ -116,14 +116,14 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {
uint256[] memory cmts = cmtBytes.convertToCmts();

// Check the relation between threshold and ownersLen are valid
if (cmts.length > MAX_MEMBERS) revert MaxMemberReached();
if (cmts.length < threshold) revert InvalidThreshold();
if (cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);
if (cmts.length < threshold) revert InvalidThreshold(account);

// Check no duplicate commitment and no `0`
cmts.insertionSort();
if (!cmts.isSortedAndUniquified()) revert CommitmentsNotUnique();
(bool found,) = cmts.searchSorted(uint256(0));
if (found) revert InvalidCommitment();
if (found) revert InvalidCommitment(account);

// Completed all checks by this point. Write to the storage.
thresholds[account] = threshold;
Expand Down Expand Up @@ -159,41 +159,44 @@ contract SemaphoreMSAValidator is ERC7579ValidatorBase {

function setThreshold(uint8 newThreshold) external moduleInstalled {
address account = msg.sender;
if (newThreshold == 0 || newThreshold > memberCount(account)) revert InvalidThreshold();
if (newThreshold == 0 || newThreshold > memberCount(account)) {
revert InvalidThreshold(account);
}

thresholds[account] = newThreshold;
emit ThresholdSet(account, newThreshold);
}

function addMember(uint256 cmt) external moduleInstalled {
function addMembers(uint256[] calldata cmts) external moduleInstalled {
address account = msg.sender;
// 0. check the module is initialized for the acct
// 1. check newOwner != 0
// 2. check ownerCount < MAX_MEMBERS
// 3. cehck owner not existed yet
if (cmt == uint256(0)) revert InvalidCommitment();
if (memberCount(account) == MAX_MEMBERS) revert MaxMemberReached();

uint256 groupId = groupMapping[account];

if (groups.hasMember(groupId, cmt)) revert IsMemberAlready(account, cmt);
if (memberCount(account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);

semaphore.addMember(groupId, cmt);
for (uint256 i = 0; i < cmts.length; ++i) {
if (cmts[i] == uint256(0)) revert InvalidCommitment(account);
if (groups.hasMember(groupId, cmts[i])) revert IsMemberAlready(account, cmts[i]);
}

emit AddedMember(account, cmt);
semaphore.addMembers(groupId, cmts);
emit AddedMembers(account, cmts.length);
}

function removeMember(uint256 cmt) external moduleInstalled {
function removeMember(
uint256 cmt,
uint256[] calldata merkleProofSiblings
)
external
moduleInstalled
{
address account = msg.sender;

if (memberCount(account) == thresholds[account]) revert CannotRemoveOwner();
if (memberCount(account) == thresholds[account]) revert MemberCntReachesThreshold(account);

uint256 groupId = groupMapping[account];
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

//TODO: add the 3rd param: merkleProofSiblings. Now I set it to 0 to make it passes the
// compiler
semaphore.removeMember(groupId, cmt, new uint256[](0));
semaphore.removeMember(groupId, cmt, merkleProofSiblings);

emit RemovedMember(account, cmt);
}
Expand Down
41 changes: 41 additions & 0 deletions test/SemaphoreMSAValidator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,47 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
assertEq(validationData, VALIDATION_SUCCESS);
}

function test_addMembers() public setupSmartAcctWithMembersThreshold(1, 1) {
Identity newIdentity = $users[1].identity;
uint256 newCommitment = newIdentity.commitment();

// Compose the userOp
PackedUserOperation memory userOp = getEmptyUserOperation();
userOp.sender = smartAcct.account;
userOp.callData = getTestUserOpCallData(
0,
address(semaphoreValidator),
abi.encodeWithSelector(SemaphoreMSAValidator.initiateTx.selector)
);
bytes32 userOpHash = bytes32(keccak256("userOpHash"));
userOp.signature = newIdentity.signHash(userOpHash);

// expecting the vm to revert
vm.expectRevert(
abi.encodeWithSelector(
SemaphoreMSAValidator.MemberNotExists.selector, smartAcct.account, newCommitment
)
);
semaphoreValidator.validateUserOp(userOp, userOpHash);

// Now we add the new member
uint256[] memory newMembers = new uint256[](1);
newMembers[0] = newCommitment;

vm.prank(smartAcct.account);
semaphoreValidator.addMembers(newMembers);

// Test: the userOp should pass
uint256 validationData = ERC7579ValidatorBase.ValidationData.unwrap(
semaphoreValidator.validateUserOp(userOp, userOpHash)
);
assertEq(validationData, VALIDATION_SUCCESS);
}

function test_removeMember() public setupSmartAcctWithMembersThreshold(2, 1) {
revert("to be implemented");

Check warning on line 250 in test/SemaphoreMSAValidator.t.sol

View workflow job for this annotation

GitHub Actions / lint

GC: Use Custom Errors instead of revert statements
}

function _getSemaphoreValidatorUserOpData(
Identity id,
bytes memory callData,
Expand Down

0 comments on commit 5e6ae75

Please sign in to comment.