diff --git a/core/contracts/shared/IReceiveApproval.sol b/core/contracts/shared/IReceiveApproval.sol new file mode 100644 index 000000000..175e32eb8 --- /dev/null +++ b/core/contracts/shared/IReceiveApproval.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-3.0-only + +pragma solidity ^0.8.20; + +/// @notice An interface that should be implemented by contracts supporting +/// `approveAndCall`/`receiveApproval` pattern. +interface IReceiveApproval { + /// @notice Receives approval to spend tokens. Called as a result of + /// `approveAndCall` call on the token. + function receiveApproval( + address from, + uint256 amount, + address token, + bytes calldata extraData + ) external; +} diff --git a/core/contracts/staking/TokenStaking.sol b/core/contracts/staking/TokenStaking.sol index d1b9d2ccd..a9051040d 100644 --- a/core/contracts/staking/TokenStaking.sol +++ b/core/contracts/staking/TokenStaking.sol @@ -4,12 +4,13 @@ pragma solidity ^0.8.20; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import "../shared/IReceiveApproval.sol"; /// @title TokenStaking /// @notice A token staking contract for a specified standard ERC20 token. A /// holder of the specified token can stake its tokens to this contract /// and recover the stake after undelegation period is over. -contract TokenStaking { +contract TokenStaking is IReceiveApproval { using SafeERC20 for IERC20; event Staked(address indexed account, uint256 amount); @@ -27,14 +28,36 @@ contract TokenStaking { token = _token; } + /// @notice Receives approval of token transfer and stakes the approved + /// amount or adds the approved amount to an existing stake. + /// @dev Requires that the provided token contract be the same one linked to + /// this contract. + /// @param from The owner of the tokens who approved them to transfer. + /// @param amount Approved amount for the transfer and stake. + /// @param _token Token contract address. + function receiveApproval( + address from, + uint256 amount, + address _token, + bytes calldata + ) external override { + require(_token == address(token), "Unrecognized token"); + _stake(from, amount); + } + /// @notice Stakes the owner's tokens in the staking contract. /// @param amount Approved amount for the transfer and stake. function stake(uint256 amount) external { + _stake(msg.sender, amount); + } + + function _stake(address account, uint256 amount) private { require(amount > 0, "Amount is less than minimum"); + require(account != address(0), "Can not be the zero address"); - balanceOf[msg.sender] += amount; + balanceOf[account] += amount; - emit Staked(msg.sender, amount); - token.safeTransferFrom(msg.sender, address(this), amount); + emit Staked(account, amount); + token.safeTransferFrom(account, address(this), amount); } } diff --git a/core/contracts/test/TestToken.sol b/core/contracts/test/TestToken.sol index 749165b4c..c061cf278 100644 --- a/core/contracts/test/TestToken.sol +++ b/core/contracts/test/TestToken.sol @@ -3,6 +3,8 @@ pragma solidity 0.8.20; import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import "../shared/IReceiveApproval.sol"; + contract Token is ERC20 { constructor() ERC20("Test Token", "TEST") {} @@ -10,4 +12,21 @@ contract Token is ERC20 { function mint(address account, uint256 value) external { _mint(account, value); } + + function approveAndCall( + address spender, + uint256 amount, + bytes memory extraData + ) external returns (bool) { + if (approve(spender, amount)) { + IReceiveApproval(spender).receiveApproval( + msg.sender, + amount, + address(this), + extraData + ); + return true; + } + return false; + } } diff --git a/core/test/staking/TokenStaking.test.ts b/core/test/staking/TokenStaking.test.ts index fb2e8655a..076313950 100644 --- a/core/test/staking/TokenStaking.test.ts +++ b/core/test/staking/TokenStaking.test.ts @@ -4,7 +4,6 @@ import { expect } from "chai" import { Token, TokenStaking } from "../../typechain" import { WeiPerEther } from "ethers" import { HardhatEthersSigner } from "@nomicfoundation/hardhat-ethers/signers" -import { before } from "mocha" async function tokenStakingFixture() { const [deployer, tokenHolder] = await ethers.getSigners() @@ -39,29 +38,52 @@ describe("TokenStaking", () => { }) describe("staking", () => { - beforeEach(async () => { - // Infinite approval for staking contract. - await token - .connect(tokenHolder) - .approve(await tokenStaking.getAddress(), ethers.MaxUint256) - }) + describe("when staking via staking contract directly", () => { + beforeEach(async () => { + // Infinite approval for staking contract. + await token + .connect(tokenHolder) + .approve(await tokenStaking.getAddress(), ethers.MaxUint256) + }) + + it("should stake tokens", async () => { + const tokenHolderAddress = await tokenHolder.getAddress() + const tokenBalance = await token.balanceOf(tokenHolderAddress) - it("should stake tokens", async () => { - const tokenHolderAddress = await tokenHolder.getAddress() - const tokenBalance = await token.balanceOf(tokenHolderAddress) + await expect(tokenStaking.connect(tokenHolder).stake(tokenBalance)) + .to.emit(tokenStaking, "Staked") + .withArgs(tokenHolderAddress, tokenBalance) + expect(await tokenStaking.balanceOf(tokenHolderAddress)).to.be.eq( + tokenBalance, + ) + expect(await token.balanceOf(tokenHolderAddress)).to.be.eq(0) + }) - await expect(tokenStaking.connect(tokenHolder).stake(tokenBalance)) - .to.emit(tokenStaking, "Staked") - .withArgs(tokenHolderAddress, tokenBalance) - expect(await tokenStaking.balanceOf(tokenHolderAddress)).to.be.eq( - tokenBalance, - ) - expect(await token.balanceOf(tokenHolderAddress)).to.be.eq(0) + it("should revert if the staked amount is less than required minimum", async () => { + await expect( + tokenStaking.connect(tokenHolder).stake(0), + ).to.be.revertedWith("Amount is less than minimum") + }) }) - it("should revert if the staked amount is less than required minimum", async () => { - await expect(tokenStaking.connect(tokenHolder).stake(0)) - .to.be.revertedWith("Amount is less than minimum") + describe("when staking via staking token using approve and call pattern", () => { + it("should stake tokens", async () => { + const tokenHolderAddress = await tokenHolder.getAddress() + const tokenBalance = await token.balanceOf(tokenHolderAddress) + const tokenStakingAddress = await tokenStaking.getAddress() + + await expect( + token + .connect(tokenHolder) + .approveAndCall(tokenStakingAddress, tokenBalance, "0x"), + ) + .to.emit(tokenStaking, "Staked") + .withArgs(tokenHolderAddress, tokenBalance) + expect(await tokenStaking.balanceOf(tokenHolderAddress)).to.be.eq( + tokenBalance, + ) + expect(await token.balanceOf(tokenHolderAddress)).to.be.eq(0) + }) }) }) })