diff --git a/src/voting-strategies/WhitelistVotingStrategy.sol b/src/voting-strategies/WhitelistVotingStrategy.sol index 0b3948a5..aa07a3df 100644 --- a/src/voting-strategies/WhitelistVotingStrategy.sol +++ b/src/voting-strategies/WhitelistVotingStrategy.sol @@ -7,6 +7,10 @@ import { IVotingStrategy } from "../interfaces/IVotingStrategy.sol"; /// @title Whitelist Voting Strategy /// @notice Allows a variable voting power whitelist to be used for voting power. contract WhitelistVotingStrategy is IVotingStrategy { + /// @notice Error thrown when the `voter` and address indicated by `voterIndex` + /// don't match. + error VoterAndIndexMismatch(); + /// @dev Stores the data for each member of the whitelist. struct Member { // The address of the member. @@ -19,34 +23,19 @@ contract WhitelistVotingStrategy is IVotingStrategy { /// @param voter The address to get the voting power of. /// @param params Parameter array containing the encoded whitelist of addresses and their voting power. /// The array should be an ABI encoded array of Member structs sorted by ascending addresses. + /// @param userParams Expected to contain a `uint256` corresponding to the voterIndex in the array provided by `params`. /// @return votingPower The voting power of the address if it exists in the whitelist, otherwise 0. function getVotingPower( uint32 /* timestamp */, address voter, bytes calldata params, - bytes calldata /* userParams */ + bytes calldata userParams ) external pure override returns (uint256 votingPower) { Member[] memory members = abi.decode(params, (Member[])); + uint256 voterIndex = abi.decode(userParams, (uint256)); - uint256 high = members.length - 1; - uint256 low; - uint256 mid; - address currentAddress; - - while (low < high) { - mid = (high + low) / 2; // Expecting high and low to never overflow - currentAddress = members[mid].addr; + if (voter != members[voterIndex].addr) revert VoterAndIndexMismatch(); - if (currentAddress < voter) { - low = mid + 1; - } else { - high = mid; - } - } - if (members[high].addr == voter) { - return (members[high].vp); - } else { - return (0); - } + return members[voterIndex].vp; } } diff --git a/test/WhitelistVotingStrategy.t.sol b/test/WhitelistVotingStrategy.t.sol index 3415d4c1..64aaa396 100644 --- a/test/WhitelistVotingStrategy.t.sol +++ b/test/WhitelistVotingStrategy.t.sol @@ -5,32 +5,25 @@ import { Test } from "forge-std/Test.sol"; import { WhitelistVotingStrategy } from "../src/voting-strategies/WhitelistVotingStrategy.sol"; contract WhitelistVotingStrategyTest is Test { + error VoterAndIndexMismatch(); + WhitelistVotingStrategy public whitelistVotingStrategy; function testWhitelistVotingPower() public { WhitelistVotingStrategy.Member[] memory members = new WhitelistVotingStrategy.Member[](3); - members[0] = WhitelistVotingStrategy.Member(address(1), 11); - members[1] = WhitelistVotingStrategy.Member(address(3), 33); + members[0] = WhitelistVotingStrategy.Member(address(3), 33); + members[1] = WhitelistVotingStrategy.Member(address(1), 11); members[2] = WhitelistVotingStrategy.Member(address(5), 55); whitelistVotingStrategy = new WhitelistVotingStrategy(); bytes memory params = abi.encode(members); - assertEq(whitelistVotingStrategy.getVotingPower(0, members[0].addr, params, ""), members[0].vp); - assertEq(whitelistVotingStrategy.getVotingPower(0, members[1].addr, params, ""), members[1].vp); - assertEq(whitelistVotingStrategy.getVotingPower(0, members[2].addr, params, ""), members[2].vp); - - // Index 0 - assertEq(whitelistVotingStrategy.getVotingPower(0, address(0), params, ""), 0); - // Index 2 - assertEq(whitelistVotingStrategy.getVotingPower(0, address(2), params, ""), 0); - // 4 - assertEq(whitelistVotingStrategy.getVotingPower(0, address(4), params, ""), 0); - // Last index - assertEq(whitelistVotingStrategy.getVotingPower(0, address(6), params, ""), 0); + assertEq(whitelistVotingStrategy.getVotingPower(0, members[0].addr, params, abi.encode(0)), members[0].vp); + assertEq(whitelistVotingStrategy.getVotingPower(0, members[1].addr, params, abi.encode(1)), members[1].vp); + assertEq(whitelistVotingStrategy.getVotingPower(0, members[2].addr, params, abi.encode(2)), members[2].vp); } - function testWhitelistVotingPowerSmall() public { + function testWhitelistVoterAndIndexMismatch() public { WhitelistVotingStrategy.Member[] memory members = new WhitelistVotingStrategy.Member[](3); members[0] = WhitelistVotingStrategy.Member(address(1), 11); members[1] = WhitelistVotingStrategy.Member(address(3), 33); @@ -38,14 +31,8 @@ contract WhitelistVotingStrategyTest is Test { bytes memory params = abi.encode(members); - assertEq(whitelistVotingStrategy.getVotingPower(0, members[0].addr, params, ""), members[0].vp); - assertEq(whitelistVotingStrategy.getVotingPower(0, members[1].addr, params, ""), members[1].vp); - - // Index 0 - assertEq(whitelistVotingStrategy.getVotingPower(0, address(0), params, ""), 0); - // Index 2 - assertEq(whitelistVotingStrategy.getVotingPower(0, address(2), params, ""), 0); - // Last index - assertEq(whitelistVotingStrategy.getVotingPower(0, address(4), params, ""), 0); + vm.expectRevert(VoterAndIndexMismatch.selector); + // `voter` is members[0] but the `voterIndex` is 1 (which corresponds to members[1]). + whitelistVotingStrategy.getVotingPower(0, members[0].addr, params, abi.encode(1)); } }