diff --git a/contracts/tokenbridge/libraries/vault/MasterVault.sol b/contracts/tokenbridge/libraries/vault/MasterVault.sol index 27c9dda09..9e0893e5c 100644 --- a/contracts/tokenbridge/libraries/vault/MasterVault.sol +++ b/contracts/tokenbridge/libraries/vault/MasterVault.sol @@ -12,21 +12,20 @@ import {Initializable} from "@openzeppelin/contracts-upgradeable/proxy/utils/Ini import {IERC20Metadata} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import {MathUpgradeable} from "@openzeppelin/contracts-upgradeable/utils/math/MathUpgradeable.sol"; import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import {ReentrancyGuardUpgradeable} from "@openzeppelin/contracts-upgradeable/security/ReentrancyGuardUpgradeable.sol"; // todo: should we have an arbitrary call function for the vault manager to do stuff with the subvault? like queue withdrawals etc /// @notice MasterVault is an ERC4626 metavault that deposits assets to an admin defined subVault. -/// @dev If a subVault is not set, MasterVault shares entitle holders to a pro-rata share of the underlying held by the MasterVault. -/// If a subVault is set, MasterVault shares entitle holders to a pro-rata share of subVault shares held by the MasterVault. -/// On deposit to the MasterVault, if there is a subVault set, the assets are immediately deposited into the subVault. -/// On withdraw from the MasterVault, if there is a subVault set, a pro rata amount of subvault shares are redeemed. -/// On deposit and withdraw, if there is no subVault set, assets are moved to/from the MasterVault itself. +/// @dev The MasterVault keeps some fraction of assets idle and deposits the rest into the subVault to earn yield. +/// A 100% performance fee can be enabled/disabled by the vault manager, and are collected on demand. +/// The MasterVault mitigates the "first depositor" problem by adding 18 decimals to the underlying asset. +/// i.e. if the underlying asset has 6 decimals, the MasterVault will have 24 decimals. /// /// For a subVault to be compatible with the MasterVault, it must adhere to the following: -/// - It must be able to handle arbitrarily large deposits and withdrawals -/// - Deposit size or withdrawal size must not affect the exchange rate (i.e. no slippage) /// - convertToAssets and convertToShares must not be manipulable -contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradeable, PausableUpgradeable { +/// - must not have deposit / withdrawal fees (todo: verify this requirement is necessary) +contract MasterVault is Initializable, ReentrancyGuardUpgradeable, ERC4626Upgradeable, AccessControlUpgradeable, PausableUpgradeable { using SafeERC20 for IERC20; using MathUpgradeable for uint256; @@ -36,20 +35,25 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea /// @notice Pauser role can pause/unpause deposits and withdrawals (todo: pause should pause EVERYTHING) bytes32 public constant PAUSER_ROLE = keccak256("PAUSER_ROLE"); - error SubVaultAlreadySet(); + /// @notice Extra decimals added to the ERC20 decimals of the underlying asset to determine the decimals of the MasterVault + /// @dev This is done to mitigate the "first depositor" problem described in the OpenZeppelin ERC4626 documentation. + /// See https://docs.openzeppelin.com/contracts/5.x/erc4626 for more details on the mitigation. + uint8 public constant EXTRA_DECIMALS = 18; + error SubVaultAssetMismatch(); - error SubVaultExchangeRateTooLow(); - error NoExistingSubVault(); - error NewSubVaultExchangeRateTooLow(); error PerformanceFeeDisabled(); error BeneficiaryNotSet(); error InvalidAsset(); error InvalidOwner(); + error NonZeroTargetAllocation(uint256 targetAllocationWad); + error NonZeroSubVaultShares(uint256 subVaultShares); // todo: avoid inflation, rounding, other common 4626 vulns // we may need a minimum asset or master share amount when setting subvaults (bc of exchange rate calc) IERC4626 public subVault; + uint256 public targetAllocationWad; + /// @notice Flag indicating if performance fee is enabled bool public enablePerformanceFee; @@ -63,13 +67,16 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea event SubvaultChanged(address indexed oldSubvault, address indexed newSubvault); event PerformanceFeeToggled(bool enabled); event BeneficiaryUpdated(address indexed oldBeneficiary, address indexed newBeneficiary); + event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amountTransferred, uint256 amountWithdrawn); - function initialize(IERC20 _asset, string memory _name, string memory _symbol, address _owner) external initializer { - if (address(_asset) == address(0)) revert InvalidAsset(); - if (_owner == address(0)) revert InvalidOwner(); - + function initialize(IERC4626 _subVault, string memory _name, string memory _symbol, address _owner) external initializer { __ERC20_init(_name, _symbol); - __ERC4626_init(IERC20Upgradeable(address(_asset))); + __ERC4626_init(IERC20Upgradeable(_subVault.asset())); + + // call decimals() to ensure underlying has reasonable decimals and we won't have overflow + decimals(); + + __ReentrancyGuard_init(); __AccessControl_init(); __Pausable_init(); @@ -79,73 +86,78 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea _grantRole(DEFAULT_ADMIN_ROLE, _owner); _grantRole(VAULT_MANAGER_ROLE, _owner); _grantRole(PAUSER_ROLE, _owner); + + // mint some dead shares to avoid first depositor issues + // for more information on the mitigation: + // https://web.archive.org/web/20250609034056/https://docs.openzeppelin.com/contracts/4.x/erc4626#fees + _mint(address(1), 10 ** EXTRA_DECIMALS); + + IERC20(asset()).safeApprove(address(_subVault), type(uint256).max); + + subVault = _subVault; } - function distributePerformanceFee() external whenNotPaused { - if (!enablePerformanceFee) revert PerformanceFeeDisabled(); - if (beneficiary == address(0)) { - revert BeneficiaryNotSet(); - } - subVault.redeem(totalProfitInSubVaultShares(MathUpgradeable.Rounding.Down), beneficiary, address(this)); - // todo emit event + function rebalance() external whenNotPaused nonReentrant { + _rebalance(); + } + + function distributePerformanceFee() external whenNotPaused nonReentrant { + _distributePerformanceFee(); } - /// @notice Set a subvault. Can only be called if there is not already a subvault set. + /// @notice Set a new subvault /// @param _subVault The subvault to set. Must be an ERC4626 vault with the same asset as this MasterVault. - /// @param minSubVaultExchRateWad Minimum acceptable ratio (times 1e18) of new subvault shares to outstanding MasterVault shares after deposit. - function setSubVault(IERC4626 _subVault, uint256 minSubVaultExchRateWad) external onlyRole(VAULT_MANAGER_ROLE) { + function setSubVault(IERC4626 _subVault) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { IERC20 underlyingAsset = IERC20(asset()); - if (address(subVault) != address(0)) revert SubVaultAlreadySet(); if (address(_subVault.asset()) != address(underlyingAsset)) revert SubVaultAssetMismatch(); + // we ensure target allocation is zero, therefore the master vault holds no subvault shares + if (targetAllocationWad != 0) revert NonZeroTargetAllocation(targetAllocationWad); + + // sanity check to ensure we have zero subvault shares before changing + if (subVault.balanceOf(address(this)) != 0) { + revert NonZeroSubVaultShares(subVault.balanceOf(address(this))); + } + + address oldSubVault = address(subVault); subVault = _subVault; + if (oldSubVault != address(0)) IERC20(asset()).safeApprove(address(oldSubVault), 0); IERC20(asset()).safeApprove(address(_subVault), type(uint256).max); - _subVault.deposit(underlyingAsset.balanceOf(address(this)), address(this)); - - uint256 subVaultExchRateWad = _subVault.balanceOf(address(this)).mulDiv(1e18, totalSupply(), MathUpgradeable.Rounding.Down); - if (subVaultExchRateWad < minSubVaultExchRateWad) revert NewSubVaultExchangeRateTooLow(); - emit SubvaultChanged(address(0), address(_subVault)); + emit SubvaultChanged(oldSubVault, address(_subVault)); } - /// @notice Revokes the current subvault, moving all assets back to MasterVault - /// @param minAssetExchRateWad Minimum acceptable ratio (times 1e18) of assets received from subvault to outstanding MasterVault shares - function revokeSubVault(uint256 minAssetExchRateWad) external onlyRole(VAULT_MANAGER_ROLE) { - IERC4626 oldSubVault = subVault; - if (address(oldSubVault) == address(0)) revert NoExistingSubVault(); - - subVault = IERC4626(address(0)); - - oldSubVault.redeem(oldSubVault.balanceOf(address(this)), address(this), address(this)); - IERC20(asset()).safeApprove(address(oldSubVault), 0); - - uint256 assetExchRateWad = IERC20(asset()).balanceOf(address(this)).mulDiv(1e18, totalSupply(), MathUpgradeable.Rounding.Down); - if (assetExchRateWad < minAssetExchRateWad) revert SubVaultExchangeRateTooLow(); - - emit SubvaultChanged(address(oldSubVault), address(0)); + function setTargetAllocationWad(uint256 _targetAllocationWad) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { + require(_targetAllocationWad <= 1e18, "Target allocation must be <= 100%"); + require(targetAllocationWad != _targetAllocationWad, "Allocation unchanged"); + targetAllocationWad = _targetAllocationWad; + _rebalance(); } /// @notice Toggle performance fee collection on/off + /// @dev Not explicitly marked nonReentrant because distributePerformanceFee is called within + /// this function and is nonReentrant itself. /// @param enabled True to enable performance fees, false to disable - function setPerformanceFee(bool enabled) external onlyRole(VAULT_MANAGER_ROLE) { - enablePerformanceFee = enabled; - + function setPerformanceFee(bool enabled) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { // reset totalPrincipal to current totalAssets when enabling performance fee // this prevents a sudden large profit if (enabled) { totalPrincipal = _totalAssets(MathUpgradeable.Rounding.Up); } else { + _distributePerformanceFee(); totalPrincipal = 0; } + enablePerformanceFee = enabled; + emit PerformanceFeeToggled(enabled); } /// @notice Set the beneficiary address for performance fees - /// @param newBeneficiary Address to receive performance fees, zero address defaults to owner - function setBeneficiary(address newBeneficiary) external onlyRole(VAULT_MANAGER_ROLE) { + /// @param newBeneficiary Address to receive performance fees + function setBeneficiary(address newBeneficiary) external whenNotPaused nonReentrant onlyRole(VAULT_MANAGER_ROLE) { address oldBeneficiary = beneficiary; beneficiary = newBeneficiary; emit BeneficiaryUpdated(oldBeneficiary, newBeneficiary); @@ -158,6 +170,11 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea function unpause() external onlyRole(PAUSER_ROLE) { _unpause(); } + + /// @dev Overridden to add EXTRA_DECIMALS to the underlying asset decimals + function decimals() public view override returns (uint8) { + return super.decimals() + EXTRA_DECIMALS; + } /** @dev See {IERC4626-totalAssets}. */ function totalAssets() public view virtual override returns (uint256) { @@ -174,14 +191,12 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea // /** @dev See {IERC4626-maxMint}. */ function maxMint(address) public view virtual override returns (uint256) { - if (address(subVault) == address(0)) { - return type(uint256).max; - } uint256 subShares = subVault.maxMint(address(this)); if (subShares == type(uint256).max) { return type(uint256).max; } - return totalSupply().mulDiv(subShares, subVault.balanceOf(address(this)), MathUpgradeable.Rounding.Down); // todo: check rounding direction + uint256 assets = _subVaultSharesToAssets(subShares, MathUpgradeable.Rounding.Down); + return _convertToShares(assets, MathUpgradeable.Rounding.Down); } function totalProfit(MathUpgradeable.Rounding rounding) public view returns (uint256) { @@ -189,15 +204,53 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea return __totalAssets > totalPrincipal ? __totalAssets - totalPrincipal : 0; } - function totalProfitInSubVaultShares(MathUpgradeable.Rounding rounding) public view returns (uint256) { - if (address(subVault) == address(0)) { - revert("Subvault not set"); + function _rebalance() internal { + uint256 totalAssetsUp = _totalAssetsLessProfit(MathUpgradeable.Rounding.Up); + uint256 totalAssetsDown = _totalAssetsLessProfit(MathUpgradeable.Rounding.Down); + uint256 idleTargetUp = totalAssetsUp.mulDiv(1e18 - targetAllocationWad, 1e18, MathUpgradeable.Rounding.Up); + uint256 idleTargetDown = totalAssetsDown.mulDiv(1e18 - targetAllocationWad, 1e18, MathUpgradeable.Rounding.Down); + uint256 idleBalance = IERC20(asset()).balanceOf(address(this)); + + if (idleTargetDown <= idleBalance && idleBalance <= idleTargetUp) { + return; } - uint256 profitAssets = totalProfit(rounding); - if (profitAssets == 0) { - return 0; + + if (idleBalance < idleTargetDown) { + // we need to withdraw from subvault + uint256 assetsToWithdraw = idleTargetDown - idleBalance; + subVault.withdraw(assetsToWithdraw, address(this), address(this)); + } + else { + // we need to deposit into subvault + uint256 assetsToDeposit = idleBalance - idleTargetUp; + subVault.deposit(assetsToDeposit, address(this)); } - return _assetsToSubVaultShares(profitAssets, rounding); + } + + function _distributePerformanceFee() internal { + if (!enablePerformanceFee) revert PerformanceFeeDisabled(); + if (beneficiary == address(0)) { + revert BeneficiaryNotSet(); + } + + uint256 profit = totalProfit(MathUpgradeable.Rounding.Down); + if (profit == 0) return; + + uint256 totalIdle = IERC20(asset()).balanceOf(address(this)); + + uint256 amountToTransfer = profit <= totalIdle ? profit : totalIdle; + uint256 amountToWithdraw = profit - amountToTransfer; + + if (amountToTransfer > 0) { + IERC20(asset()).safeTransfer(beneficiary, amountToTransfer); + } + if (amountToWithdraw > 0) { + subVault.withdraw(amountToWithdraw, beneficiary, address(this)); + } + + _rebalance(); + + emit PerformanceFeesWithdrawn(beneficiary, amountToTransfer, amountToWithdraw); } /** @@ -208,15 +261,10 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea address receiver, uint256 assets, uint256 shares - ) internal virtual override whenNotPaused { + ) internal override whenNotPaused nonReentrant { super._deposit(caller, receiver, assets, shares); - if (enablePerformanceFee) totalPrincipal += assets; - - IERC4626 _subVault = subVault; - if (address(_subVault) != address(0)) { - _subVault.deposit(assets, address(this)); - } + _rebalance(); } /** @@ -228,22 +276,19 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea address _owner, uint256 assets, uint256 shares - ) internal virtual override whenNotPaused { + ) internal override whenNotPaused nonReentrant { if (enablePerformanceFee) totalPrincipal -= assets; - - IERC4626 _subVault = subVault; - if (address(_subVault) != address(0)) { - _subVault.withdraw(assets, address(this), address(this)); + uint256 idleAssets = IERC20(asset()).balanceOf(address(this)); + if (idleAssets < assets) { + uint256 assetsToWithdraw = assets - idleAssets; + subVault.withdraw(assetsToWithdraw, address(this), address(this)); } - super._withdraw(caller, receiver, _owner, assets, shares); + _rebalance(); } function _totalAssets(MathUpgradeable.Rounding rounding) internal view returns (uint256) { - if (address(subVault) == address(0)) { - return IERC20(asset()).balanceOf(address(this)); - } - return _subVaultSharesToAssets(subVault.balanceOf(address(this)), rounding); + return IERC20(asset()).balanceOf(address(this)) + _subVaultSharesToAssets(subVault.balanceOf(address(this)), rounding); } /** @@ -253,60 +298,32 @@ contract MasterVault is Initializable, ERC4626Upgradeable, AccessControlUpgradea * would represent an infinite amount of shares. */ function _convertToShares(uint256 assets, MathUpgradeable.Rounding rounding) internal view virtual override returns (uint256 shares) { - if (address(subVault) == address(0)) { - uint256 effectiveTotalAssets = enablePerformanceFee ? _min(totalAssets(), totalPrincipal) : totalAssets(); - return totalSupply().mulDiv(assets, effectiveTotalAssets, rounding); - } - - uint256 totalSubShares = subVault.balanceOf(address(this)); - - if (enablePerformanceFee) { - // since we use totalSubShares in the denominator of the final calculation, - // and we are subtracting profit from it, we should use the same rounding direction for profit - totalSubShares -= totalProfitInSubVaultShares(_flipRounding(rounding)); - } - - uint256 subShares = _assetsToSubVaultShares(assets, rounding); - - return totalSupply().mulDiv(subShares, totalSubShares, rounding); + // we add one as part of the first deposit mitigation + // see for details: https://docs.openzeppelin.com/contracts/5.x/erc4626 + return assets.mulDiv(totalSupply(), _totalAssetsLessProfit(_flipRounding(rounding)) + 1, rounding); } /** * @dev Internal conversion function (from shares to assets) with support for rounding direction. */ function _convertToAssets(uint256 shares, MathUpgradeable.Rounding rounding) internal view virtual override returns (uint256 assets) { - // if we have no subvault, we just do normal pro-rata calculation - if (address(subVault) == address(0)) { - uint256 effectiveTotalAssets = enablePerformanceFee ? _min(totalAssets(), totalPrincipal) : totalAssets(); - return effectiveTotalAssets.mulDiv(shares, totalSupply(), rounding); - } - - uint256 totalSubShares = subVault.balanceOf(address(this)); + // we add one as part of the first deposit mitigation + // see for details: https://docs.openzeppelin.com/contracts/5.x/erc4626 + return shares.mulDiv(_totalAssetsLessProfit(rounding) + 1, totalSupply(), rounding); + } + function _totalAssetsLessProfit(MathUpgradeable.Rounding rounding) internal view returns (uint256) { + uint256 __totalAssets = _totalAssets(rounding); if (enablePerformanceFee) { - // since we use totalSubShares in the numerator of the final calculation, - // and we are subtracting profit from it, we should use the opposite rounding direction for profit - totalSubShares -= totalProfitInSubVaultShares(_flipRounding(rounding)); + __totalAssets -= totalProfit(_flipRounding(rounding)); } - - // totalSubShares * shares / totalMasterShares - uint256 subShares = totalSubShares.mulDiv(shares, totalSupply(), rounding); - - return _subVaultSharesToAssets(subShares, rounding); - } - - function _assetsToSubVaultShares(uint256 assets, MathUpgradeable.Rounding rounding) internal view returns (uint256 subShares) { - return rounding == MathUpgradeable.Rounding.Up ? subVault.previewWithdraw(assets) : subVault.previewDeposit(assets); + return __totalAssets; } function _subVaultSharesToAssets(uint256 subShares, MathUpgradeable.Rounding rounding) internal view returns (uint256 assets) { return rounding == MathUpgradeable.Rounding.Up ? subVault.previewMint(subShares) : subVault.previewRedeem(subShares); } - function _min(uint256 a, uint256 b) internal pure returns (uint256) { - return a <= b ? a : b; - } - function _flipRounding(MathUpgradeable.Rounding rounding) internal pure returns (MathUpgradeable.Rounding) { return rounding == MathUpgradeable.Rounding.Up ? MathUpgradeable.Rounding.Down : MathUpgradeable.Rounding.Up; } diff --git a/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol b/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol index 0259c84d4..87e63c1ae 100644 --- a/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol +++ b/contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol @@ -2,14 +2,19 @@ pragma solidity ^0.8.0; -import "@openzeppelin/contracts/utils/Create2.sol"; import "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; import "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; +import "@openzeppelin/contracts/token/ERC20/extensions/ERC4626.sol"; import "../ClonableBeaconProxy.sol"; import "./IMasterVault.sol"; import "./IMasterVaultFactory.sol"; import "./MasterVault.sol"; +contract DefaultSubVault is ERC4626 { + constructor(address token) ERC4626(IERC20(token)) ERC20("Default SubVault", "DSV") {} +} + +// todo: slim down this contract contract MasterVaultFactory is IMasterVaultFactory, Initializable { error ZeroAddress(); error BeaconNotDeployed(); @@ -27,23 +32,13 @@ contract MasterVaultFactory is IMasterVaultFactory, Initializable { } function deployVault(address token) public returns (address vault) { - if (token == address(0)) { - revert ZeroAddress(); - } - if ( - address(beaconProxyFactory) == address(0) && beaconProxyFactory.beacon() == address(0) - ) { - revert BeaconNotDeployed(); - } - bytes32 userSalt = _getUserSalt(token); vault = beaconProxyFactory.createProxy(userSalt); - IERC20Metadata tokenMetadata = IERC20Metadata(token); - string memory name = string(abi.encodePacked("Master ", tokenMetadata.name())); - string memory symbol = string(abi.encodePacked("m", tokenMetadata.symbol())); + string memory name = string(abi.encodePacked("Master ", _tryGetTokenName(token))); + string memory symbol = string(abi.encodePacked("m", _tryGetTokenSymbol(token))); - MasterVault(vault).initialize(IERC20(token), name, symbol, owner); + MasterVault(vault).initialize(new DefaultSubVault(token), name, symbol, owner); emit VaultDeployed(token, vault); } @@ -64,4 +59,20 @@ contract MasterVaultFactory is IMasterVaultFactory, Initializable { } return vault; } + + function _tryGetTokenName(address token) internal view returns (string memory) { + try IERC20Metadata(token).name() returns (string memory name) { + return name; + } catch { + return ""; + } + } + + function _tryGetTokenSymbol(address token) internal view returns (string memory) { + try IERC20Metadata(token).symbol() returns (string memory symbol) { + return symbol; + } catch { + return ""; + } + } } diff --git a/contracts/tokenbridge/test/MockSubVault.sol b/contracts/tokenbridge/test/MockSubVault.sol index 411edb61b..ea5adad49 100644 --- a/contracts/tokenbridge/test/MockSubVault.sol +++ b/contracts/tokenbridge/test/MockSubVault.sol @@ -15,4 +15,8 @@ contract MockSubVault is ERC4626 { function totalAssets() public view override returns (uint256) { return IERC20(asset()).balanceOf(address(this)); } + + function adminMint(address to, uint256 amount) external { + _mint(to, amount); + } } \ No newline at end of file diff --git a/contracts/tokenbridge/test/TestERC20.sol b/contracts/tokenbridge/test/TestERC20.sol index 71dd2c005..fc2741144 100644 --- a/contracts/tokenbridge/test/TestERC20.sol +++ b/contracts/tokenbridge/test/TestERC20.sol @@ -28,6 +28,10 @@ contract TestERC20 is aeERC20 { function mint() external { _mint(msg.sender, 50000000); } + + function mint(uint256 amount) external { + _mint(msg.sender, amount); + } } // test token code inspired from maker diff --git a/test-foundry/libraries/vault/MasterVault.t.sol b/test-foundry/libraries/vault/MasterVault.t.sol index 30ec03a51..639289253 100644 --- a/test-foundry/libraries/vault/MasterVault.t.sol +++ b/test-foundry/libraries/vault/MasterVault.t.sol @@ -1,470 +1,180 @@ -// // SPDX-License-Identifier: UNLICENSED -// pragma solidity ^0.8.0; - -// import { Test } from "forge-std/Test.sol"; -// import { MasterVault } from "../../../contracts/tokenbridge/libraries/vault/MasterVault.sol"; -// import { TestERC20 } from "../../../contracts/tokenbridge/test/TestERC20.sol"; -// import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; -// import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; -// import { UpgradeableBeacon } from "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; -// import { BeaconProxyFactory, ClonableBeaconProxy } from "../../../contracts/tokenbridge/libraries/ClonableBeaconProxy.sol"; -// import { IAccessControl } from "@openzeppelin/contracts/access/IAccessControl.sol"; - -// contract MasterVaultTest is Test { -// MasterVault public vault; -// TestERC20 public token; -// UpgradeableBeacon public beacon; -// BeaconProxyFactory public beaconProxyFactory; - -// event SubvaultChanged(address indexed oldSubvault, address indexed newSubvault); - -// address public user = address(0x1); -// string public name = "Master Test Token"; -// string public symbol = "mTST"; - -// function setUp() public { -// token = new TestERC20(); - -// MasterVault implementation = new MasterVault(); -// beacon = new UpgradeableBeacon(address(implementation)); - -// beaconProxyFactory = new BeaconProxyFactory(); -// beaconProxyFactory.initialize(address(beacon)); - -// bytes32 salt = keccak256("test"); -// address proxyAddress = beaconProxyFactory.createProxy(salt); -// vault = MasterVault(proxyAddress); - -// vault.initialize(IERC20(address(token)), name, symbol, address(this)); -// } - -// function test_initialize() public { -// assertEq(address(vault.asset()), address(token), "Invalid asset"); -// assertEq(vault.name(), name, "Invalid name"); -// assertEq(vault.symbol(), symbol, "Invalid symbol"); -// assertEq(vault.decimals(), token.decimals(), "Invalid decimals"); -// assertEq(vault.totalSupply(), 0, "Invalid initial supply"); -// assertEq(vault.totalAssets(), 0, "Invalid initial assets"); -// assertEq(address(vault.subVault()), address(0), "SubVault should be zero initially"); - -// assertTrue(vault.hasRole(vault.DEFAULT_ADMIN_ROLE(), address(this)), "Should have DEFAULT_ADMIN_ROLE"); -// assertTrue(vault.hasRole(vault.VAULT_MANAGER_ROLE(), address(this)), "Should have VAULT_MANAGER_ROLE"); -// assertTrue(vault.hasRole(vault.FEE_MANAGER_ROLE(), address(this)), "Should have FEE_MANAGER_ROLE"); -// } - -// function test_WithoutSubvault_deposit() public { -// assertEq(address(vault.subVault()), address(0), "SubVault should be zero initially"); - -// // user deposit 500 tokens to vault -// // by this test expec: -// //- user to receive 500 shares -// //- total shares supply to increase by 500 -// //- total assets to increase by 500 - -// uint256 minShares = 0; - -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); - -// token.approve(address(vault), depositAmount); - -// uint256 sharesBefore = vault.balanceOf(user); -// uint256 totalSupplyBefore = vault.totalSupply(); -// uint256 totalAssetsBefore = vault.totalAssets(); - -// uint256 shares = vault.deposit(depositAmount, user, minShares); - -// assertEq(vault.balanceOf(user), sharesBefore + shares, "Invalid user balance"); -// assertEq(vault.totalSupply(), totalSupplyBefore + shares, "Invalid total supply"); -// assertEq(vault.totalAssets(), totalAssetsBefore + depositAmount, "Invalid total assets"); -// assertEq(token.balanceOf(user), 0, "User tokens should be transferred"); - -// vm.stopPrank(); -// } - -// function test_deposit_RevertTooFewSharesReceived() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// uint256 minShares = depositAmount * 2; // Unrealistic requirement - -// token.approve(address(vault), depositAmount); - -// vm.expectRevert(MasterVault.TooFewSharesReceived.selector); -// vault.deposit(depositAmount, user, minShares); - -// vm.stopPrank(); -// } - -// function test_setSubvault() public { -// MockSubVault subVault = new MockSubVault( -// IERC20(address(token)), -// "Sub Vault Token", -// "svTST" -// ); - -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// assertEq(address(vault.subVault()), address(0), "SubVault should be zero initially"); -// assertEq(vault.totalAssets(), depositAmount, "Total assets should equal deposit"); - -// uint256 minSubVaultExchRateWad = 1e18; - -// vm.expectEmit(true, true, false, false); -// emit SubvaultChanged(address(0), address(subVault)); - -// vault.setSubVault(subVault, minSubVaultExchRateWad); - -// assertEq(address(vault.subVault()), address(subVault), "SubVault should be set"); -// assertEq(vault.totalAssets(), depositAmount, "Total assets should remain the same"); -// assertEq(subVault.balanceOf(address(vault)), depositAmount, "SubVault should have received assets"); -// } - -// function test_revokeSubvault() public { -// MockSubVault subVault = new MockSubVault( -// IERC20(address(token)), -// "Sub Vault Token", -// "svTST" -// ); - -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// vault.setSubVault(subVault, 1e18); - -// assertEq(address(vault.subVault()), address(subVault), "SubVault should be set"); -// assertEq(subVault.balanceOf(address(vault)), depositAmount, "SubVault should have assets"); - -// uint256 minAssetExchRateWad = 1e18; - -// vm.expectEmit(true, true, false, false); -// emit SubvaultChanged(address(subVault), address(0)); - -// vault.revokeSubVault(minAssetExchRateWad); - -// assertEq(address(vault.subVault()), address(0), "SubVault should be revoked"); -// assertEq(vault.totalAssets(), depositAmount, "Total assets should remain the same"); -// assertEq(subVault.balanceOf(address(vault)), 0, "SubVault should have no assets"); -// assertEq(token.balanceOf(address(vault)), depositAmount, "MasterVault should have assets directly"); -// } - -// function test_WithoutSubvault_withdraw() public { -// uint256 maxSharesBurned = type(uint256).max; - -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); - -// uint256 withdrawAmount = depositAmount / 2; -// uint256 userSharesBefore = vault.balanceOf(user); -// uint256 totalSupplyBefore = vault.totalSupply(); -// uint256 totalAssetsBefore = vault.totalAssets(); - -// uint256 shares = vault.withdraw(withdrawAmount, user, user, maxSharesBurned); - -// assertEq(vault.balanceOf(user), userSharesBefore - shares, "User shares should decrease"); -// assertEq(vault.totalSupply(), totalSupplyBefore - shares, "Total supply should decrease"); -// assertEq(vault.totalAssets(), totalAssetsBefore - withdrawAmount, "Total assets should decrease"); -// assertEq(token.balanceOf(user), withdrawAmount, "User should receive withdrawn assets"); -// assertEq(token.balanceOf(address(vault)), depositAmount - withdrawAmount, "Vault should have remaining assets"); - -// vm.stopPrank(); -// } - -// function test_WithSubvault_withdraw() public { -// MockSubVault subVault = new MockSubVault( -// IERC20(address(token)), -// "Sub Vault Token", -// "svTST" -// ); - -// vm.startPrank(user); -// token.mint(); -// uint256 firstDepositAmount = token.balanceOf(user); -// token.approve(address(vault), firstDepositAmount); -// vault.deposit(firstDepositAmount, user, 0); -// vm.stopPrank(); - -// vault.setSubVault(subVault, 1e18); - -// uint256 withdrawAmount = firstDepositAmount / 2; -// uint256 maxSharesBurned = type(uint256).max; - -// vm.startPrank(user); -// uint256 userSharesBefore = vault.balanceOf(user); -// uint256 totalSupplyBefore = vault.totalSupply(); -// uint256 totalAssetsBefore = vault.totalAssets(); -// uint256 subVaultSharesBefore = subVault.balanceOf(address(vault)); - -// uint256 shares = vault.withdraw(withdrawAmount, user, user, maxSharesBurned); - -// assertEq(vault.balanceOf(user), userSharesBefore - shares, "User shares should decrease"); -// assertEq(vault.totalSupply(), totalSupplyBefore - shares, "Total supply should decrease"); -// assertEq(vault.totalAssets(), totalAssetsBefore - withdrawAmount, "Total assets should decrease"); -// assertEq(token.balanceOf(user), withdrawAmount, "User should receive withdrawn assets"); -// assertLt(subVault.balanceOf(address(vault)), subVaultSharesBefore, "SubVault shares should decrease"); - -// token.mint(); -// uint256 secondDepositAmount = token.balanceOf(user) - withdrawAmount; -// token.approve(address(vault), secondDepositAmount); -// vault.deposit(secondDepositAmount, user, 0); - -// vault.balanceOf(user); -// uint256 finalTotalAssets = vault.totalAssets(); -// subVault.balanceOf(address(vault)); - -// vault.withdraw(finalTotalAssets, user, user, type(uint256).max); - -// assertEq(vault.balanceOf(user), 0, "User should have no shares left"); -// assertEq(vault.totalSupply(), 0, "Total supply should be zero"); -// assertEq(vault.totalAssets(), 0, "Total assets should be zero"); -// assertEq(token.balanceOf(user), firstDepositAmount + secondDepositAmount, "User should have all original tokens"); -// assertEq(subVault.balanceOf(address(vault)), 0, "SubVault should have no shares left"); - -// vm.stopPrank(); -// } - -// function test_beaconUpgrade() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// address oldImplementation = beacon.implementation(); -// assertEq(oldImplementation, address(beacon.implementation()), "Should have initial implementation"); - -// MasterVault newImplementation = new MasterVault(); -// beacon.upgradeTo(address(newImplementation)); - -// assertEq(beacon.implementation(), address(newImplementation), "Beacon should point to new implementation"); -// assertTrue(beacon.implementation() != oldImplementation, "Implementation should have changed"); - -// assertEq(vault.name(), name, "Name should remain after upgrade"); +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import { MasterVaultCoreTest } from "./MasterVaultCore.t.sol"; +import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; +import {console2} from "forge-std/console2.sol"; + +contract MasterVaultFirstDepositTest is MasterVaultCoreTest { + using Math for uint256; + + uint256 constant FRESH_STATE_PLACEHOLDER = uint256(keccak256("FRESH_STATE_PLACEHOLDER")); + uint256 constant DEAD_SHARES = 10**18; + + struct State { + uint256 userShares; + uint256 masterVaultTotalAssets; + uint256 masterVaultTotalSupply; + uint256 masterVaultTokenBalance; + uint256 masterVaultSubVaultShareBalance; + uint256 subVaultTotalAssets; + uint256 subVaultTotalSupply; + uint256 subVaultTokenBalance; + } + + // first deposit + function test_deposit(uint96 _depositAmount) public { + uint256 depositAmount = _depositAmount; + vm.startPrank(user); + token.mint(depositAmount); + token.approve(address(vault), depositAmount); + uint256 shares = vault.deposit(depositAmount, user); + vm.stopPrank(); + _checkState(State({ + userShares: depositAmount * DEAD_SHARES, + masterVaultTotalAssets: depositAmount, + masterVaultTotalSupply: (1 + depositAmount) * DEAD_SHARES, + masterVaultTokenBalance: depositAmount, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(shares, depositAmount * DEAD_SHARES, "shares mismatch deposit return value"); + } + + function test_mint(uint96 _mintAmount) public { + uint256 mintAmount = _mintAmount; + vm.startPrank(user); + token.mint(mintAmount); + token.approve(address(vault), mintAmount); + uint256 assets = vault.mint(mintAmount, user); + vm.stopPrank(); + _checkState(State({ + userShares: mintAmount, + masterVaultTotalAssets: mintAmount.ceilDiv(1e18), + masterVaultTotalSupply: mintAmount + DEAD_SHARES, + masterVaultTokenBalance: mintAmount.ceilDiv(1e18), + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(assets, mintAmount.ceilDiv(1e18), "assets mismatch mint return value"); + } + + function test_withdraw(uint96 _firstDeposit, uint96 _withdrawAmount) public { + uint256 firstDeposit = _firstDeposit; + uint256 withdrawAmount = _withdrawAmount; + vm.assume(withdrawAmount <= firstDeposit); + test_deposit(_firstDeposit); + vm.startPrank(user); + uint256 sharesRedeemed = vault.withdraw(withdrawAmount, user, user); + vm.stopPrank(); + _checkState(State({ + userShares: (firstDeposit - withdrawAmount) * DEAD_SHARES, + masterVaultTotalAssets: firstDeposit - withdrawAmount, + masterVaultTotalSupply: (1 + firstDeposit - withdrawAmount) * DEAD_SHARES, + masterVaultTokenBalance: firstDeposit - withdrawAmount, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(sharesRedeemed, withdrawAmount * DEAD_SHARES, "sharesRedeemed mismatch withdraw return value"); + } + + function test_redeem(uint96 _firstMint, uint96 _redeemAmount) public { + uint256 firstMint = _firstMint; + uint256 redeemAmount = _redeemAmount; + vm.assume(redeemAmount <= firstMint); + test_mint(_firstMint); + State memory beforeState = _getState(); + vm.startPrank(user); + uint256 assets = vault.redeem(redeemAmount, user, user); + uint256 expectedAssets = (1 + beforeState.masterVaultTotalAssets) * redeemAmount / (beforeState.masterVaultTotalSupply); + vm.stopPrank(); + _checkState(State({ + userShares: beforeState.userShares - redeemAmount, + masterVaultTotalAssets: beforeState.masterVaultTotalAssets - expectedAssets, + masterVaultTotalSupply: beforeState.masterVaultTotalSupply - redeemAmount, + masterVaultTokenBalance: beforeState.masterVaultTokenBalance - expectedAssets, + masterVaultSubVaultShareBalance: 0, + subVaultTotalAssets: 0, + subVaultTotalSupply: 0, + subVaultTokenBalance: 0 + })); + assertEq(assets, expectedAssets, "assets mismatch redeem return value"); + } + + function _checkState(State memory expectedState) internal { + assertEq(expectedState.userShares, vault.balanceOf(user), "userShares mismatch"); + assertEq(expectedState.masterVaultTotalAssets, vault.totalAssets(), "masterVaultTotalAssets mismatch"); + assertEq(expectedState.masterVaultTotalSupply, vault.totalSupply(), "masterVaultTotalSupply mismatch"); + assertEq(expectedState.masterVaultTokenBalance, token.balanceOf(address(vault)), "masterVaultTokenBalance mismatch"); + assertEq(expectedState.masterVaultSubVaultShareBalance, vault.subVault().balanceOf(address(vault)), "masterVaultSubVaultShareBalance mismatch"); + assertEq(expectedState.subVaultTotalAssets, vault.subVault().totalAssets(), "subVaultTotalAssets mismatch"); + assertEq(expectedState.subVaultTotalSupply, vault.subVault().totalSupply(), "subVaultTotalSupply mismatch"); + assertEq(expectedState.subVaultTokenBalance, token.balanceOf(address(vault.subVault())), "subVaultTokenBalance mismatch"); + } + + function _getState() internal view returns (State memory) { + return State({ + userShares: vault.balanceOf(user), + masterVaultTotalAssets: vault.totalAssets(), + masterVaultTotalSupply: vault.totalSupply(), + masterVaultTokenBalance: token.balanceOf(address(vault)), + masterVaultSubVaultShareBalance: vault.subVault().balanceOf(address(vault)), + subVaultTotalAssets: vault.subVault().totalAssets(), + subVaultTotalSupply: vault.subVault().totalSupply(), + subVaultTokenBalance: token.balanceOf(address(vault.subVault())) + }); + } + + function _logState(string memory label, State memory state) internal view { + console2.log(label); + console2.log(" userShares:", state.userShares); + console2.log(" masterVaultTotalAssets:", state.masterVaultTotalAssets); + console2.log(" masterVaultTotalSupply:", state.masterVaultTotalSupply); + console2.log(" masterVaultTokenBalance:", state.masterVaultTokenBalance); + console2.log(" masterVaultSubVaultShareBalance:", state.masterVaultSubVaultShareBalance); + console2.log(" subVaultTotalAssets:", state.subVaultTotalAssets); + console2.log(" subVaultTotalSupply:", state.subVaultTotalSupply); + console2.log(" subVaultTokenBalance:", state.subVaultTokenBalance); + } +} + +// contract MasterVaultTestWithSubvaultFresh is MasterVaultTest { +// function setUp() public override { +// super.setUp(); +// MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); +// vault.setSubVault(IERC4626(address(_subvault))); // } +// } -// function test_setSubVault_revert_NotVaultManager() public { -// MockSubVault subVault = new MockSubVault( -// IERC20(address(token)), -// "Sub Vault Token", -// "svTST" +// contract MasterVaultTestWithSubvaultHoldingAssets is MasterVaultTest { +// function setUp() public override { +// super.setUp(); + +// MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); +// uint256 _initAmount = 97659743; +// token.mint(_initAmount); +// token.approve(address(_subvault), _initAmount); +// _subvault.deposit(_initAmount, address(this)); +// assertEq( +// _initAmount, +// _subvault.totalAssets(), +// "subvault should be initiated with assets = _initAmount" // ); - -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); - -// vm.expectRevert(); -// vault.setSubVault(subVault, 1e18); - -// vm.stopPrank(); -// } - -// function test_setBeneficiary_revert_NotFeeManager() public { -// address newBeneficiary = address(0x999); - -// vm.prank(user); -// vm.expectRevert(); -// vault.setBeneficiary(newBeneficiary); -// } - -// function test_withdrawPerformanceFees_revert_NotFeeManager() public { -// vm.prank(user); -// vm.expectRevert(); -// vault.withdrawPerformanceFees(); -// } - -// function test_roleAdmin() public { -// address vaultManager = address(0x1111); -// address feeManager = address(0x2222); - -// vault.grantRole(vault.VAULT_MANAGER_ROLE(), vaultManager); -// vault.grantRole(vault.FEE_MANAGER_ROLE(), feeManager); - -// assertTrue(vault.hasRole(vault.VAULT_MANAGER_ROLE(), vaultManager), "Should have VAULT_MANAGER_ROLE"); -// assertTrue(vault.hasRole(vault.FEE_MANAGER_ROLE(), feeManager), "Should have FEE_MANAGER_ROLE"); - -// vault.revokeRole(vault.VAULT_MANAGER_ROLE(), vaultManager); -// assertFalse(vault.hasRole(vault.VAULT_MANAGER_ROLE(), vaultManager), "Should not have VAULT_MANAGER_ROLE"); -// } - -// function test_multipleRoleHolders() public { -// address vaultManager1 = address(0x1111); -// address vaultManager2 = address(0x2222); - -// vault.grantRole(vault.VAULT_MANAGER_ROLE(), vaultManager1); -// vault.grantRole(vault.VAULT_MANAGER_ROLE(), vaultManager2); - -// assertTrue(vault.hasRole(vault.VAULT_MANAGER_ROLE(), vaultManager1), "Manager1 should have VAULT_MANAGER_ROLE"); -// assertTrue(vault.hasRole(vault.VAULT_MANAGER_ROLE(), vaultManager2), "Manager2 should have VAULT_MANAGER_ROLE"); - -// MockSubVault subVault = new MockSubVault( -// IERC20(address(token)), -// "Sub Vault Token", -// "svTST" +// assertEq( +// _initAmount, +// _subvault.totalSupply(), +// "subvault should be initiated with shares = _initAmount" // ); -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// vm.prank(vaultManager1); -// vault.setSubVault(subVault, 1e18); - -// assertEq(address(vault.subVault()), address(subVault), "SubVault should be set by manager1"); -// } - -// function test_initialize_pauserRole() public { -// assertTrue(vault.hasRole(vault.PAUSER_ROLE(), address(this)), "Should have PAUSER_ROLE"); -// assertFalse(vault.paused(), "Should not be paused initially"); +// vault.setSubVault(IERC4626(address(_subvault))); // } - -// function test_pause() public { -// assertFalse(vault.paused(), "Should not be paused initially"); - -// vault.pause(); - -// assertTrue(vault.paused(), "Should be paused"); -// } - -// function test_unpause() public { -// vault.pause(); -// assertTrue(vault.paused(), "Should be paused"); - -// vault.unpause(); - -// assertFalse(vault.paused(), "Should not be paused"); -// } - -// function test_pause_revert_NotPauser() public { -// vm.prank(user); -// vm.expectRevert(); -// vault.pause(); -// } - -// function test_unpause_revert_NotPauser() public { -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert(); -// vault.unpause(); -// } - -// function test_deposit_revert_WhenPaused() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vm.stopPrank(); - -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert("Pausable: paused"); -// vault.deposit(depositAmount, user, 0); -// } - -// function test_withdraw_revert_WhenPaused() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert("Pausable: paused"); -// vault.withdraw(depositAmount / 2, user, user, type(uint256).max); -// } - -// function test_mint_revert_WhenPaused() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vm.stopPrank(); - -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert("Pausable: paused"); -// vault.mint(100, user, type(uint256).max); -// } - -// function test_redeem_revert_WhenPaused() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// uint256 shares = vault.deposit(depositAmount, user, 0); -// vm.stopPrank(); - -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert("Pausable: paused"); -// vault.redeem(shares / 2, user, user, 0); -// } - -// function test_pauseUnpauseFlow() public { -// vm.startPrank(user); -// token.mint(); -// uint256 depositAmount = token.balanceOf(user); -// token.approve(address(vault), depositAmount); -// vault.deposit(depositAmount / 2, user, 0); -// vm.stopPrank(); - -// vault.pause(); - -// vm.prank(user); -// vm.expectRevert("Pausable: paused"); -// vault.deposit(depositAmount / 2, user, 0); - -// vault.unpause(); - -// vm.prank(user); -// vault.deposit(depositAmount / 2, user, 0); - -// assertEq(token.balanceOf(user), 0, "All tokens should be deposited"); -// } - -// function test_multiplePausers() public { -// address pauser1 = address(0x3333); -// address pauser2 = address(0x4444); - -// vault.grantRole(vault.PAUSER_ROLE(), pauser1); -// vault.grantRole(vault.PAUSER_ROLE(), pauser2); - -// assertTrue(vault.hasRole(vault.PAUSER_ROLE(), pauser1), "Pauser1 should have PAUSER_ROLE"); -// assertTrue(vault.hasRole(vault.PAUSER_ROLE(), pauser2), "Pauser2 should have PAUSER_ROLE"); - -// vm.prank(pauser1); -// vault.pause(); -// assertTrue(vault.paused(), "Should be paused by pauser1"); - -// vm.prank(pauser2); -// vault.unpause(); -// assertFalse(vault.paused(), "Should be unpaused by pauser2"); -// } - // } diff --git a/test-foundry/libraries/vault/MasterVaultAttack.t.sol b/test-foundry/libraries/vault/MasterVaultAttack.t.sol new file mode 100644 index 000000000..03c768f23 --- /dev/null +++ b/test-foundry/libraries/vault/MasterVaultAttack.t.sol @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import "forge-std/console2.sol"; +import { MasterVaultCoreTest } from "./MasterVaultCore.t.sol"; +import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; + +contract MasterVaultTestWithSubvaultFresh is MasterVaultCoreTest { + function setUp() public override { + super.setUp(); + MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); + vault.setSubVault(IERC4626(address(_subvault))); + } +} + +contract AttackTest is MasterVaultTestWithSubvaultFresh { + function _calculateStolenAmount( + uint128 initialSubVaultTotalAssets, + uint128 initialSubVaultTotalSupply, + uint128 vaultInitialDepositAmount, + uint128 vaultAttackDepositAmount + ) public returns (uint256) { + console2.log("initialSubVaultTotalAssets:", initialSubVaultTotalAssets); + console2.log("initialSubVaultTotalSupply:", initialSubVaultTotalSupply); + console2.log("vaultInitialDepositAmount:", vaultInitialDepositAmount); + console2.log("vaultAttackDepositAmount:", vaultAttackDepositAmount); + + MockSubVault(address(vault.subVault())).adminMint(address(this), initialSubVaultTotalSupply); + token.mint(initialSubVaultTotalAssets); + token.transfer(address(vault.subVault()), initialSubVaultTotalAssets); + + assertEq( + vault.subVault().totalAssets(), + initialSubVaultTotalAssets, + "subvault total assets should be correct" + ); + assertEq( + vault.subVault().totalSupply(), + initialSubVaultTotalSupply, + "subvault total supply should be correct" + ); + + vm.startPrank(user); + token.mint(vaultInitialDepositAmount); + token.approve(address(vault), vaultInitialDepositAmount); + vault.deposit(vaultInitialDepositAmount, user); + vm.stopPrank(); + + address attacker = address(0xBEEF); + vm.startPrank(attacker); + token.mint(vaultAttackDepositAmount); + token.approve(address(vault), vaultAttackDepositAmount); + uint256 sharesBack = vault.deposit(vaultAttackDepositAmount, attacker); + // vm.assume(sharesBack < vault.maxRedeem(attacker)); + uint256 assetsBack = vault.redeem(sharesBack, attacker, attacker); + vm.stopPrank(); + + uint256 stolenAmount = assetsBack > vaultAttackDepositAmount + ? assetsBack - vaultAttackDepositAmount + : 0; + + console2.log("stolenAmount:", stolenAmount); + + return stolenAmount; + } + + function testFindCombo( + uint120 initialSubVaultTotalAssets, + int8 initialSubVaultTotalSupplyWiggle, + uint128 vaultInitialDepositAmount, + uint128 vaultAttackDepositAmount + ) public { + if(initialSubVaultTotalAssets < 1e18) { + initialSubVaultTotalAssets += 1e18; + } + if(vaultInitialDepositAmount < 1e18) { + vaultInitialDepositAmount += 1e18; + } + + uint128 initialSubVaultTotalSupply = uint128(int128(int120(initialSubVaultTotalAssets)) + int128(initialSubVaultTotalSupplyWiggle)); + uint256 stolenAmt = _calculateStolenAmount( + initialSubVaultTotalAssets, + initialSubVaultTotalSupply, + vaultInitialDepositAmount, + vaultAttackDepositAmount + ); + require( + stolenAmt == 0, + "theft occurred with these parameters" + ); + } +} diff --git a/test-foundry/libraries/vault/MasterVaultCore.t.sol b/test-foundry/libraries/vault/MasterVaultCore.t.sol new file mode 100644 index 000000000..fc5549578 --- /dev/null +++ b/test-foundry/libraries/vault/MasterVaultCore.t.sol @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import { Test } from "forge-std/Test.sol"; +import { MasterVault } from "../../../contracts/tokenbridge/libraries/vault/MasterVault.sol"; +import { MasterVaultFactory } from "../../../contracts/tokenbridge/libraries/vault/MasterVaultFactory.sol"; +import { TestERC20 } from "../../../contracts/tokenbridge/test/TestERC20.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { UpgradeableBeacon } from "@openzeppelin/contracts/proxy/beacon/UpgradeableBeacon.sol"; +import { + BeaconProxyFactory, + ClonableBeaconProxy +} from "../../../contracts/tokenbridge/libraries/ClonableBeaconProxy.sol"; +import { IAccessControl } from "@openzeppelin/contracts/access/IAccessControl.sol"; + +contract MasterVaultCoreTest is Test { + MasterVaultFactory public factory; + MasterVault public vault; + TestERC20 public token; + + address public user = vm.addr(1); + string public name = "Master Test Token"; + string public symbol = "mTST"; + + function getAssetsHoldingVault() internal view virtual returns (address) { + return address(vault.subVault()) == address(0) ? address(vault) : address(vault.subVault()); + } + + function setUp() public virtual { + factory = new MasterVaultFactory(); + factory.initialize(address(this)); + token = new TestERC20(); + vault = MasterVault(factory.deployVault(address(token))); + } +} diff --git a/test-foundry/libraries/vault/MasterVaultFactory.t.sol b/test-foundry/libraries/vault/MasterVaultFactory.t.sol index 48359c390..ec7e65fa4 100644 --- a/test-foundry/libraries/vault/MasterVaultFactory.t.sol +++ b/test-foundry/libraries/vault/MasterVaultFactory.t.sol @@ -45,7 +45,7 @@ contract MasterVaultFactoryTest is Test { } function test_deployVault_RevertZeroAddress() public { - vm.expectRevert(MasterVaultFactory.ZeroAddress.selector); + vm.expectRevert(); factory.deployVault(address(0)); } diff --git a/test-foundry/libraries/vault/MasterVaultFee.t.sol b/test-foundry/libraries/vault/MasterVaultFee.t.sol new file mode 100644 index 000000000..71d34aa88 --- /dev/null +++ b/test-foundry/libraries/vault/MasterVaultFee.t.sol @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.0; + +import { MasterVaultCoreTest } from "./MasterVaultCore.t.sol"; +import { MasterVault } from "../../../contracts/tokenbridge/libraries/vault/MasterVault.sol"; +import { MockSubVault } from "../../../contracts/tokenbridge/test/MockSubVault.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; +import { MathUpgradeable } from "@openzeppelin/contracts-upgradeable/utils/math/MathUpgradeable.sol"; + +contract MasterVaultFeeTest is MasterVaultCoreTest { + address public beneficiaryAddress = address(0x9999); + + function test_setPerformanceFee_enable() public { + assertFalse(vault.enablePerformanceFee(), "Performance fee should be disabled by default"); + + vault.setPerformanceFee(true); + + assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); + } + + function test_cannotDisableWithoutBeneficiarySet() public { + vault.setPerformanceFee(true); + assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); + + vm.expectRevert(MasterVault.BeneficiaryNotSet.selector); + vault.setPerformanceFee(false); + } + + function test_setPerformanceFee_disable() public { + vault.setPerformanceFee(true); + assertTrue(vault.enablePerformanceFee(), "Performance fee should be enabled"); + vault.setBeneficiary(beneficiaryAddress); + vault.setPerformanceFee(false); + + assertFalse(vault.enablePerformanceFee(), "Performance fee should be disabled"); + } + + function test_setPerformanceFee_revert_NotVaultManager() public { + vm.prank(user); + vm.expectRevert(); + vault.setPerformanceFee(true); + } + + function test_setPerformanceFee_emitsEvent() public { + vault.setBeneficiary(beneficiaryAddress); + + vm.expectEmit(true, true, true, true); + emit PerformanceFeeToggled(true); + vault.setPerformanceFee(true); + + vm.expectEmit(true, true, true, true); + emit PerformanceFeeToggled(false); + vault.setPerformanceFee(false); + } + + function test_setBeneficiary() public { + assertEq(vault.beneficiary(), address(0), "Beneficiary should be zero address by default"); + + vault.setBeneficiary(beneficiaryAddress); + + assertEq(vault.beneficiary(), beneficiaryAddress, "Beneficiary should be updated"); + } + + function test_setBeneficiary_revert_NotFeeManager() public { + vm.prank(user); + vm.expectRevert(); + vault.setBeneficiary(beneficiaryAddress); + } + + function test_setBeneficiary_emitsEvent() public { + vm.expectEmit(true, true, true, true); + emit BeneficiaryUpdated(address(0), beneficiaryAddress); + vault.setBeneficiary(beneficiaryAddress); + + address newBeneficiary = address(0x8888); + vm.expectEmit(true, true, true, true); + emit BeneficiaryUpdated(beneficiaryAddress, newBeneficiary); + vault.setBeneficiary(newBeneficiary); + } + + function test_setPerformanceFee_withVaultManagerRole() public { + address vaultManager = address(0x7777); + vault.grantRole(vault.VAULT_MANAGER_ROLE(), vaultManager); + + vm.prank(vaultManager); + vault.setPerformanceFee(true); + + assertTrue( + vault.enablePerformanceFee(), + "Vault manager should be able to set performance fee" + ); + } + + function test_deposit_updatesTotalPrincipal() public { + vault.setPerformanceFee(true); + assertEq(vault.totalPrincipal(), 0, "Total principal should be zero initially"); + + vm.startPrank(user); + token.mint(); + uint256 depositAmount = 100; + token.approve(address(vault), depositAmount); + + vault.deposit(depositAmount, user); + + assertEq( + vault.totalPrincipal(), + (depositAmount), + "Total principal should equal deposit amount" + ); + + vm.stopPrank(); + } + + function test_mint_updatesTotalPrincipal() public { + vault.setPerformanceFee(true); + assertEq(vault.totalPrincipal(), 0, "Total principal should be zero initially"); + + vm.startPrank(user); + token.mint(); + uint256 shares = 100; + token.approve(address(vault), shares); + + uint256 assets = vault.mint(shares, user); + + assertEq( + vault.totalPrincipal(), + assets, + "Total principal should equal assets deposited" + ); + + vm.stopPrank(); + } + + function test_withdraw_updatesTotalPrincipal() public { + vault.setPerformanceFee(true); + vm.startPrank(user); + token.mint(); + uint256 depositAmount = 200; + token.approve(address(vault), depositAmount); + vault.deposit(depositAmount, user); + + assertEq( + vault.totalPrincipal(), + depositAmount, + "Total principal should equal deposit amount" + ); + + uint256 withdrawAmount = 100; + vault.withdraw(withdrawAmount, user, user); + + assertEq( + vault.totalPrincipal(), + depositAmount - withdrawAmount, + "Total principal should decrease by withdraw amount" + ); + + vm.stopPrank(); + } + + function test_redeem_updatesTotalPrincipal() public { + vault.setPerformanceFee(true); + vm.startPrank(user); + token.mint(); + uint256 depositAmount = 200; + token.approve(address(vault), depositAmount); + uint256 shares = vault.deposit(depositAmount, user); + + assertEq( + vault.totalPrincipal(), + depositAmount, + "Total principal should equal deposit amount" + ); + + uint256 sharesToRedeem = shares / 2; + uint256 assetsReceived = vault.redeem(sharesToRedeem, user, user); + + assertEq( + vault.totalPrincipal(), + depositAmount - assetsReceived, + "Total principal should decrease by redeemed assets" + ); + + vm.stopPrank(); + } + + function test_withdrawPerformanceFees_revert_PerformanceFeeDisabled() public { + vault.setBeneficiary(beneficiaryAddress); + + vm.expectRevert(MasterVault.PerformanceFeeDisabled.selector); + vault.distributePerformanceFee(); + } + + function test_withdrawPerformanceFees_revert_BeneficiaryNotSet() public { + vault.setPerformanceFee(true); + + vm.expectRevert(MasterVault.BeneficiaryNotSet.selector); + vault.distributePerformanceFee(); + } + + function test_withdrawPerformanceFees_VaultDoubleInAssets() public { + vault.setBeneficiary(beneficiaryAddress); + vault.setPerformanceFee(true); + + address _assetsHoldingVault = address(vault); // since allocation is 0 + + vm.startPrank(user); + token.mint(); + uint256 depositAmount = token.balanceOf(user); + token.approve(address(vault), depositAmount); + vault.deposit(depositAmount, user); + vm.stopPrank(); + + assertEq( + vault.totalPrincipal(), + depositAmount, + "Total principal should equal deposit" + ); + assertEq(vault.totalAssets(), depositAmount, "Total assets should equal deposit"); + assertEq(vault.totalProfit(MathUpgradeable.Rounding.Up), 0, "Should have no profit initially"); + + uint256 assetsHoldingVaultBalance = token.balanceOf(_assetsHoldingVault); + uint256 amountToMint = assetsHoldingVaultBalance; + + vm.prank(_assetsHoldingVault); + token.mint(amountToMint); + + assertEq(vault.totalAssets(), depositAmount * 2, "Total assets should be doubled"); + assertEq( + vault.totalProfit(MathUpgradeable.Rounding.Down), + depositAmount, + "Profit should equal initial deposit amount" + ); + + uint256 beneficiaryBalanceBefore = token.balanceOf(beneficiaryAddress); + + vm.expectEmit(true, true, true, true); + emit PerformanceFeesWithdrawn(beneficiaryAddress, depositAmount, 0); + vault.distributePerformanceFee(); + + assertEq( + token.balanceOf(beneficiaryAddress), + beneficiaryBalanceBefore + depositAmount, + "Beneficiary should receive profit" + ); + assertEq( + vault.totalAssets(), + depositAmount, + "Vault assets should decrease by profit amount" + ); + } + + event PerformanceFeeToggled(bool enabled); + event BeneficiaryUpdated(address indexed oldBeneficiary, address indexed newBeneficiary); + event PerformanceFeesWithdrawn(address indexed beneficiary, uint256 amountTransferred, uint256 amountWithdrawn); +} + +contract MasterVaultFeeTestWithSubvaultFresh is MasterVaultFeeTest { + function setUp() public override { + super.setUp(); + MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); + vault.setSubVault(IERC4626(address(_subvault))); + } +} + +contract MasterVaultFeeTestWithSubvaultHoldingAssets is MasterVaultFeeTest { + function setUp() public override { + super.setUp(); + + MockSubVault _subvault = new MockSubVault(IERC20(address(token)), "TestSubvault", "TSV"); + uint256 _initAmount = 97659744; + token.mint(_initAmount); + token.approve(address(_subvault), _initAmount); + _subvault.deposit(_initAmount, address(this)); + assertEq( + _initAmount, + _subvault.totalAssets(), + "subvault should be initiated with assets = _initAmount" + ); + assertEq( + _initAmount, + _subvault.totalSupply(), + "subvault should be initiated with shares = _initAmount" + ); + + vault.setSubVault(IERC4626(address(_subvault))); + } +}