Skip to content

Commit

Permalink
feat: created AllowedParamsEnforcer and ported over AllowedTargets an…
Browse files Browse the repository at this point in the history
…d AllowedMethods enforcers from metamask repo
  • Loading branch information
SahilVasava committed Aug 9, 2024
1 parent 30506aa commit 334081d
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 1 deletion.
93 changes: 93 additions & 0 deletions src/enforcers/AllowedMethodsEnforcer.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import "kernel/src/utils/ExecLib.sol";
import {CaveatEnforcerBatch} from "./CaveatEnforcerBatch.sol";

/**
* @title AllowedMethodsEnforcer
* @dev This contract enforces the allowed methods a delegate may call.
*/
contract AllowedMethodsEnforcer is CaveatEnforcerBatch {
////////////////////////////// Custom Errors //////////////////////////////

error InvalidCallType();
error MethodNotAllowed();

////////////////////////////// Public Methods //////////////////////////////

/**
* @notice Allows the delegator to limit what methods the delegate may call.
* @dev This function enforces the allowed methods before the transaction is performed.
* @param _terms A series of 4byte method identifiers, representing the methods that the delegate is allowed to call.
* @param _executionData The executionData the delegate is trying try to execute.
*/
function beforeHook(
bytes calldata _terms,
bytes calldata,
bytes32 _executionMode,
bytes calldata _executionData,
bytes32,
address,
address
) public pure override {
bytes4[] memory allowedSignatures_ = getTermsInfo(_terms);
(CallType callType_,,,) = ExecLib.decode(ExecMode.wrap(_executionMode));
if (callType_ == CALLTYPE_SINGLE) {
(,, bytes calldata callData_) = ExecLib.decodeSingle(_executionData);
bytes4 targetSig_ = bytes4(callData_[0:4]);
bool signaturePass_ = _checkSignature(allowedSignatures_, targetSig_);
if (!signaturePass_) {
revert MethodNotAllowed();
}
} else if (callType_ == CALLTYPE_BATCH) {
Execution[] calldata exec_ = ExecLib.decodeBatch(_executionData);
for (uint256 j = 0; j < exec_.length; j++) {
bytes4 targetSig_ = bytes4(exec_[j].callData[0:4]);
bool signaturePass_ = _checkSignature(allowedSignatures_, targetSig_);
if (!signaturePass_) {
revert MethodNotAllowed();
}
}
} else if (callType_ == CALLTYPE_DELEGATECALL) {
bytes4 targetSig_ = bytes4(_executionData[20:24]);
bool signaturePass_ = _checkSignature(allowedSignatures_, targetSig_);
if (!signaturePass_) {
revert MethodNotAllowed();
}
} else {
revert InvalidCallType();
}
}

/**
* @dev Checks the method signature with set of allowed method signatures.
* @param _allowedSignatures The allowed signatures array.
* @param _targetSig The target method signature of the calldata.
* @return A boolean indicating whether the target method signature matches one of the allowed target method signatures.
*/
function _checkSignature(bytes4[] memory _allowedSignatures, bytes4 _targetSig) internal pure returns (bool) {
for (uint256 i = 0; i < _allowedSignatures.length; ++i) {
if (_targetSig == _allowedSignatures[i]) {
return true;
}
}
return false;
}

/**
* @notice Decodes the terms used in this CaveatEnforcer.
* @param _terms encoded data that is used during the execution hooks.
* @return allowedMethods_ The 4 byte identifiers for the methods that the delegate is allowed to call.
*/
function getTermsInfo(bytes calldata _terms) public pure returns (bytes4[] memory allowedMethods_) {
uint256 j = 0;
uint256 termsLength_ = _terms.length;
require(termsLength_ % 4 == 0, "AllowedMethodsEnforcer:invalid-terms-length");
allowedMethods_ = new bytes4[](termsLength_ / 4);
for (uint256 i = 0; i < termsLength_; i += 4) {
allowedMethods_[j] = bytes4(_terms[i:i + 4]);
j++;
}
}
}
180 changes: 180 additions & 0 deletions src/enforcers/AllowedParamsEnforcer.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import {CaveatEnforcerBatch} from "./CaveatEnforcerBatch.sol";
import "kernel/src/utils/ExecLib.sol";

struct ParamRule {
ParamCondition condition;
uint64 offset;
bytes32[] params;
}

enum ParamCondition {
EQUAL,
GREATER_THAN,
LESS_THAN,
GREATER_THAN_OR_EQUAL,
LESS_THAN_OR_EQUAL,
NOT_EQUAL,
ONE_OF
}

struct Permission {
CallType callType;
address target;
bytes4 selector;
ParamRule[] rules;
}

