From 3005b80dfdfbf2781c14af9119d14992cc0282d4 Mon Sep 17 00:00:00 2001 From: Eric <5089238+emizzle@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:15:33 +1100 Subject: [PATCH] fix(slot-reservations): ensure slot is free Ensure that the slot state is free before allowing reservations --- contracts/Marketplace.sol | 4 ++++ contracts/SlotReservations.sol | 5 ++++- contracts/TestSlotReservations.sol | 10 ++++++++++ test/SlotReservations.test.js | 16 ++++++++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/contracts/Marketplace.sol b/contracts/Marketplace.sol index dc244efe..1d5aa123 100644 --- a/contracts/Marketplace.sol +++ b/contracts/Marketplace.sol @@ -457,6 +457,10 @@ contract Marketplace is SlotReservations, Proofs, StateRetrieval, Endian { _; } + function _slotIsFree(SlotId slotId) internal view override returns (bool) { + return _slots[slotId].state == SlotState.Free; + } + function requestEnd(RequestId requestId) public view returns (uint256) { uint256 end = _requestContexts[requestId].endsAt; RequestState state = requestState(requestId); diff --git a/contracts/SlotReservations.sol b/contracts/SlotReservations.sol index 8584105c..cbc050cc 100644 --- a/contracts/SlotReservations.sol +++ b/contracts/SlotReservations.sol @@ -5,7 +5,7 @@ import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import "./Requests.sol"; import "./Configuration.sol"; -contract SlotReservations { +abstract contract SlotReservations { using EnumerableSet for EnumerableSet.AddressSet; mapping(SlotId => EnumerableSet.AddressSet) internal _reservations; @@ -15,6 +15,8 @@ contract SlotReservations { _config = config; } + function _slotIsFree(SlotId slotId) internal view virtual returns (bool); + function reserveSlot(RequestId requestId, uint256 slotIndex) public { require(canReserveSlot(requestId, slotIndex), "Reservation not allowed"); @@ -34,6 +36,7 @@ contract SlotReservations { SlotId slotId = Requests.slotId(requestId, slotIndex); return // TODO: add in check for address inside of expanding window + _slotIsFree(slotId) && (_reservations[slotId].length() < _config.maxReservations) && (!_reservations[slotId].contains(host)); } diff --git a/contracts/TestSlotReservations.sol b/contracts/TestSlotReservations.sol index 31d19d6a..3fb737c0 100644 --- a/contracts/TestSlotReservations.sol +++ b/contracts/TestSlotReservations.sol @@ -6,6 +6,8 @@ import "./SlotReservations.sol"; contract TestSlotReservations is SlotReservations { using EnumerableSet for EnumerableSet.AddressSet; + mapping(SlotId => SlotState) private _states; + // solhint-disable-next-line no-empty-blocks constructor(SlotReservationsConfig memory config) SlotReservations(config) {} @@ -16,4 +18,12 @@ contract TestSlotReservations is SlotReservations { function length(SlotId slotId) public view returns (uint256) { return _reservations[slotId].length(); } + + function _slotIsFree(SlotId slotId) internal view override returns (bool) { + return _states[slotId] == SlotState.Free; + } + + function setSlotState(SlotId id, SlotState state) public { + _states[id] = state; + } } diff --git a/test/SlotReservations.test.js b/test/SlotReservations.test.js index 9db71e9d..63727593 100644 --- a/test/SlotReservations.test.js +++ b/test/SlotReservations.test.js @@ -2,6 +2,7 @@ const { expect } = require("chai") const { ethers } = require("hardhat") const { exampleRequest, exampleConfiguration } = require("./examples") const { requestId, slotId } = require("./ids") +const { SlotState } = require("./requests") describe("SlotReservations", function () { let reservations @@ -28,6 +29,8 @@ describe("SlotReservations", function () { index: slotIndex, } id = slotId(slot) + + await reservations.setSlotState(id, SlotState.Free) }) function switchAccount(account) { @@ -99,6 +102,19 @@ describe("SlotReservations", function () { expect(await reservations.canReserveSlot(reqId, slotIndex)).to.be.false }) + it("cannot reserve a slot if not free", async function () { + await reservations.setSlotState(id, SlotState.Filled) + await expect(reservations.reserveSlot(reqId, slotIndex)).to.be.revertedWith( + "Reservation not allowed" + ) + expect(await reservations.length(id)).to.equal(0) + }) + + it("reports a slot cannot be reserved if not free", async function () { + await reservations.setSlotState(id, SlotState.Filled) + expect(await reservations.canReserveSlot(reqId, slotIndex)).to.be.false + }) + it("should emit an event when slot reservations are full", async function () { await reservations.reserveSlot(reqId, slotIndex) switchAccount(address1)