diff --git a/crates/cheatcodes/defs/src/vm.rs b/crates/cheatcodes/defs/src/vm.rs index 2061dac26d2aa..c031a86ef559e 100644 --- a/crates/cheatcodes/defs/src/vm.rs +++ b/crates/cheatcodes/defs/src/vm.rs @@ -255,7 +255,8 @@ interface Vm { #[cheatcode(group = Evm, safety = Unsafe)] function store(address target, bytes32 slot, bytes32 value) external; - /// Marks the slots of an account and the account address as cold. + /// Marks the `target` address cold, and is a no-op if the address is already cold. + /// All storage slots are also made cold, but their values are preserved. #[cheatcode(group = Evm, safety = Unsafe)] function cool(address target) external; diff --git a/crates/cheatcodes/src/evm.rs b/crates/cheatcodes/src/evm.rs index c1cdcc1b22692..14fde5aa08c96 100644 --- a/crates/cheatcodes/src/evm.rs +++ b/crates/cheatcodes/src/evm.rs @@ -1,6 +1,7 @@ //! Implementations of [`Evm`](crate::Group::Evm) cheatcodes. -use crate::{Cheatcode, Cheatcodes, CheatsCtxt, Result, Vm::*}; +use crate::{inspector::AddressState, Cheatcode, Cheatcodes, CheatsCtxt, Result, Vm::*}; + use alloy_primitives::{Address, Bytes, U256}; use alloy_sol_types::SolValue; use ethers_core::utils::{Genesis, GenesisAccount}; @@ -302,11 +303,10 @@ impl Cheatcode for storeCall { impl Cheatcode for coolCall { fn apply_full(&self, ccx: &mut CheatsCtxt) -> Result { - let Self { target } = self; - if let Some(account) = ccx.data.journaled_state.state.get_mut(target) { - account.unmark_touch(); - account.storage.clear(); - } + let Self { target } = *self; + ensure_not_precompile!(&target, ccx); + // TODO: prevent or warn about cooling the to/from address in a tx + ccx.state.addresses.insert(target, (AddressState::Cool, HashMap::new())); Ok(Default::default()) } } diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 5c95cfc2fa2fd..7a1dc3f6f1171 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -200,6 +200,182 @@ pub struct Cheatcodes { /// Breakpoints supplied by the `breakpoint` cheatcode. /// `char -> (address, pc)` pub breakpoints: Breakpoints, + + /// Track if cool cheatcode was called on each address + /// Mapping tracks if the address itself is cool and if each of it's slots are cool + pub addresses: HashMap)>, + /// How much gas to charge in the next step (op code) based on cool cheatcode calculations + pub additional_gas_next_op: u64, +} + +/// Whether an Address is accessed or not +#[derive(Clone, PartialEq, Debug)] +pub enum AddressState { + /// if already accessed, then charge WARM_STORAGE_READ_COST (100) + Warm, + /// charge COLD_ACCOUNT_ACCESS_COST (2600) + Cool, +} + +/// Whether a Storage Slot is warm or already been modified +#[derive(Clone, PartialEq, Debug)] +pub enum StorageSlotState { + /// charge extra based on SSTORE calculations + WarmWithSLOAD, + /// if SSTORE already happened, don't charge extra + WarmWithSSTORE, + /// same as if empty + Cool, +} + +/// Function to charge extra gas per opcode based on cool cheatcode +fn add_gas_from_cool_cheatcode( + state: &mut Cheatcodes, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, +) -> InstructionResult { + if state.addresses.is_empty() { + return InstructionResult::Continue + } + + // For gas costs, see https://eips.ethereum.org/EIPS/eip-2200, https://eips.ethereum.org/EIPS/eip-2929 + + // if previous step added gas, add it once + // note that all the opcodes will already have a cost (usually 100) + // so adding until it hits the gas expected for a cold key/address + if state.additional_gas_next_op > 0 { + interpreter.gas.record_cost(state.additional_gas_next_op); + state.additional_gas_next_op = 0; + } + + // if cool cheatcode was ever called on this address + let contract_address = interpreter.contract().address; + + if state.addresses.get(&contract_address).is_some() { + // check target itself + match interpreter.current_opcode() { + // via AccessListTracer + opcode::EXTCODECOPY | + opcode::EXTCODEHASH | + opcode::EXTCODESIZE | + opcode::BALANCE | + opcode::SELFDESTRUCT => { + // address is first parameter + if let Ok(slot) = interpreter.stack().peek(0) { + let addr: Address = Address::from_word(slot.into()); + + // COLD_ACCOUNT_ACCESS_COST is 2600 + // check this is done once per address, unless cheatcode is called again + // ignore if same as contract address + if let Some((ref mut address, _)) = state.addresses.get_mut(&addr) { + if *address == AddressState::Cool { + state.additional_gas_next_op = 2500; + *address = AddressState::Warm; + } + } + } + } + // via AccessListTracer + opcode::DELEGATECALL | opcode::CALL | opcode::STATICCALL | opcode::CALLCODE => { + // address is second parameter + if let Ok(slot) = interpreter.stack().peek(1) { + let addr: Address = Address::from_word(slot.into()); + + // COLD_ACCOUNT_ACCESS_COST is 2600 + // check this is done once per address, unless cheatcode is called again + // ignore if same as contract address + if let Some((ref mut address, _)) = state.addresses.get_mut(&addr) { + if *address == AddressState::Cool { + state.additional_gas_next_op = 2500; + *address = AddressState::Warm; + } + } + } + } + _ => {} + } + // check target's slots + if let Some((_, slots)) = state.addresses.get_mut(&contract_address) { + match interpreter.current_opcode() { + opcode::SLOAD => { + let key = try_or_continue!(interpreter.stack().peek(0)); + + let account = data.journaled_state.state().get(&contract_address).unwrap(); + if account.storage.get(&key).is_some() { + match slots.get(&key) { + None | Some(StorageSlotState::Cool) => { + // COLD_SLOAD_COST is 2100 + state.additional_gas_next_op = 2000; + slots.insert(key, StorageSlotState::WarmWithSLOAD); + } + Some(_) => {} + } + } else { + slots.insert(key, StorageSlotState::WarmWithSLOAD); + } + } + opcode::SSTORE => { + let key = try_or_continue!(interpreter.stack().peek(0)); + let val = try_or_continue!(interpreter.stack().peek(1)); + + let account = data.journaled_state.state().get(&contract_address).unwrap(); + if account.storage.get(&key).is_some() { + // only add gas the first time the storage is touched again + match slots.get(&key) { + Some(StorageSlotState::WarmWithSLOAD) => { + // cool keeps the slot value changes + // as if the previous_or_original_value = present_value` + // so include the extra gas + let slot = account.storage.get(&key).unwrap(); + if val != slot.present_value && + slot.present_value != slot.previous_or_original_value + { + if slot.present_value == U256::ZERO { + // SSTORE_SET_GAS is 20000 + state.additional_gas_next_op += 20000 - 100 + } else { + // SSTORE_RESET_GAS is 5000 - COLD_SLOAD_COST (2100) + state.additional_gas_next_op += 2900 - 100 + } + } + + // set slot is_warm to true + slots.insert(key, StorageSlotState::WarmWithSSTORE); + } + None | Some(StorageSlotState::Cool) => { + // Means SSTORE was called without SLOAD before + // COLD_SLOAD_COST is 2100 + state.additional_gas_next_op = 2100; + + // cool keeps the slot value changes + // as if the previous_or_original_value = present_value` + // so include the extra gas + let slot = account.storage.get(&key).unwrap(); + if val != slot.present_value && + slot.present_value != slot.previous_or_original_value + { + if slot.present_value == U256::ZERO { + // SSTORE_SET_GAS is 20000 + state.additional_gas_next_op += 20000 - 100 + } else { + // SSTORE_RESET_GAS is 5000 - COLD_SLOAD_COST (2100) + state.additional_gas_next_op += 2900 - 100 + } + } + slots.insert(key, StorageSlotState::WarmWithSSTORE); + } + Some(StorageSlotState::WarmWithSSTORE) => {} + } + } else { + slots.insert(key, StorageSlotState::WarmWithSSTORE); + } + } + _ => {} + } + } + } + + InstructionResult::Continue } impl Cheatcodes { @@ -523,6 +699,16 @@ impl Inspector for Cheatcodes { InstructionResult::Continue } + fn step_end( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + eval: InstructionResult, + ) -> InstructionResult { + add_gas_from_cool_cheatcode(self, interpreter, data); + eval + } + fn log(&mut self, _: &mut EVMData<'_, DB>, address: &Address, topics: &[B256], data: &Bytes) { if !self.expected_emits.is_empty() { expect::handle_expect_emit(self, address, topics, data); diff --git a/testdata/cheats/Cool.t.sol b/testdata/cheats/Cool.t.sol index 674b9abbd3426..83c99663d72f9 100644 --- a/testdata/cheats/Cool.t.sol +++ b/testdata/cheats/Cool.t.sol @@ -6,55 +6,461 @@ import "./Vm.sol"; contract CoolTest is DSTest { Vm constant vm = Vm(HEVM_ADDRESS); - uint256 public slot0 = 1; + uint256 public slot0; + uint256 public slot1 = 1; + MiniERC public erc; + + function setUp() public { + erc = new MiniERC(); + erc.mint(address(1337), 1 ether); + } function testCool_SLOAD_normal() public { uint256 startGas; uint256 endGas; - uint256 val; uint256 beforeCoolGas; uint256 noCoolGas; startGas = gasleft(); - val = slot0; + uint256 val = slot0; endGas = gasleft(); beforeCoolGas = startGas - endGas; startGas = gasleft(); - val = slot0; + uint256 val2 = slot0; endGas = gasleft(); noCoolGas = startGas - endGas; + assertEq(val, val2); assertGt(beforeCoolGas, noCoolGas); } function testCool_SLOAD() public { uint256 startGas; uint256 endGas; - uint256 val; uint256 beforeCoolGas; uint256 afterCoolGas; - uint256 noCoolGas; + uint256 warmGas; + uint256 secondCoolGas; + uint256 extraGas; startGas = gasleft(); - val = slot0; + uint256 val = slot0; endGas = gasleft(); beforeCoolGas = startGas - endGas; vm.cool(address(this)); startGas = gasleft(); - val = slot0; + uint256 val2 = slot0; endGas = gasleft(); afterCoolGas = startGas - endGas; + extraGas = afterCoolGas - 2100; + assertEq(val, val2); assertEq(beforeCoolGas, afterCoolGas); + assertEq(beforeCoolGas, 2100 + extraGas); startGas = gasleft(); - val = slot0; + uint256 val3 = slot0; endGas = gasleft(); - noCoolGas = startGas - endGas; + warmGas = startGas - endGas; - assertGt(beforeCoolGas, noCoolGas); + assertEq(val2, val3); + assertGt(beforeCoolGas, warmGas); + assertEq(warmGas, 100 + extraGas); + + // cool again to see if same resut + vm.cool(address(this)); + + startGas = gasleft(); + uint256 val4 = slot0; + endGas = gasleft(); + secondCoolGas = startGas - endGas; + + assertEq(val, val4); + assertEq(beforeCoolGas, secondCoolGas); + assertEq(beforeCoolGas, 2100 + extraGas); + } + + // check if slot value is preserved + function testCool_SSTORE_check_slot_value() public { + slot0 = 2; + assertEq(slot0, 2); + assertEq(slot1, 1); + + vm.cool(address(this)); + assertEq(slot0, 2); + assertEq(slot1, 1); + + slot0 = 3; + assertEq(slot0, 3); + assertEq(slot1, 1); + + vm.cool(address(this)); + assertEq(slot0, 3); + assertEq(slot1, 1); + + slot0 = 8; + slot1 = 9; + + vm.cool(address(this)); + assertEq(slot0, 8); + assertEq(slot1, 9); + } + + function testCool_SSTORE_nonzero_to_nonzero() public { + uint256 startGas; + uint256 endGas; + uint256 beforeCoolGas; + uint256 afterCoolGas; + uint256 warmGas; + uint256 extraGas; + + // start as non-zero + startGas = gasleft(); + slot1 = 2; // 5k gas + endGas = gasleft(); + beforeCoolGas = startGas - endGas; + extraGas = beforeCoolGas - 2900 - 2100; + assertEq(slot1, 2); + assertEq(beforeCoolGas, 2900 + 2100 + extraGas); + + // cool and set to same value + vm.cool(address(this)); + + startGas = gasleft(); + slot1 = 2; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot1, 2); + assertEq(afterCoolGas, 100 + 2100 + extraGas); + + // cool and set from non-zero to another non-zero + vm.cool(address(this)); + + startGas = gasleft(); + slot1 = 3; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot1, 3); + assertEq(afterCoolGas, 2900 + 2100 + extraGas); + + // don't cool and set non-zero to another non-zero + startGas = gasleft(); + slot1 = 3; // 100 gas + endGas = gasleft(); + warmGas = startGas - endGas; + assertEq(slot1, 3); + assertGt(afterCoolGas, warmGas); + assertEq(warmGas, 100 + extraGas); + + // don't cool and set non-zero to another non-zero + startGas = gasleft(); + slot1 = 4; // 100 gas + endGas = gasleft(); + warmGas = startGas - endGas; + assertEq(slot1, 4); + assertGt(afterCoolGas, warmGas); + assertEq(warmGas, 100 + extraGas); + } + + function testCool_SSTORE_zero_to_nonzero() public { + uint256 startGas; + uint256 endGas; + uint256 afterCoolGas; + uint256 warmGas; + uint256 extraGas; + + // start as zero + // set from zero to non-zero + startGas = gasleft(); + slot0 = 1; // 22.1k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + extraGas = afterCoolGas - 20000 - 2100; + assertEq(slot0, 1); + assertEq(afterCoolGas, 20000 + 2100 + extraGas); + + slot0 = 0; + vm.cool(address(this)); + + // set from zero to non-zero + startGas = gasleft(); + slot0 = 1; // 22.1k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 1); + assertEq(afterCoolGas, 20000 + 2100 + extraGas); + + // don't cool and set non-zero to another non-zero + startGas = gasleft(); + slot0 = 2; // 100 + endGas = gasleft(); + warmGas = startGas - endGas; + assertEq(slot0, 2); // persisted state + assertGt(afterCoolGas, warmGas); + assertEq(warmGas, 100 + extraGas); + + // cool again + // set from non-zero to non-zero + vm.cool(address(this)); + startGas = gasleft(); + slot0 = 1; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 1); + assertEq(afterCoolGas, 2900 + 2100 + extraGas); + + // cool again, set to zero + // set from zero to non-zero + slot0 = 0; + vm.cool(address(this)); + startGas = gasleft(); + slot0 = 1; // 22.1k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 1); + assertEq(afterCoolGas, 20000 + 2100 + extraGas); + + // cool again + // set to same value + vm.cool(address(this)); + startGas = gasleft(); + slot0 = 1; // 2.2k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 1); + assertEq(afterCoolGas, 100 + 2100 + extraGas); + } + + function testCool_SSTORE_Multiple() public { + uint256 startGas; + uint256 endGas; + uint256 afterCoolGas; + uint256 extraGas; + + // start as zero + assertEq(slot0, 0); + + vm.cool(address(this)); + vm.cool(address(this)); + + // set from zero to non-zero + startGas = gasleft(); + slot0 = 3; // 22.1k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + extraGas = afterCoolGas - 20000 - 2100; + assertEq(slot0, 3); + assertEq(afterCoolGas, 20000 + 2100 + extraGas); + + vm.cool(address(this)); + vm.cool(address(this)); + vm.cool(address(this)); + + // set from non-zero to non-zero + startGas = gasleft(); + slot0 = 2; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 2); + assertEq(afterCoolGas, 2900 + 2100 + extraGas); + } + + function testCool_Once() public { + uint256 startGas; + uint256 endGas; + uint256 afterCoolGas; + uint256 extraGas; + + // start as zero + assertEq(slot0, 0); + vm.cool(address(this)); + + // set from zero to non-zero + startGas = gasleft(); + slot0 = 3; // 22.1k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 3); + extraGas = afterCoolGas - 20000 - 2100; + + // set from non-zero to non-zero + startGas = gasleft(); + slot0 = 2; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 2); + assertEq(afterCoolGas, 100 + extraGas); + + // set to same + startGas = gasleft(); + slot0 = 2; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 2); + assertEq(afterCoolGas, 100 + extraGas); + + // set from non-zero to non-zero + startGas = gasleft(); + slot0 = 4; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 4); + assertEq(afterCoolGas, 100 + extraGas); + + // set from non-zero to zero + startGas = gasleft(); + slot0 = 0; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 0); + assertEq(afterCoolGas, 100 + extraGas); + + // set from zero to non-zero + startGas = gasleft(); + slot0 = 1; // 5k gas + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(slot0, 1); + assertEq(afterCoolGas, 20000 + extraGas); + } + + function testCool_call() public { + uint256 startGas; + uint256 endGas; + uint256 afterCoolGas; + + TestContract test = new TestContract(); + + // zero to 1 (20k) but slot is warm + startGas = gasleft(); + test.setSlot0(1); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 1); + assertGt(afterCoolGas, 20000); + + test.setSlot0(0); + vm.cool(address(test)); + + // zero to 1 (20k) and slot is cold + startGas = gasleft(); + test.setSlot0(2); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 2); + assertGt(afterCoolGas, 20000 + 2100); + + test.setSlot0(1); + vm.cool(address(test)); + + // 1 to 2 (2900) and slot is cold + startGas = gasleft(); + test.setSlot0(2); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 2); + assertGt(afterCoolGas, 2900 + 2100); + + test.setSlot0(1); + vm.cool(address(test)); + + // 1 to 1 (100 gas) and slot is cold + startGas = gasleft(); + test.setSlot0(1); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 1); + assertGt(afterCoolGas, 100 + 2100); + + test.setBoth(0); + vm.cool(address(test)); + + // both 0 to 1 (20k * 2) + startGas = gasleft(); + test.setBoth(1); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 1); + assertEq(test.slot1(), 1); + assertGt(afterCoolGas, 20000 * 2 + 2100 * 2); + + test.setSlot0(0); + vm.cool(address(test)); + + // slot0 from 0 to 2 (20k) + // slot1 from 1 to 2 (2900) + startGas = gasleft(); + test.setBoth(2); + endGas = gasleft(); + afterCoolGas = startGas - endGas; + assertEq(test.slot0(), 2); + assertEq(test.slot1(), 2); + assertGt(afterCoolGas, 20000 + 2900 + 2100 * 2); + } + + function testCool_Mint() public { + uint256 startGas; + uint256 endGas; + uint256 beforeGas; + + startGas = gasleft(); + erc.mint(address(1337), 0.01 ether); // 15462 + endGas = gasleft(); + beforeGas = startGas - endGas; + + vm.cool(address(erc)); + vm.cool(address(this)); + + startGas = gasleft(); + erc.mint(address(1337), 0.01 ether); // 15474 + endGas = gasleft(); + assertEq(beforeGas, startGas - endGas + 12); // ? + beforeGas = startGas - endGas; + + vm.cool(address(erc)); + vm.cool(address(this)); + + startGas = gasleft(); + erc.mint(address(1337), 0.01 ether); // 15474 + endGas = gasleft(); + assertEq(beforeGas, startGas - endGas); + + startGas = gasleft(); + erc.mint(address(1337), 0.01 ether); // 1362 + endGas = gasleft(); + assertLt(startGas - endGas, beforeGas); + } +} + +contract TestContract { + uint256 public slot0 = 0; + uint256 public slot1 = 1; + + function setSlot0(uint256 num) public { + slot0 = num; + } + + function setSlot1(uint256 num) public { + slot1 = num; + } + + function setBoth(uint256 num) public { + slot0 = num; + slot1 = num; + } +} + +contract MiniERC { + mapping(address => uint256) private _balances; + uint256 private _totalSupply; + + function mint(address to, uint256 amount) external { + _totalSupply += amount; + unchecked { + _balances[to] += amount; + } } }