/**
* @title AllowedParamsEnforcer
* @dev This contract enforces that target, methods and params of the calldata to be executed matches the permissions.
* @dev A common use case for this enforcer is enforcing function parameters.
*/
contract AllowedParamsEnforcer is CaveatEnforcerBatch {
////////////////////////////// Custom Errors //////////////////////////////

error InvalidCallType();
error CallViolatesParamRule();

////////////////////////////// Public Methods //////////////////////////////

/**
* @notice Allows the delegator to restrict the calldata that is executed
* @dev This function enforces that a subset of the calldata to be executed matches the allowed subset of calldata.
* @param _terms This is packed bytes
* @param _executionData The executionData the delegate is trying try to execute.
*/
function beforeHook(
bytes calldata _terms,
bytes calldata,
bytes32 _executionMode,
bytes calldata _executionData,
bytes32,
address,
address
) public pure override {
Permission[] memory permissions_;

(permissions_) = getTermsInfo(_terms);
(CallType callType_,,,) = ExecLib.decode(ExecMode.wrap(_executionMode));
if (callType_ == CALLTYPE_SINGLE) {
(address target_,, bytes calldata callData_) = ExecLib.decodeSingle(_executionData);
for (uint256 i = 0; i < permissions_.length; i++) {
Permission memory permission_ = permissions_[i];
if (
(permission_.target == target_ || permission_.target == address(0))
&& permission_.selector == bytes4(callData_[0:4]) && permission_.callType == CALLTYPE_SINGLE
) {
bool permissionPass_ = _checkParams(callData_, permission_.rules);
if (!permissionPass_) {
revert CallViolatesParamRule();
}
return;
}
}
} else if (callType_ == CALLTYPE_BATCH) {
Execution[] calldata exec_ = ExecLib.decodeBatch(_executionData);
for (uint256 j = 0; j < exec_.length; j++) {
bool permissionFoundAndPassed_ = false;
bytes4 execSelector_ = bytes4(exec_[j].callData[0:4]);
for (uint256 i = 0; i < permissions_.length; i++) {
Permission memory permission_ = permissions_[i];
if (
(permission_.target == exec_[j].target || permission_.target == address(0))
&& permission_.selector == execSelector_ && permission_.callType == CALLTYPE_BATCH
) {
bool permissionPass_ = _checkParams(exec_[j].callData, permission_.rules);
if (!permissionPass_) {
revert CallViolatesParamRule();
}
permissionFoundAndPassed_ = true;
break;
}
}
if (!permissionFoundAndPassed_) {
revert("AllowedParamsEnforcer:no-matching-permissions-found");
}
}
return;
} else if (callType_ == CALLTYPE_DELEGATECALL) {
address target_ = address(bytes20(_executionData[0:20]));
bytes calldata callData_ = _executionData[20:];
for (uint256 i = 0; i < permissions_.length; i++) {
Permission memory permission_ = permissions_[i];
if (
(permission_.target == target_ || permission_.target == address(0))
&& permission_.selector == bytes4(callData_[0:4]) && permission_.callType == CALLTYPE_DELEGATECALL
) {
bool permissionPass_ = _checkParams(callData_, permission_.rules);
if (!permissionPass_) {
revert CallViolatesParamRule();
}
return;
}
}
} else {
revert InvalidCallType();
}
revert("AllowedParamsEnforcer:no-matching-permissions-found");
}

/**
* @dev Checks the params of the calldata to be execute with set of allowed params.
* @param _data The calldata of the execution.
* @param _rules The rules array for the params of the calldata.
* @return A boolean indicating whether all the params satisfies the defined set of rules.
*/
function _checkParams(bytes calldata _data, ParamRule[] memory _rules) internal pure returns (bool) {
for (uint256 i = 0; i < _rules.length; i++) {
ParamRule memory rule_ = _rules[i];
bytes32 param_ = bytes32(_data[4 + rule_.offset:4 + rule_.offset + 32]);
// only ONE_OF condition can have multiple params
if (rule_.condition == ParamCondition.EQUAL && param_ != rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.GREATER_THAN && param_ <= rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.LESS_THAN && param_ >= rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.GREATER_THAN_OR_EQUAL && param_ < rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.LESS_THAN_OR_EQUAL && param_ > rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.NOT_EQUAL && param_ == rule_.params[0]) {
return false;
} else if (rule_.condition == ParamCondition.ONE_OF) {
bool oneOfStatus = false;
for (uint256 j = 0; j < rule_.params.length; j++) {
if (param_ == rule_.params[j]) {
oneOfStatus = true;
break;
}
}
if (!oneOfStatus) {
return false;
}
}
}
return true;
}

/**
* @notice Decodes the terms used in this CaveatEnforcer.
* @param _terms encoded data that is used during the execution hooks.
* @return permissions The permissions for the transaction.
*/
function getTermsInfo(bytes calldata _terms) public pure returns (Permission[] memory permissions) {
(permissions) = abi.decode(_terms, (Permission[]));
}

/**
* @dev Compares two byte arrays for equality.
* @param _a The first byte array.
* @param _b The second byte array.
* @return A boolean indicating whether the byte arrays are equal.
*/
function _compare(bytes memory _a, bytes memory _b) private pure returns (bool) {
return keccak256(_a) == keccak256(_b);
}
}
98 changes: 98 additions & 0 deletions src/enforcers/AllowedTargetsEnforcer.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import "kernel/src/utils/ExecLib.sol";
import {CaveatEnforcerBatch} from "./CaveatEnforcerBatch.sol";

