diff --git a/contracts/tokenbridge/libraries/vault/MasterVault.sol b/contracts/tokenbridge/libraries/vault/MasterVault.sol index 121562c90..9e0893e5c 100644 --- a/contracts/tokenbridge/libraries/vault/MasterVault.sol +++ b/contracts/tokenbridge/libraries/vault/MasterVault.sol @@ -12,21 +12,20 @@ import {Initializable} from "@openzeppelin/contracts-upgradeable/proxy/utils/Ini import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {MathUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/math/MathUpgradeable.sol"; import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import {ReentrancyGuardUpgradeable} from "@openzeppelin/contracts-upgradeable/security/ReentrancyGuardUpgradeable.sol"; // todo: should we have an arbitrary call function for the vault manager to do stuff with the subvault? like queue withdrawals etc /// @notice MasterVault is an ERC4626 metavault that deposits assets to an admin defined subVault. -/// @dev If a subVault is not set, MasterVault shares entitle holders to a pro-rata share of the underlying held by the MasterVault. -/// If a subVault is set, MasterVault shares entitle holders to a pro-rata share of subVault shares held by the MasterVault. -/// On deposit to the MasterVault, if there is a subVault set, the assets are immediately deposited into the subVault. -/// On withdraw from the MasterVault, if there is a subVault set, a pro rata amount of subvault shares are redeemed. -/// On deposit and withdraw, if there is no subVault set, assets are moved to/from the MasterVault itself. +/// @dev The MasterVault keeps some fraction of assets idle and deposits the rest into the subVault to earn yield. +/// A 100% performance fee can be enabled/disabled by the vault manager, and are collected on demand. +/// The MasterVault mitigates the "first depositor" problem by adding 18 decimals to the underlying asset. +/// i.e. if the underlying asset has 6 decimals, the MasterVault will have 24 decimals. /// /// For a subVault to be compatible with the MasterVault, it must adhere to the following: -/// - It must be able to handle arbitrarily large deposits and withdrawals -/// - Deposit size or withdrawal size must not affect the exchange rate (i.e. no slippage) /// - convertToAssets and convertToShares must not be manipulable -contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradeable, PausableUpgradeable { +/// - must not have deposit / withdrawal fees (todo: verify this requirement is necessary) +contract MasterVault is Initializable, ReentrancyGuardUpgradeable, ERC4626Upgradeable, AccessControlUpgradeable, PausableUpgradeable { using SafeERC20 for IERC20; using MathUpgradeable for uint256; @@ -36,20 +35,25 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea /// @notice Pauser role can pause/unpause deposits and withdrawals (todo: pause should pause EVERYTHING) bytes32 public constant PAUSER_ROLE = keccak256("PAUSER_ROLE"); - error SubVaultAlreadySet(); + /// @notice Extra decimals added to the ERC20 decimals of the underlying asset to determine the decimals of the MasterVault + /// @dev This is done to mitigate the "first depositor" problem described in the OpenZeppelin ERC4626 documentation. + /// See https://docs.openzeppelin.com/contracts/5.x/erc4626 for more details on the mitigation. + uint8 public constant EXTRA_DECIMALS = 18; + error SubVaultAssetMismatch(); - error SubVaultExchangeRateTooLow(); - error NoExistingSubVault(); - error NewSubVaultExchangeRateTooLow(); error PerformanceFeeDisabled(); error BeneficiaryNotSet(); error InvalidAsset(); error InvalidOwner(); + error NonZeroTargetAllocation(uint256 targetAllocationWad); + error NonZeroSubVaultShares(uint256 subVaultShares); // todo: avoid inflation, rounding, other common 4626 vulns // we may need a minimum asset or master share amount when setting subvaults (bc of exchange rate calc) IERC4626 public subVault; + uint256 public targetAllocationWad; + /// @notice Flag indicating if performance fee is enabled bool public enablePerformanceFee; @@ -63,14 +67,16 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea event SubvaultChanged(address indexed oldSubvault, address indexed newSubvault); event PerformanceFeeToggled(bool enabled); event BeneficiaryUpdated(address indexed oldBeneficiary, address indexed newBeneficiary); - event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amount); - - function initialize(IERC20 _asset, string memory _name, string memory _symbol, address _owner) external initializer { - if (address(_asset) == address(0)) revert InvalidAsset(); - if (_owner == address(0)) revert InvalidOwner(); + event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amountTransferred, uint256 amountWithdrawn); + function initialize(IERC4626 _subVault, string memory _name, string memory _symbol, address _owner) external initializer { __ERC20_init(_name, _symbol); - __ERC4626_init(IERC20Upgradeable(address(_asset))); + __ERC4626_init(IERC20Upgradeable(_subVault.asset())); + + // call decimals() to ensure underlying has reasonable decimals and we won't have overflow + decimals(); + + __ReentrancyGuard_init(); __AccessControl_init(); __Pausable_init(); @@ -81,89 +87,77 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea _grantRole(VAULT_MANAGER_ROLE, _owner); _grantRole(PAUSER_ROLE, _owner); - _pause(); - } + // mint some dead shares to avoid first depositor issues + // for more information on the mitigation: + // https://web.archive.org/web/20250609034056/https://docs.openzeppelin.com/contracts/4.x/erc4626#fees + _mint(address(1), 10 ** EXTRA_DECIMALS); - function distributePerformanceFee() external whenNotPaused { - if (!enablePerformanceFee) revert PerformanceFeeDisabled(); - if (beneficiary == address(0)) { - revert BeneficiaryNotSet(); - } + IERC20(asset()).safeApprove(address(_subVault), type(uint256).max); - uint256 profit = totalProfit(MathUpgradeable.Rounding.Down); - if (profit == 0) return; + subVault = _subVault; + } - if (address(subVault) != address(0)) { - subVault.redeem(totalProfitInSubVaultShares(MathUpgradeable.Rounding.Down), beneficiary, address(this)); - } else { - IERC20(asset()).safeTransfer(beneficiary, profit); - } + function rebalance() external whenNotPaused nonReentrant { + _rebalance(); + } - emit PerformanceFeesWithdrawn(beneficiary, profit); + function distributePerformanceFee() external whenNotPaused nonReentrant { + _distributePerformanceFee(); } - /// @notice Set a subvault. Can only be called if there is not already a subvault set. + /// @notice Set a new subvault /// @param _subVault The subvault to set. Must be an ERC4626 vault with the same asset as this MasterVault. - /// @param minSubVaultExchRateWad Minimum acceptable ratio (times 1e18) of new subvault shares to outstanding MasterVault shares after deposit. - function setSubVault(IERC4626 _subVault, uint256 minSubVaultExchRateWad) external onlyRole(VAULT_MANAGER_ROLE) { + function setSubVault(IERC4626 _subVault) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { IERC20 underlyingAsset = IERC20(asset()); - if (address(subVault) != address(0)) revert SubVaultAlreadySet(); if (address(_subVault.asset()) != address(underlyingAsset)) revert SubVaultAssetMismatch(); - subVault = _subVault; + // we ensure target allocation is zero, therefore the master vault holds no subvault shares + if (targetAllocationWad != 0) revert NonZeroTargetAllocation(targetAllocationWad); - IERC20(asset()).safeApprove(address(_subVault), type(uint256).max); - _subVault.deposit(underlyingAsset.balanceOf(address(this)), address(this)); - - uint256 supply = totalSupply(); - if (supply > 0) { - uint256 subVaultExchRateWad = _subVault.balanceOf(address(this)).mulDiv(1e18, supply, MathUpgradeable.Rounding.Down); - if (subVaultExchRateWad < minSubVaultExchRateWad) revert NewSubVaultExchangeRateTooLow(); + // sanity check to ensure we have zero subvault shares before changing + if (subVault.balanceOf(address(this)) != 0) { + revert NonZeroSubVaultShares(subVault.balanceOf(address(this))); } - emit SubvaultChanged(address(0), address(_subVault)); - } - - /// @notice Revokes the current subvault, moving all assets back to MasterVault - /// @param minAssetExchRateWad Minimum acceptable ratio (times 1e18) of assets received from subvault to outstanding MasterVault shares - function revokeSubVault(uint256 minAssetExchRateWad) external onlyRole(VAULT_MANAGER_ROLE) { - IERC4626 oldSubVault = subVault; - if (address(oldSubVault) == address(0)) revert NoExistingSubVault(); - - subVault = IERC4626(address(0)); + address oldSubVault = address(subVault); + subVault = _subVault; - oldSubVault.redeem(oldSubVault.balanceOf(address(this)), address(this), address(this)); - IERC20(asset()).safeApprove(address(oldSubVault), 0); + if (oldSubVault != address(0)) IERC20(asset()).safeApprove(address(oldSubVault), 0); + IERC20(asset()).safeApprove(address(_subVault), type(uint256).max); - uint256 supply = totalSupply(); - if (supply > 0) { - uint256 assetExchRateWad = IERC20(asset()).balanceOf(address(this)).mulDiv(1e18, supply, MathUpgradeable.Rounding.Down); - if (assetExchRateWad < minAssetExchRateWad) revert SubVaultExchangeRateTooLow(); - } + emit SubvaultChanged(oldSubVault, address(_subVault)); + } - emit SubvaultChanged(address(oldSubVault), address(0)); + function setTargetAllocationWad(uint256 _targetAllocationWad) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { + require(_targetAllocationWad <= 1e18, "Target allocation must be <= 100%"); + require(targetAllocationWad != _targetAllocationWad, "Allocation unchanged"); + targetAllocationWad = _targetAllocationWad; + _rebalance(); } /// @notice Toggle performance fee collection on/off + /// @dev Not explicitly marked nonReentrant because distributePerformanceFee is called within + /// this function and is nonReentrant itself. /// @param enabled True to enable performance fees, false to disable - function setPerformanceFee(bool enabled) external onlyRole(VAULT_MANAGER_ROLE) { - enablePerformanceFee = enabled; - + function setPerformanceFee(bool enabled) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { // reset totalPrincipal to current totalAssets when enabling performance fee // this prevents a sudden large profit if (enabled) { totalPrincipal = _totalAssets(MathUpgradeable.Rounding.Up); } else { + _distributePerformanceFee(); totalPrincipal = 0; } + enablePerformanceFee = enabled; + emit PerformanceFeeToggled(enabled); } /// @notice Set the beneficiary address for performance fees - /// @param newBeneficiary Address to receive performance fees, zero address defaults to owner - function setBeneficiary(address newBeneficiary) external onlyRole(VAULT_MANAGER_ROLE) { + /// @param newBeneficiary Address to receive performance fees + function setBeneficiary(address newBeneficiary) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { address oldBeneficiary = beneficiary; beneficiary = newBeneficiary; emit BeneficiaryUpdated(oldBeneficiary, newBeneficiary); @@ -176,6 +170,11 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea function unpause() external onlyRole(PAUSER_ROLE) { _unpause(); } + + /// @dev Overridden to add EXTRA_DECIMALS to the underlying asset decimals + function decimals() public view override returns (uint8) { + return super.decimals() + EXTRA_DECIMALS; + } /** @dev See {IERC4626-totalAssets}. */ function totalAssets() public view virtual override returns (uint256) { @@ -192,14 +191,12 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea // /** @dev See {IERC4626-maxMint}. */ function maxMint(address) public view virtual override returns (uint256) { - if (address(subVault) == address(0)) { - return type(uint256).max; - } uint256 subShares = subVault.maxMint(address(this)); if (subShares == type(uint256).max) { return type(uint256).max; } - return totalSupply().mulDiv(subShares, subVault.balanceOf(address(this)), MathUpgradeable.Rounding.Down); // todo: check rounding direction + uint256 assets = _subVaultSharesToAssets(subShares, MathUpgradeable.Rounding.Down); + return _convertToShares(assets, MathUpgradeable.Rounding.Down); } function totalProfit(MathUpgradeable.Rounding rounding) public view returns (uint256) { @@ -207,15 +204,53 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea return __totalAssets > totalPrincipal ? __totalAssets - totalPrincipal : 0; } - function totalProfitInSubVaultShares(MathUpgradeable.Rounding rounding) public view returns (uint256) { - if (address(subVault) == address(0)) { - revert("Subvault not set"); + function _rebalance() internal { + uint256 totalAssetsUp = _totalAssetsLessProfit(MathUpgradeable.Rounding.Up); + uint256 totalAssetsDown = _totalAssetsLessProfit(MathUpgradeable.Rounding.Down); + uint256 idleTargetUp = totalAssetsUp.mulDiv(1e18 - targetAllocationWad, 1e18, MathUpgradeable.Rounding.Up); + uint256 idleTargetDown = totalAssetsDown.mulDiv(1e18 - targetAllocationWad, 1e18, MathUpgradeable.Rounding.Down); + uint256 idleBalance = IERC20(asset()).balanceOf(address(this)); + + if (idleTargetDown <= idleBalance && idleBalance <= idleTargetUp) { + return; + } + + if (idleBalance < idleTargetDown) { + // we need to withdraw from subvault + uint256 assetsToWithdraw = idleTargetDown - idleBalance; + subVault.withdraw(assetsToWithdraw, address(this), address(this)); + } + else { + // we need to deposit into subvault + uint256 assetsToDeposit = idleBalance - idleTargetUp; + subVault.deposit(assetsToDeposit, address(this)); + } + } + + function _distributePerformanceFee() internal { + if (!enablePerformanceFee) revert PerformanceFeeDisabled(); + if (beneficiary == address(0)) { + revert BeneficiaryNotSet(); + } + + uint256 profit = totalProfit(MathUpgradeable.Rounding.Down); + if (profit == 0) return; + + uint256 totalIdle = IERC20(asset()).balanceOf(address(this)); + + uint256 amountToTransfer = profit <= totalIdle ? profit : totalIdle; + uint256 amountToWithdraw = profit - amountToTransfer; + + if (amountToTransfer > 0) { + IERC20(asset()).safeTransfer(beneficiary, amountToTransfer); } - uint256 profitAssets = totalProfit(rounding); - if (profitAssets == 0) { - return 0; + if (amountToWithdraw > 0) { + subVault.withdraw(amountToWithdraw, beneficiary, address(this)); } - return _assetsToSubVaultShares(profitAssets, rounding); + + _rebalance(); + + emit PerformanceFeesWithdrawn(beneficiary, amountToTransfer, amountToWithdraw); } /** @@ -226,15 +261,10 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea address receiver, uint256 assets, uint256 shares - ) internal virtual override whenNotPaused { + ) internal override whenNotPaused nonReentrant { super._deposit(caller, receiver, assets, shares); - if (enablePerformanceFee) totalPrincipal += assets; - - IERC4626 _subVault = subVault; - if (address(_subVault) != address(0)) { - _subVault.deposit(assets, address(this)); - } + _rebalance(); } /** @@ -246,22 +276,19 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea address _owner, uint256 assets, uint256 shares - ) internal virtual override whenNotPaused { + ) internal override whenNotPaused nonReentrant { if (enablePerformanceFee) totalPrincipal -= assets; - - IERC4626 _subVault = subVault; - if (address(_subVault) != address(0)) { - _subVault.withdraw(assets, address(this), address(this)); + uint256 idleAssets = IERC20(asset()).balanceOf(address(this)); + if (idleAssets < assets) { + uint256 assetsToWithdraw = assets - idleAssets; + subVault.withdraw(assetsToWithdraw, address(this), address(this)); } - super._withdraw(caller, receiver, _owner, assets, shares); + _rebalance(); } function _totalAssets(MathUpgradeable.Rounding rounding) internal view returns (uint256) { - if (address(subVault) == address(0)) { - return IERC20(asset()).balanceOf(address(this)); - } - return _subVaultSharesToAssets(subVault.balanceOf(address(this)), rounding); + return IERC20(asset()).balanceOf(address(this)) + _subVaultSharesToAssets(subVault.balanceOf(address(this)), rounding); } /** @@ -271,77 +298,32 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea * would represent an infinite amount of shares. */ function _convertToShares(uint256 assets, MathUpgradeable.Rounding rounding) internal view virtual override returns (uint256 shares) { - uint256 supply = totalSupply(); - - if (address(subVault) == address(0)) { - uint256 effectiveTotalAssets = enablePerformanceFee ? _min(totalAssets(), totalPrincipal) : totalAssets(); - - if (supply == 0 || effectiveTotalAssets == 0) { - return assets; - } - - return supply.mulDiv(assets, effectiveTotalAssets, rounding); - } - - uint256 totalSubShares = subVault.balanceOf(address(this)); - - if (enablePerformanceFee) { - // since we use totalSubShares in the denominator of the final calculation, - // and we are subtracting profit from it, we should use the same rounding direction for profit - totalSubShares -= totalProfitInSubVaultShares(_flipRounding(rounding)); - } - - uint256 subShares = _assetsToSubVaultShares(assets, rounding); - - if (supply == 0 || totalSubShares == 0) { - return subShares; - } - - return supply.mulDiv(subShares, totalSubShares, rounding); + // we add one as part of the first deposit mitigation + // see for details: https://docs.openzeppelin.com/contracts/5.x/erc4626 + return assets.mulDiv(totalSupply(), _totalAssetsLessProfit(_flipRounding(rounding)) + 1, rounding); } /** * @dev Internal conversion function (from shares to assets) with support for rounding direction. */ function _convertToAssets(uint256 shares, MathUpgradeable.Rounding rounding) internal view virtual override returns (uint256 assets) { - uint256 _totalSupply = totalSupply(); - - if(_totalSupply == 0) { - return shares; - } - - // if we have no subvault, we just do normal pro-rata calculation - if (address(subVault) == address(0)) { - uint256 effectiveTotalAssets = enablePerformanceFee ? _min(totalAssets(), totalPrincipal) : totalAssets(); - return effectiveTotalAssets.mulDiv(shares, _totalSupply, rounding); - } - - uint256 totalSubShares = subVault.balanceOf(address(this)); + // we add one as part of the first deposit mitigation + // see for details: https://docs.openzeppelin.com/contracts/5.x/erc4626 + return shares.mulDiv(_totalAssetsLessProfit(rounding) + 1, totalSupply(), rounding); + } + function _totalAssetsLessProfit(MathUpgradeable.Rounding rounding) internal view returns (uint256) { + uint256 __totalAssets = _totalAssets(rounding); if (enablePerformanceFee) { - // since we use totalSubShares in the numerator of the final calculation, - // and we are subtracting profit from it, we should use the opposite rounding direction for profit - totalSubShares -= totalProfitInSubVaultShares(_flipRounding(rounding)); + __totalAssets -= totalProfit(_flipRounding(rounding)); } - - // totalSubShares * shares / totalMasterShares - uint256 subShares = totalSubShares.mulDiv(shares, _totalSupply, rounding); - - return _subVaultSharesToAssets(subShares, rounding); - } - - function _assetsToSubVaultShares(uint256 assets, MathUpgradeable.Rounding rounding) internal view returns (uint256 subShares) { - return rounding == MathUpgradeable.Rounding.Up ? subVault.previewWithdraw(assets) : subVault.previewDeposit(assets); + return __totalAssets; } function _subVaultSharesToAssets(uint256 subShares, MathUpgradeable.Rounding rounding) internal view returns (uint256 assets) { return rounding == MathUpgradeable.Rounding.Up ? subVault.previewMint(subShares) : subVault.previewRedeem(subShares); } - function _min(uint256 a, uint256 b) internal pure returns (uint256) { - return a <= b ? a : b; - } - function _flipRounding(MathUpgradeable.Rounding rounding) internal pure returns (MathUpgradeable.Rounding) { return rounding == MathUpgradeable.Rounding.Up ? MathUpgradeable.Rounding.Down : MathUpgradeable.Rounding.Up; } diff --git a/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol b/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol index 0259c84d4..87e63c1ae 100644 --- a/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol +++ b/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol @@ -2,14 +2,19 @@ pragma solidity ^0.8.0; -import "@openzeppelin/contracts/utils/Create2.sol"; import "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; +import "@openzeppelin/contracts/token/ERC20/extensions/ERC4626.sol"; import "../ClonableBeaconProxy.sol"; import "./IMasterVault.sol"; import "./IMasterVaultFactory.sol"; import "./MasterVault.sol"; +contract DefaultSubVault is ERC4626 { + constructor(address token) ERC4626(IERC20(token)) ERC20("Default SubVault", "DSV") {} +} + +// todo: slim down this contract contract MasterVaultFactory is IMasterVaultFactory, Initializable { error ZeroAddress(); error BeaconNotDeployed(); @@ -27,23 +32,13 @@ contract MasterVaultFactory is IMasterVaultFactory, Initializable { } function deployVault(address token) public returns (address vault) { - if (token == address(0)) { - revert ZeroAddress(); - } - if ( - address(beaconProxyFactory) == address(0) && beaconProxyFactory.beacon() == address(0) - ) { - revert BeaconNotDeployed(); - } - bytes32 userSalt = _getUserSalt(token); vault = beaconProxyFactory.createProxy(userSalt); - IERC20Metadata tokenMetadata = IERC20Metadata(token); - string memory name = string(abi.encodePacked("Master ", tokenMetadata.name())); - string memory symbol = string(abi.encodePacked("m", tokenMetadata.symbol())); + string memory name = string(abi.encodePacked("Master ", _tryGetTokenName(token))); + string memory symbol = string(abi.encodePacked("m", _tryGetTokenSymbol(token))); - MasterVault(vault).initialize(IERC20(token), name, symbol, owner); + MasterVault(vault).initialize(new DefaultSubVault(token), name, symbol, owner); emit VaultDeployed(token, vault); } @@ -64,4 +59,20 @@ contract MasterVaultFactory is IMasterVaultFactory, Initializable { } return vault; } + + function _tryGetTokenName(address token) internal view returns (string memory) { + try IERC20Metadata(token).name() returns (string memory name) { + return name; + } catch { + return ""; + } + } + + function _tryGetTokenSymbol(address token) internal view returns (string memory) { + try IERC20Metadata(token).symbol() returns (string memory symbol) { + return symbol; + } catch { + return ""; + } + } } diff --git a/test-foundry/libraries/vault/MasterVault.t.sol b/test-foundry/libraries/vault/MasterVault.t.sol index 42563d9c2..639289253 100644 --- a/test-foundry/libraries/vault/MasterVault.t.sol +++ b/test-foundry/libraries/vault/MasterVault.t.sol @@ -5,145 +5,176 @@ import { MasterVaultCoreTest } from "./MasterVaultCore.t.sol"; import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; +import {console2} from "forge-std/console2.sol"; + +contract MasterVaultFirstDepositTest is MasterVaultCoreTest { + using Math for uint256; + + uint256 constant FRESH_STATE_PLACEHOLDER = uint256(keccak256("FRESH_STATE_PLACEHOLDER")); + uint256 constant DEAD_SHARES = 10**18; + + struct State { + uint256 userShares; + uint256 masterVaultTotalAssets; + uint256 masterVaultTotalSupply; + uint256 masterVaultTokenBalance; + uint256 masterVaultSubVaultShareBalance; + uint256 subVaultTotalAssets; + uint256 subVaultTotalSupply; + uint256 subVaultTokenBalance; + } -contract MasterVaultTest is MasterVaultCoreTest { // first deposit - function test_deposit() public { - address _assetsHoldingVault = address(vault.subVault()) == address(0) - ? address(vault) - : address(vault.subVault()); - uint256 _assetsHoldingVaultBalanceBefore = token.balanceOf(_assetsHoldingVault); - + function test_deposit(uint96 _depositAmount) public { + uint256 depositAmount = _depositAmount; vm.startPrank(user); - token.mint(); - uint256 depositAmount = 100; - + token.mint(depositAmount); token.approve(address(vault), depositAmount); - uint256 shares = vault.deposit(depositAmount, user); - - uint256 _assetsHoldingVaultBalanceAfter = token.balanceOf(_assetsHoldingVault); - uint256 diff = _assetsHoldingVaultBalanceAfter - _assetsHoldingVaultBalanceBefore; - - assertEq(vault.balanceOf(user), shares, "User should receive shares"); - assertEq(vault.totalAssets(), depositAmount, "Vault should hold deposited assets"); - assertEq(vault.totalSupply(), shares, "Total supply should equal shares minted"); - - assertEq(diff, depositAmount, "Vault should increase holding of assets"); - assertGt(token.balanceOf(_assetsHoldingVault), 0, "Vault should hold the tokens"); - - assertEq(vault.totalSupply(), diff, "First deposit should be at a rate of 1"); - vm.stopPrank(); + _checkState(State({ + userShares: depositAmount * DEAD_SHARES, + masterVaultTotalAssets: depositAmount, + masterVaultTotalSupply: (1 + depositAmount) * DEAD_SHARES, + masterVaultTokenBalance: depositAmount, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(shares, depositAmount * DEAD_SHARES, "shares mismatch deposit return value"); } - // first mint - function test_mint() public { - address _assetsHoldingVault = address(vault.subVault()) == address(0) - ? address(vault) - : address(vault.subVault()); - - uint256 _assetsHoldingVaultBalanceBefore = token.balanceOf(_assetsHoldingVault); - + function test_mint(uint96 _mintAmount) public { + uint256 mintAmount = _mintAmount; vm.startPrank(user); - token.mint(); - uint256 sharesToMint = 100; - - token.approve(address(vault), type(uint256).max); - - // assertEq(1, vault.totalAssets(), "First mint should be at a rate of 1"); // 0 - // assertEq(1, vault.totalSupply(), "First mint should be at a rate of 1"); // 0 - - - uint256 assetsCost = vault.mint(sharesToMint, user); - - uint256 _assetsHoldingVaultBalanceAfter = token.balanceOf(_assetsHoldingVault); - - assertEq(vault.balanceOf(user), sharesToMint, "User should receive requested shares"); - assertEq(vault.totalSupply(), sharesToMint, "Total supply should equal shares minted"); - assertEq(vault.totalAssets(), assetsCost, "Vault should hold the assets deposited"); - assertEq( - _assetsHoldingVaultBalanceAfter - _assetsHoldingVaultBalanceBefore, - assetsCost, - "Vault should hold the tokens" - ); - - assertEq(vault.totalSupply(), vault.totalAssets(), "First mint should be at a rate of 1"); + token.mint(mintAmount); + token.approve(address(vault), mintAmount); + uint256 assets = vault.mint(mintAmount, user); vm.stopPrank(); + _checkState(State({ + userShares: mintAmount, + masterVaultTotalAssets: mintAmount.ceilDiv(1e18), + masterVaultTotalSupply: mintAmount + DEAD_SHARES, + masterVaultTokenBalance: mintAmount.ceilDiv(1e18), + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(assets, mintAmount.ceilDiv(1e18), "assets mismatch mint return value"); } - function test_withdraw() public { + function test_withdraw(uint96 _firstDeposit, uint96 _withdrawAmount) public { + uint256 firstDeposit = _firstDeposit; + uint256 withdrawAmount = _withdrawAmount; + vm.assume(withdrawAmount <= firstDeposit); + test_deposit(_firstDeposit); vm.startPrank(user); - token.mint(); - uint256 depositAmount = token.balanceOf(user); - token.approve(address(vault), depositAmount); - vault.deposit(depositAmount, user); - - uint256 userSharesBefore = vault.balanceOf(user); - uint256 withdrawAmount = depositAmount; // withdraw all assets - uint256 sharesRedeemed = vault.withdraw(withdrawAmount, user, user); - - assertEq(vault.balanceOf(user), 0, "User should have no shares left"); - assertEq(token.balanceOf(user), depositAmount, "User should receive all withdrawn tokens"); - assertEq(vault.totalAssets(), 0, "Vault should have no assets left"); - assertEq(vault.totalSupply(), 0, "Total supply should be zero"); - assertEq(token.balanceOf(address(vault)), 0, "Vault should have no tokens left"); - assertEq(sharesRedeemed, userSharesBefore, "All shares should be redeemed"); - vm.stopPrank(); + _checkState(State({ + userShares: (firstDeposit - withdrawAmount) * DEAD_SHARES, + masterVaultTotalAssets: firstDeposit - withdrawAmount, + masterVaultTotalSupply: (1 + firstDeposit - withdrawAmount) * DEAD_SHARES, + masterVaultTokenBalance: firstDeposit - withdrawAmount, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(sharesRedeemed, withdrawAmount * DEAD_SHARES, "sharesRedeemed mismatch withdraw return value"); } - function test_redeem() public { + function test_redeem(uint96 _firstMint, uint96 _redeemAmount) public { + uint256 firstMint = _firstMint; + uint256 redeemAmount = _redeemAmount; + vm.assume(redeemAmount <= firstMint); + test_mint(_firstMint); + State memory beforeState = _getState(); vm.startPrank(user); - token.mint(); - uint256 depositAmount = token.balanceOf(user); - token.approve(address(vault), depositAmount); - uint256 shares = vault.deposit(depositAmount, user); - - uint256 sharesToRedeem = shares; // redeem all shares - - uint256 assetsReceived = vault.redeem(sharesToRedeem, user, user); - - assertEq(vault.balanceOf(user), 0, "User should have no shares left"); - assertEq(token.balanceOf(user), depositAmount, "User should receive all assets back"); - assertEq(vault.totalAssets(), 0, "Vault should have no assets left"); - assertEq(vault.totalSupply(), 0, "Total supply should be zero"); - assertEq(token.balanceOf(address(vault)), 0, "Vault should have no tokens left"); - assertEq(assetsReceived, depositAmount, "All assets should be received"); - + uint256 assets = vault.redeem(redeemAmount, user, user); + uint256 expectedAssets = (1 + beforeState.masterVaultTotalAssets) * redeemAmount / (beforeState.masterVaultTotalSupply); vm.stopPrank(); + _checkState(State({ + userShares: beforeState.userShares - redeemAmount, + masterVaultTotalAssets: beforeState.masterVaultTotalAssets - expectedAssets, + masterVaultTotalSupply: beforeState.masterVaultTotalSupply - redeemAmount, + masterVaultTokenBalance: beforeState.masterVaultTokenBalance - expectedAssets, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(assets, expectedAssets, "assets mismatch redeem return value"); } -} -contract MasterVaultTestWithSubvaultFresh is MasterVaultTest { - function setUp() public override { - super.setUp(); - MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); - vault.setSubVault(IERC4626(address(_subvault)), 0); + function _checkState(State memory expectedState) internal { + assertEq(expectedState.userShares, vault.balanceOf(user), "userShares mismatch"); + assertEq(expectedState.masterVaultTotalAssets, vault.totalAssets(), "masterVaultTotalAssets mismatch"); + assertEq(expectedState.masterVaultTotalSupply, vault.totalSupply(), "masterVaultTotalSupply mismatch"); + assertEq(expectedState.masterVaultTokenBalance, token.balanceOf(address(vault)), "masterVaultTokenBalance mismatch"); + assertEq(expectedState.masterVaultSubVaultShareBalance, vault.subVault().balanceOf(address(vault)), "masterVaultSubVaultShareBalance mismatch"); + assertEq(expectedState.subVaultTotalAssets, vault.subVault().totalAssets(), "subVaultTotalAssets mismatch"); + assertEq(expectedState.subVaultTotalSupply, vault.subVault().totalSupply(), "subVaultTotalSupply mismatch"); + assertEq(expectedState.subVaultTokenBalance, token.balanceOf(address(vault.subVault())), "subVaultTokenBalance mismatch"); } -} - -contract MasterVaultTestWithSubvaultHoldingAssets is MasterVaultTest { - function setUp() public override { - super.setUp(); - MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); - uint256 _initAmount = 97659743; - token.mint(_initAmount); - token.approve(address(_subvault), _initAmount); - _subvault.deposit(_initAmount, address(this)); - assertEq( - _initAmount, - _subvault.totalAssets(), - "subvault should be initiated with assets = _initAmount" - ); - assertEq( - _initAmount, - _subvault.totalSupply(), - "subvault should be initiated with shares = _initAmount" - ); + function _getState() internal view returns (State memory) { + return State({ + userShares: vault.balanceOf(user), + masterVaultTotalAssets: vault.totalAssets(), + masterVaultTotalSupply: vault.totalSupply(), + masterVaultTokenBalance: token.balanceOf(address(vault)), + masterVaultSubVaultShareBalance: vault.subVault().balanceOf(address(vault)), + subVaultTotalAssets: vault.subVault().totalAssets(), + subVaultTotalSupply: vault.subVault().totalSupply(), + subVaultTokenBalance: token.balanceOf(address(vault.subVault())) + }); + } - vault.setSubVault(IERC4626(address(_subvault)), 0); + function _logState(string memory label, State memory state) internal view { + console2.log(label); + console2.log(" userShares:", state.userShares); + console2.log(" masterVaultTotalAssets:", state.masterVaultTotalAssets); + console2.log(" masterVaultTotalSupply:", state.masterVaultTotalSupply); + console2.log(" masterVaultTokenBalance:", state.masterVaultTokenBalance); + console2.log(" masterVaultSubVaultShareBalance:", state.masterVaultSubVaultShareBalance); + console2.log(" subVaultTotalAssets:", state.subVaultTotalAssets); + console2.log(" subVaultTotalSupply:", state.subVaultTotalSupply); + console2.log(" subVaultTokenBalance:", state.subVaultTokenBalance); } } + +// contract MasterVaultTestWithSubvaultFresh is MasterVaultTest { +// function setUp() public override { +// super.setUp(); +// MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); +// vault.setSubVault(IERC4626(address(_subvault))); +// } +// } + +// contract MasterVaultTestWithSubvaultHoldingAssets is MasterVaultTest { +// function setUp() public override { +// super.setUp(); + +// MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); +// uint256 _initAmount = 97659743; +// token.mint(_initAmount); +// token.approve(address(_subvault), _initAmount); +// _subvault.deposit(_initAmount, address(this)); +// assertEq( +// _initAmount, +// _subvault.totalAssets(), +// "subvault should be initiated with assets = _initAmount" +// ); +// assertEq( +// _initAmount, +// _subvault.totalSupply(), +// "subvault should be initiated with shares = _initAmount" +// ); + +// vault.setSubVault(IERC4626(address(_subvault))); +// } +// } diff --git a/test-foundry/libraries/vault/MasterVaultAttack.t.sol b/test-foundry/libraries/vault/MasterVaultAttack.t.sol index 29d9c3109..03c768f23 100644 --- a/test-foundry/libraries/vault/MasterVaultAttack.t.sol +++ b/test-foundry/libraries/vault/MasterVaultAttack.t.sol @@ -2,16 +2,16 @@ pragma solidity ^0.8.0; import "forge-std/console2.sol"; -import { MasterVaultTest } from "./MasterVault.t.sol"; +import { MasterVaultCoreTest } from "./MasterVaultCore.t.sol"; import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; -contract MasterVaultTestWithSubvaultFresh is MasterVaultTest { +contract MasterVaultTestWithSubvaultFresh is MasterVaultCoreTest { function setUp() public override { super.setUp(); MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); - vault.setSubVault(IERC4626(address(_subvault)), 0); + vault.setSubVault(IERC4626(address(_subvault))); } } diff --git a/test-foundry/libraries/vault/MasterVaultCore.t.sol b/test-foundry/libraries/vault/MasterVaultCore.t.sol index 5fd882359..fc5549578 100644 --- a/test-foundry/libraries/vault/MasterVaultCore.t.sol +++ b/test-foundry/libraries/vault/MasterVaultCore.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import { Test } from "forge-std/Test.sol"; import { MasterVault } from "../../../contracts/tokenbridge/libraries/vault/MasterVault.sol"; +import { MasterVaultFactory } from "../../../contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol"; import { TestERC20 } from "../../../contracts/tokenbridge/test/TestERC20.sol"; import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { UpgradeableBeacon } from "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; @@ -13,12 +14,11 @@ import { import { IAccessControl } from "@openzeppelin/contracts/access/IAccessControl.sol"; contract MasterVaultCoreTest is Test { + MasterVaultFactory public factory; MasterVault public vault; TestERC20 public token; - UpgradeableBeacon public beacon; - BeaconProxyFactory public beaconProxyFactory; - address public user = address(0x1); + address public user = vm.addr(1); string public name = "Master Test Token"; string public symbol = "mTST"; @@ -27,19 +27,9 @@ contract MasterVaultCoreTest is Test { } function setUp() public virtual { + factory = new MasterVaultFactory(); + factory.initialize(address(this)); token = new TestERC20(); - - MasterVault implementation = new MasterVault(); - beacon = new UpgradeableBeacon(address(implementation)); - - beaconProxyFactory = new BeaconProxyFactory(); - beaconProxyFactory.initialize(address(beacon)); - - bytes32 salt = keccak256("test"); - address proxyAddress = beaconProxyFactory.createProxy(salt); - vault = MasterVault(proxyAddress); - - vault.initialize(IERC20(address(token)), name, symbol, address(this)); - vault.unpause(); + vault = MasterVault(factory.deployVault(address(token))); } } diff --git a/test-foundry/libraries/vault/MasterVaultFactory.t.sol b/test-foundry/libraries/vault/MasterVaultFactory.t.sol index 48359c390..ec7e65fa4 100644 --- a/test-foundry/libraries/vault/MasterVaultFactory.t.sol +++ b/test-foundry/libraries/vault/MasterVaultFactory.t.sol @@ -45,7 +45,7 @@ contract MasterVaultFactoryTest is Test { } function test_deployVault_RevertZeroAddress() public { - vm.expectRevert(MasterVaultFactory.ZeroAddress.selector); + vm.expectRevert(); factory.deployVault(address(0)); } diff --git a/test-foundry/libraries/vault/MasterVaultFee.t.sol b/test-foundry/libraries/vault/MasterVaultFee.t.sol index 768e661e0..71d34aa88 100644 --- a/test-foundry/libraries/vault/MasterVaultFee.t.sol +++ b/test-foundry/libraries/vault/MasterVaultFee.t.sol @@ -19,10 +19,18 @@ contract MasterVaultFeeTest is MasterVaultCoreTest { assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); } - function test_setPerformanceFee_disable() public { + function test_cannotDisableWithoutBeneficiarySet() public { vault.setPerformanceFee(true); assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); + vm.expectRevert(MasterVault.BeneficiaryNotSet.selector); + vault.setPerformanceFee(false); + } + + function test_setPerformanceFee_disable() public { + vault.setPerformanceFee(true); + assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); + vault.setBeneficiary(beneficiaryAddress); vault.setPerformanceFee(false); assertFalse(vault.enablePerformanceFee(), "Performance fee should be disabled"); @@ -35,6 +43,8 @@ contract MasterVaultFeeTest is MasterVaultCoreTest { } function test_setPerformanceFee_emitsEvent() public { + vault.setBeneficiary(beneficiaryAddress); + vm.expectEmit(true, true, true, true); emit PerformanceFeeToggled(true); vault.setPerformanceFee(true); @@ -189,12 +199,10 @@ contract MasterVaultFeeTest is MasterVaultCoreTest { } function test_withdrawPerformanceFees_VaultDoubleInAssets() public { - vault.setPerformanceFee(true); vault.setBeneficiary(beneficiaryAddress); + vault.setPerformanceFee(true); - address _assetsHoldingVault = address(vault.subVault()) == address(0) - ? address(vault) - : address(vault.subVault()); + address _assetsHoldingVault = address(vault); // since allocation is 0 vm.startPrank(user); token.mint(); @@ -227,7 +235,7 @@ contract MasterVaultFeeTest is MasterVaultCoreTest { uint256 beneficiaryBalanceBefore = token.balanceOf(beneficiaryAddress); vm.expectEmit(true, true, true, true); - emit PerformanceFeesWithdrawn(beneficiaryAddress, depositAmount); + emit PerformanceFeesWithdrawn(beneficiaryAddress, depositAmount, 0); vault.distributePerformanceFee(); assertEq( @@ -244,14 +252,14 @@ contract MasterVaultFeeTest is MasterVaultCoreTest { event PerformanceFeeToggled(bool enabled); event BeneficiaryUpdated(address indexed oldBeneficiary, address indexed newBeneficiary); - event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amount); + event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amountTransferred, uint256 amountWithdrawn); } contract MasterVaultFeeTestWithSubvaultFresh is MasterVaultFeeTest { function setUp() public override { super.setUp(); MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); - vault.setSubVault(IERC4626(address(_subvault)), 0); + vault.setSubVault(IERC4626(address(_subvault))); } } @@ -275,6 +283,6 @@ contract MasterVaultFeeTestWithSubvaultHoldingAssets is MasterVaultFeeTest { "subvault should be initiated with shares = _initAmount" ); - vault.setSubVault(IERC4626(address(_subvault)), 0); + vault.setSubVault(IERC4626(address(_subvault))); } }