Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wip): completed removeMember() and test case #18

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"devDependencies": {
"@rhinestone/modulekit": "~0.5.4",
"@semaphore-protocol/contracts": "github:jimmychu0807/semaphore#identity-cli&path:/packages/contracts/contracts",
"@semaphore-protocol/core": "github:jimmychu0807/semaphore#identity-cli&path:/packages/core",
"@semaphore-protocol/identity": "github:jimmychu0807/semaphore#identity-cli&path:/packages/identity",
"@semaphore-protocol/proof": "github:jimmychu0807/semaphore#identity-cli&path:/packages/proof",
"@semaphore-protocol/group": "github:jimmychu0807/semaphore#identity-cli&path:/packages/group",
"poseidon-solidity": "github:chancehudson/poseidon-solidity#main",
"rimraf": "^5.0.5",
"solady": "^0.0.287"
Expand All @@ -47,7 +47,7 @@
"prepack": "pnpm install && bash ./shell/prepare-artifacts.sh",
"prettier:check": "prettier --no-error-on-unmatched-pattern -c \"{src,test,script}/**/*.{json,md,svg,yml}\"",
"prettier:write": "prettier --no-error-on-unmatched-pattern -w \"{src,test,script}/**/*.{json,md,svg,yml}\"",
"test": "COMPLIANCE=true forge test --ffi",
"test": "forge test --ffi",
"test:lite": "FOUNDRY_PROFILE=lite forge test",
"test:optimized": "pnpm run build:optimized && FOUNDRY_PROFILE=test-optimized forge test"
},
Expand Down
61 changes: 20 additions & 41 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 12 additions & 13 deletions src/SemaphoreMSAValidator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
ISemaphoreGroups public groups;
mapping(address account => uint256 groupId) public groupMapping;
mapping(address account => uint8 threshold) public thresholds;
mapping(address account => uint8 count) public memberCount;

// smart account -> hash(call(params)) -> valid proof count
mapping(address account => mapping(bytes32 txHash => ExtCallCount callDataCount)) public
Expand Down Expand Up @@ -133,6 +134,7 @@

// Add members to the group
semaphore.addMembers(groupId, cmts);
memberCount[account] = uint8(cmts.length);

emit ModuleInitialized(account);
}
Expand All @@ -143,6 +145,7 @@
delete thresholds[account];
delete groupMapping[account];
delete acctSeqNum[account];
delete memberCount[account];

//TODO: what is a good way to delete entries associated with `acctTxCount[account]`,
// The following line will make the compiler fail.
Expand All @@ -151,15 +154,9 @@
emit ModuleUninitialized(account);
}

function memberCount(address account) public view returns (uint8 cnt) {
// account doesn't belong to a semaphore group. We return 0
if (thresholds[account] == 0) return 0;
cnt = uint8(groups.getMerkleTreeSize(groupMapping[account]));
}

function setThreshold(uint8 newThreshold) external moduleInstalled {
address account = msg.sender;
if (newThreshold == 0 || newThreshold > memberCount(account)) {
if (newThreshold == 0 || newThreshold > memberCount[account]) {
revert InvalidThreshold(account);
}

Expand All @@ -171,14 +168,16 @@
address account = msg.sender;
uint256 groupId = groupMapping[account];

if (memberCount(account) + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);
if (memberCount[account] + cmts.length > MAX_MEMBERS) revert MaxMemberReached(account);

for (uint256 i = 0; i < cmts.length; ++i) {
if (cmts[i] == uint256(0)) revert InvalidCommitment(account);
if (groups.hasMember(groupId, cmts[i])) revert IsMemberAlready(account, cmts[i]);
}

semaphore.addMembers(groupId, cmts);
memberCount[account] += uint8(cmts.length);

emit AddedMembers(account, cmts.length);
}

Expand All @@ -191,12 +190,13 @@
{
address account = msg.sender;

if (memberCount(account) == thresholds[account]) revert MemberCntReachesThreshold(account);
if (memberCount[account] == thresholds[account]) revert MemberCntReachesThreshold(account);

uint256 groupId = groupMapping[account];
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

semaphore.removeMember(groupId, cmt, merkleProofSiblings);
memberCount[account] -= 1;

emit RemovedMember(account, cmt);
}
Expand Down Expand Up @@ -347,7 +347,7 @@
uint256 cmt = Identity.getCommitment(pubKey);
if (!groups.hasMember(groupId, cmt)) revert MemberNotExists(account, cmt);

// We don't allow call to other contract.
// We don't allow call to other contracts.
address targetAddr = address(bytes20(userOp.callData[100:120]));
if (targetAddr != address(this)) revert NonValidatorCallBanned(targetAddr, address(this));

Expand All @@ -356,17 +356,16 @@
bytes memory valAndCallData = userOp.callData[120:];
bytes4 funcSel = bytes4(LibBytes.slice(valAndCallData, 32, 36));

// Allow only these few types on function calls to pass, and reject all other on-chain
// calls. They must be executed via `executeTx()` function.
// We only allow calls to `initiateTx()`, `signTx()`, and `executeTx()` to pass,
// and reject the rest.
if (_isAllowedSelector(funcSel)) return VALIDATION_SUCCESS;

revert NonAllowedSelector(account, funcSel);
}