/**
* @title AllowedTargetsEnforcer
* @dev This contract enforces the allowed target addresses for a delegate.
*/
contract AllowedTargetsEnforcer is CaveatEnforcerBatch {
////////////////////////////// Custom Errors //////////////////////////////

error InvalidCallType();
error TargetNotAllowed();

////////////////////////////// Public Methods //////////////////////////////

/**
* @notice Allows the delegator to limit what addresses the delegate may call.
* @dev This function enforces the allowed target addresses before the transaction is performed.
* @param _terms A series of 20byte addresses, representing the addresses that the delegate is allowed to call.
* @param _executionData The executionData the delegate is trying try to execute.
*/
function beforeHook(
bytes calldata _terms,
bytes calldata,
bytes32 _executionMode,
bytes calldata _executionData,
bytes32,
address,
address
) public pure override {
address[] memory allowedTargets_ = getTermsInfo(_terms);
(CallType callType_,,,) = ExecLib.decode(ExecMode.wrap(_executionMode));
if (callType_ == CALLTYPE_SINGLE) {
(address targetAddress_,,) = ExecLib.decodeSingle(_executionData);
bool targetPass_ = _checkTargetAddress(allowedTargets_, targetAddress_);
if (!targetPass_) {
revert TargetNotAllowed();
}
} else if (callType_ == CALLTYPE_BATCH) {
Execution[] calldata exec = ExecLib.decodeBatch(_executionData);
for (uint256 j = 0; j < exec.length; j++) {
address targetAddress_ = exec[j].target;
bool targetPass_ = _checkTargetAddress(allowedTargets_, targetAddress_);
if (!targetPass_) {
revert TargetNotAllowed();
}
}
} else if (callType_ == CALLTYPE_DELEGATECALL) {
address targetAddress_ = address(bytes20(_executionData[0:20]));
bool targetPass_ = _checkTargetAddress(allowedTargets_, targetAddress_);
if (!targetPass_) {
revert TargetNotAllowed();
}
} else {
revert InvalidCallType();
}

revert("AllowedTargetsEnforcer:target-address-not-allowed");
}

/**
* @dev Checks the target address with set of allowed target addresses.
* @param _allowedTargets The allowed targets array.
* @param _targetAddress The target address of the calldata.
* @return A boolean indicating whether the target address matches one of the allowed target addresses.
*/
function _checkTargetAddress(address[] memory _allowedTargets, address _targetAddress)
internal
pure
returns (bool)
{
for (uint256 i = 0; i < _allowedTargets.length; ++i) {
if (_targetAddress == _allowedTargets[i]) {
return true;
}
}
return false;
}

/**
* @notice Decodes the terms used in this CaveatEnforcer.
* @param _terms encoded data that is used during the execution hooks.
* @return allowedTargets_ The allowed target addresses.
*/
function getTermsInfo(bytes calldata _terms) public pure returns (address[] memory allowedTargets_) {
uint256 j = 0;
uint256 termsLength_ = _terms.length;
require(termsLength_ % 20 == 0, "AllowedTargetsEnforcer:invalid-terms-length");
allowedTargets_ = new address[](termsLength_ / 20);
for (uint256 i = 0; i < termsLength_; i += 20) {
allowedTargets_[j] = address(bytes20(_terms[i:i + 20]));
j++;
}
}
}
22 changes: 22 additions & 0 deletions src/enforcers/CaveatEnforcerBatch.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import {ICaveatEnforcerBatch} from "../interfaces/ICaveatEnforcerBatch.sol";

/**
* @title CaveatEnforcer
* @dev This abstract contract enforces caveats before and after the execution of an action.
*/
abstract contract CaveatEnforcerBatch is ICaveatEnforcerBatch {
/// @inheritdoc ICaveatEnforcerBatch
function beforeHook(bytes calldata, bytes calldata, bytes32, bytes calldata, bytes32, address, address)
public
virtual
{}

/// @inheritdoc ICaveatEnforcerBatch
function afterHook(bytes calldata, bytes calldata, bytes32, bytes calldata, bytes32, address, address)
public
virtual
{}
}
Loading

0 comments on commit 334081d

Please sign in to comment.