function isValidSignatureWithSender(
address sender,

Check warning on line 366 in src/SemaphoreMSAValidator.sol

View workflow job for this annotation

GitHub Actions / lint

Variable "sender" is unused
bytes32 hash,

Check warning on line 367 in src/SemaphoreMSAValidator.sol

View workflow job for this annotation

GitHub Actions / lint

Variable "hash" is unused
bytes calldata signature

Check warning on line 368 in src/SemaphoreMSAValidator.sol

View workflow job for this annotation

GitHub Actions / lint

Variable "signature" is unused
)
external
view
Expand Down
45 changes: 42 additions & 3 deletions test/SemaphoreMSAValidator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { SemaphoreMSAValidator, ERC7579ValidatorBase } from "../src/SemaphoreMSA
import {
getEmptyUserOperation,
getEmptySemaphoreProof,
getGroupRmMerkleProof,
getTestUserOpCallData,
Identity,
IdentityLib
Expand Down Expand Up @@ -236,8 +237,14 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
uint256[] memory newMembers = new uint256[](1);
newMembers[0] = newCommitment;

vm.prank(smartAcct.account);
// Test: addMembers() is successfully executed
vm.startPrank(smartAcct.account);
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.AddedMembers(smartAcct.account, uint256(1));
semaphoreValidator.addMembers(newMembers);
vm.stopPrank();

assertEq(semaphoreValidator.memberCount(smartAcct.account), 2);

// Test: the userOp should pass
uint256 validationData = ERC7579ValidatorBase.ValidationData.unwrap(
Expand All @@ -246,8 +253,40 @@ contract SemaphoreValidatorUnitTest is RhinestoneModuleKit, Test {
assertEq(validationData, VALIDATION_SUCCESS);
}

function test_removeMember() public setupSmartAcctWithMembersThreshold(2, 1) {
revert("to be implemented");
function test_removeMember() public setupSmartAcctWithMembersThreshold(MEMBER_NUM, 1) {
uint256[] memory cmts = _getMemberCmts(MEMBER_NUM);
User storage rmUser = $users[0];
uint256 rmCmt = rmUser.identity.commitment();

(uint256[] memory merkleProof,) = getGroupRmMerkleProof(cmts, rmCmt);

// Test: remove member
vm.startPrank(smartAcct.account);
vm.expectEmit(true, true, true, true, address(semaphoreValidator));
emit SemaphoreMSAValidator.RemovedMember(smartAcct.account, rmCmt);
semaphoreValidator.removeMember(rmCmt, merkleProof);
vm.stopPrank();

assertEq(semaphoreValidator.memberCount(smartAcct.account), MEMBER_NUM - 1);

// Compose a UserOp
PackedUserOperation memory userOp = getEmptyUserOperation();
userOp.sender = smartAcct.account;
userOp.callData = getTestUserOpCallData(
0,
address(semaphoreValidator),
abi.encodeWithSelector(SemaphoreMSAValidator.initiateTx.selector)
);
bytes32 userOpHash = bytes32(keccak256("userOpHash"));
userOp.signature = rmUser.identity.signHash(userOpHash);

// Test: the userOp should fail and revert
vm.expectRevert(
abi.encodeWithSelector(
SemaphoreMSAValidator.MemberNotExists.selector, smartAcct.account, rmCmt
)
);
semaphoreValidator.validateUserOp(userOp, userOpHash);
}

function _getSemaphoreValidatorUserOpData(
Expand Down
46 changes: 42 additions & 4 deletions test/utils/TestUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import { ISemaphore } from "../../src/utils/Semaphore.sol";
// import { console } from "forge-std/console.sol";
import { LibString } from "solady/Milady.sol";

// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
address constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
Vm constant vm = Vm(VM_ADDRESS);

struct ValidationData {
address aggregator;
uint48 validAfter;
Expand Down Expand Up @@ -49,13 +53,47 @@ function getTestUserOpCallData(
callData = bytes.concat(new bytes(100), bytes20(targetAddr), bytes32(value), txCallData);
}

function getGroupRmMerkleProof(
uint256[] memory members,
uint256 removal
)
returns (uint256[] memory merkleProof, uint256 root)
{
string[] memory cmd = new string[](5);
cmd[0] = "pnpm";
cmd[1] = "semaphore-group";
cmd[2] = "remove-member";
cmd[3] = _join(members);
cmd[4] = LibString.toString(removal);

bytes memory outBytes = vm.ffi(cmd);
string memory outStr = string(outBytes);
string[] memory retStr = LibString.split(outStr, " ");

merkleProof = _splitToUint(retStr[0]);
root = vm.parseUint(retStr[1]);
}

function _splitToUint(string memory str) pure returns (uint256[] memory retArr) {
string[] memory arr = LibString.split(str, ",");
retArr = new uint256[](arr.length);
for (uint256 i = 0; i < arr.length; i++) {
retArr[i] = vm.parseUint(arr[i]);
}
}

function _join(uint256[] memory members) pure returns (string memory retStr) {
for (uint256 i = 0; i < members.length; i++) {
retStr = string.concat(retStr, LibString.toString(members[i]));
if (i < members.length - 1) {
retStr = string.concat(retStr, ",");
}
}
}

type Identity is bytes32;

library IdentityLib {
// https://github.com/foundry-rs/forge-std/blob/master/src/Base.sol#L9
address internal constant VM_ADDRESS = 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D;
Vm internal constant vm = Vm(VM_ADDRESS);

function genIdentity(uint256 seed) public view returns (Identity) {
return Identity.wrap(keccak256(abi.encodePacked(seed, address(this))));
}
Expand Down
Loading