diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index 97a0cbd88d2..c46b553e3ad 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -7,8 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **BREAKING:** Add event listener for `AccountsController:accountRemoved` on `TokenBalancesController` to remove token balances for the removed account ([#5726](https://github.com/MetaMask/core/pull/5726)) + +- **BREAKING:** Add event listener for `AccountsController:accountRemoved` on `TokensController` to remove tokens for the removed account ([#5726](https://github.com/MetaMask/core/pull/5726)) + +- **BREAKING:** Add `listAccounts` action to `TokensController` ([#5726](https://github.com/MetaMask/core/pull/5726)) + +- **BREAKING:** Add `listAccounts` action to `TokenBalancesController` ([#5726](https://github.com/MetaMask/core/pull/5726)) + ### Changed +- TokenBalancesController will now check if balances has changed before updating the state ([#5726](https://github.com/MetaMask/core/pull/5726)) - Bump `@metamask/base-controller` from ^8.0.0 to ^8.0.1 ([#5722](https://github.com/MetaMask/core/pull/5722)) ## [60.0.0] diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 9d79d33468d..b4aac53c396 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -16,13 +16,17 @@ import type { } from './TokenBalancesController'; import { TokenBalancesController } from './TokenBalancesController'; import type { TokensControllerState } from './TokensController'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; +import type { InternalAccount } from '../../transaction-controller/src/types'; const setupController = ({ config, tokens = { allTokens: {}, allDetectedTokens: {} }, + listAccounts = [], }: { config?: Partial[0]>; tokens?: Partial; + listAccounts?: InternalAccount[]; } = {}) => { const messenger = new Messenger< TokenBalancesControllerActions | AllowedActions, @@ -37,11 +41,13 @@ const setupController = ({ 'PreferencesController:getState', 'TokensController:getState', 'AccountsController:getSelectedAccount', + 'AccountsController:listAccounts', ], allowedEvents: [ 'NetworkController:stateChange', 'PreferencesController:stateChange', 'TokensController:stateChange', + 'AccountsController:accountRemoved', ], }); @@ -67,6 +73,12 @@ const setupController = ({ jest.fn().mockImplementation(() => tokens), ); + const mockListAccounts = jest.fn().mockReturnValue(listAccounts); + messenger.registerActionHandler( + 'AccountsController:listAccounts', + mockListAccounts, + ); + messenger.registerActionHandler( 'AccountsController:getSelectedAccount', jest.fn().mockImplementation(() => ({ @@ -78,12 +90,15 @@ const setupController = ({ 'NetworkController:getNetworkClientById', jest.fn().mockReturnValue({ provider: jest.fn() }), ); + const controller = new TokenBalancesController({ + messenger: tokenBalancesMessenger, + ...config, + }); + const updateSpy = jest.spyOn(controller, 'update' as never); return { - controller: new TokenBalancesController({ - messenger: tokenBalancesMessenger, - ...config, - }), + controller, + updateSpy, messenger, }; }; @@ -357,6 +372,128 @@ describe('TokenBalancesController', () => { }); }); + it('does not update balances when multi-account balances is enabled and all returned values did not change', async () => { + const chainId = '0x1'; + const account1 = '0x0000000000000000000000000000000000000001'; + const account2 = '0x0000000000000000000000000000000000000002'; + const tokenAddress = '0x0000000000000000000000000000000000000003'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [account1]: [{ address: tokenAddress, symbol: 's', decimals: 0 }], + [account2]: [{ address: tokenAddress, symbol: 's', decimals: 0 }], + }, + }, + }; + + const { controller, messenger, updateSpy } = setupController({ tokens }); + + // Enable multi account balances + messenger.publish( + 'PreferencesController:stateChange', + { isMultiAccountBalancesEnabled: true } as PreferencesState, + [], + ); + + const balance1 = 100; + const balance2 = 200; + jest.spyOn(multicall, 'multicallOrFallback').mockResolvedValue([ + { success: true, value: new BN(balance1) }, + { success: true, value: new BN(balance2) }, + ]); + + await controller._executePoll({ chainId }); + + expect(controller.state.tokenBalances).toStrictEqual({ + [account1]: { + [chainId]: { + [tokenAddress]: toHex(balance1), + }, + }, + [account2]: { + [chainId]: { + [tokenAddress]: toHex(balance2), + }, + }, + }); + + await controller._executePoll({ chainId }); + + expect(updateSpy).toHaveBeenCalledTimes(1); + }); + + it('updates balances when multi-account balances is enabled and some returned values changed', async () => { + const chainId = '0x1'; + const account1 = '0x0000000000000000000000000000000000000001'; + const account2 = '0x0000000000000000000000000000000000000002'; + const tokenAddress = '0x0000000000000000000000000000000000000003'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [account1]: [{ address: tokenAddress, symbol: 's', decimals: 0 }], + [account2]: [{ address: tokenAddress, symbol: 's', decimals: 0 }], + }, + }, + }; + + const { controller, messenger, updateSpy } = setupController({ tokens }); + + // Enable multi account balances + messenger.publish( + 'PreferencesController:stateChange', + { isMultiAccountBalancesEnabled: true } as PreferencesState, + [], + ); + + const balance1 = 100; + const balance2 = 200; + const balance3 = 300; + jest.spyOn(multicall, 'multicallOrFallback').mockResolvedValueOnce([ + { success: true, value: new BN(balance1) }, + { success: true, value: new BN(balance2) }, + ]); + jest.spyOn(multicall, 'multicallOrFallback').mockResolvedValueOnce([ + { success: true, value: new BN(balance1) }, + { success: true, value: new BN(balance3) }, + ]); + + await controller._executePoll({ chainId }); + + expect(controller.state.tokenBalances).toStrictEqual({ + [account1]: { + [chainId]: { + [tokenAddress]: toHex(balance1), + }, + }, + [account2]: { + [chainId]: { + [tokenAddress]: toHex(balance2), + }, + }, + }); + + await controller._executePoll({ chainId }); + + expect(controller.state.tokenBalances).toStrictEqual({ + [account1]: { + [chainId]: { + [tokenAddress]: toHex(balance1), + }, + }, + [account2]: { + [chainId]: { + [tokenAddress]: toHex(balance3), + }, + }, + }); + + expect(updateSpy).toHaveBeenCalledTimes(2); + }); + it('only updates selected account balance when multi-account balances is disabled', async () => { const chainId = '0x1'; const selectedAccount = '0x0000000000000000000000000000000000000000'; @@ -471,4 +608,83 @@ describe('TokenBalancesController', () => { }); }); }); + + describe('when accountRemoved is published', () => { + it('removes the balances for the removed account', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const accountAddress2 = '0x0000000000000000000000000000000000000002'; + const tokenAddress = '0x0000000000000000000000000000000000000001'; + const tokenAddress2 = '0x0000000000000000000000000000000000000022'; + const account = createMockInternalAccount({ + address: accountAddress, + }); + const account2 = createMockInternalAccount({ + address: accountAddress2, + }); + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: tokenAddress, symbol: 's', decimals: 0 }, + ], + [accountAddress2]: [ + { address: tokenAddress2, symbol: 't', decimals: 0 }, + ], + }, + }, + }; + + const { controller, messenger } = setupController({ + tokens, + listAccounts: [account, account2], + }); + // Enable multi account balances + messenger.publish( + 'PreferencesController:stateChange', + { isMultiAccountBalancesEnabled: true } as PreferencesState, + [], + ); + expect(controller.state.tokenBalances).toStrictEqual({}); + + const balance = 123456; + const balance2 = 200; + jest.spyOn(multicall, 'multicallOrFallback').mockResolvedValue([ + { + success: true, + value: new BN(balance), + }, + { success: true, value: new BN(balance2) }, + ]); + + await controller._executePoll({ chainId }); + + expect(controller.state.tokenBalances).toStrictEqual({ + [accountAddress]: { + [chainId]: { + [tokenAddress]: toHex(balance), + }, + }, + [accountAddress2]: { + [chainId]: { + [tokenAddress2]: toHex(balance2), + }, + }, + }); + + messenger.publish('AccountsController:accountRemoved', account.id); + + await advanceTime({ clock, duration: 1 }); + + expect(controller.state.tokenBalances).toStrictEqual({ + [accountAddress2]: { + [chainId]: { + [tokenAddress2]: toHex(balance2), + }, + }, + }); + }); + }); }); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 5c57b3eabe2..62a667f9073 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,6 +1,10 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; -import type { AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; +import type { + AccountsControllerAccountRemovedEvent, + AccountsControllerGetSelectedAccountAction, + AccountsControllerListAccountsAction, +} from '@metamask/accounts-controller'; import type { RestrictedMessenger, ControllerGetStateAction, @@ -80,7 +84,8 @@ export type AllowedActions = | NetworkControllerGetStateAction | TokensControllerGetStateAction | PreferencesControllerGetStateAction - | AccountsControllerGetSelectedAccountAction; + | AccountsControllerGetSelectedAccountAction + | AccountsControllerListAccountsAction; export type TokenBalancesControllerStateChangeEvent = ControllerStateChangeEvent< @@ -94,7 +99,8 @@ export type TokenBalancesControllerEvents = export type AllowedEvents = | TokensControllerStateChangeEvent | PreferencesControllerStateChangeEvent - | NetworkControllerStateChangeEvent; + | NetworkControllerStateChangeEvent + | AccountsControllerAccountRemovedEvent; export type TokenBalancesControllerMessenger = RestrictedMessenger< typeof controllerName, @@ -185,6 +191,13 @@ export class TokenBalancesController extends StaticIntervalPollingController this.#handleOnAccountRemoved(accountId), + ); } /** @@ -242,8 +255,9 @@ export class TokenBalancesController extends StaticIntervalPollingController account.id === accountId, + )?.address; + if (!accountAddress) { + return; + } + + this.update((state) => { + delete state.tokenBalances[accountAddress as `0x${string}`]; + }); + } + /** * Returns an array of chain ids that have tokens. * @param allTokens - The state for imported tokens across all chains. @@ -309,6 +344,72 @@ export class TokenBalancesController extends StaticIntervalPollingController elm.address), + ); + + for (const singleToken of allCurrentTokens) { + if (!existingSet.has(singleToken)) { + this.update((state) => { + delete state.tokenBalances[currentAccount as Hex][ + currentChain as Hex + ][singleToken as `0x${string}`]; + }); + } + } + } + } + + // then we check if the state change was due to a token being added + let shouldUpdate = false; + for (const currentChain of Object.keys(currentAllTokens)) { + if (chainIds?.length && !chainIdsSet.has(currentChain as Hex)) { + continue; + } + const accountsPerChain = currentAllTokens[currentChain as Hex]; + + for (const currentAccount of Object.keys(accountsPerChain)) { + const tokensList = accountsPerChain[currentAccount as `0x${string}`]; + const tokenBalancesObject = + currentTokenBalances[currentAccount as `0x${string}`]?.[ + currentChain as Hex + ] || {}; + for (const singleToken of tokensList) { + if (!tokenBalancesObject?.[singleToken.address as `0x${string}`]) { + shouldUpdate = true; + break; + } + } + } + } + if (shouldUpdate) { + await this.updateBalances({ chainIds }).catch(console.error); + } + } + /** * Updates token balances for the given chain id. * @param input - The input for the update. @@ -341,6 +442,10 @@ export class TokenBalancesController extends StaticIntervalPollingController 0) { const provider = new Web3Provider( this.#getNetworkClient(chainId).provider, @@ -357,18 +462,34 @@ export class TokenBalancesController extends StaticIntervalPollingController { - // Reset so that when accounts or tokens are removed, - // their balances are removed rather than left stale. - for (const accountAddress of Object.keys(state.tokenBalances)) { - state.tokenBalances[accountAddress as Hex][chainId] = {}; - } + const updatedResults: (MulticallResult & { + isTokenBalanceValueChanged?: boolean; + })[] = results.map((res, i) => { + const { value } = res; + const { accountAddress, tokenAddress } = accountTokenPairs[i]; + const currentTokenBalanceValueForAccount = + currentTokenBalances.tokenBalances?.[accountAddress]?.[chainId]?.[ + tokenAddress + ]; + const isTokenBalanceValueChanged = + currentTokenBalanceValueForAccount !== toHex(value as BN); + return { + ...res, + isTokenBalanceValueChanged, + }; + }); - for (let i = 0; i < results.length; i++) { - const { success, value } = results[i]; - const { accountAddress, tokenAddress } = accountTokenPairs[i]; + // if all values of isTokenBalanceValueChanged are false, return + if (updatedResults.every((result) => !result.isTokenBalanceValueChanged)) { + return; + } - if (success) { + this.update((state) => { + for (let i = 0; i < updatedResults.length; i++) { + const { success, value, isTokenBalanceValueChanged } = + updatedResults[i]; + const { accountAddress, tokenAddress } = accountTokenPairs[i]; + if (success && isTokenBalanceValueChanged) { ((state.tokenBalances[accountAddress] ??= {})[chainId] ??= {})[ tokenAddress ] = toHex(value as BN); diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 5a691ccb597..5a241d4b527 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -3339,6 +3339,136 @@ describe('TokensController', () => { ); }); }); + + describe('when accountRemoved is published', () => { + it('removes the list of tokens for the removed account', async () => { + const firstAddress = '0x123'; + const secondAddress = '0x456'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); + const initialState: TokensControllerState = { + allTokens: { + [ChainId.mainnet]: { + [firstAddress]: [ + { + address: '0x03', + symbol: 'barC', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + ], + [secondAddress]: [ + { + address: '0x04', + symbol: 'barD', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + ], + }, + }, + allIgnoredTokens: {}, + allDetectedTokens: { + [ChainId.mainnet]: { + [firstAddress]: [], + [secondAddress]: [], + }, + }, + }; + await withController( + { + options: { + state: initialState, + }, + listAccounts: [firstAccount, secondAccount], + }, + ({ controller, triggerAccountRemoved }) => { + expect(controller.state).toStrictEqual(initialState); + + triggerAccountRemoved(firstAccount.id); + + expect(controller.state).toStrictEqual({ + allTokens: { + [ChainId.mainnet]: { + [secondAddress]: [ + { + address: '0x04', + symbol: 'barD', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + ], + }, + }, + allIgnoredTokens: {}, + allDetectedTokens: { + [ChainId.mainnet]: { + [secondAddress]: [], + }, + }, + }); + }, + ); + }); + + it('removes an account with no tokens', async () => { + const firstAddress = '0x123'; + const secondAddress = '0x456'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); + const initialState: TokensControllerState = { + allTokens: { + [ChainId.mainnet]: { + [firstAddress]: [ + { + address: '0x03', + symbol: 'barC', + decimals: 2, + aggregators: [], + image: undefined, + name: undefined, + }, + ], + }, + }, + allIgnoredTokens: {}, + allDetectedTokens: { + [ChainId.mainnet]: { + [firstAddress]: [], + }, + }, + }; + await withController( + { + options: { + state: initialState, + }, + listAccounts: [firstAccount, secondAccount], + }, + ({ controller, triggerAccountRemoved }) => { + expect(controller.state).toStrictEqual(initialState); + + triggerAccountRemoved(secondAccount.id); + + expect(controller.state).toStrictEqual(initialState); + }, + ); + }); + }); }); type WithControllerCallback = ({ @@ -3347,6 +3477,7 @@ type WithControllerCallback = ({ messenger, approvalController, triggerSelectedAccountChange, + triggerAccountRemoved, }: { controller: TokensController; changeNetwork: (networkControllerState: { @@ -3355,6 +3486,7 @@ type WithControllerCallback = ({ messenger: UnrestrictedMessenger; approvalController: ApprovalController; triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; + triggerAccountRemoved: (accountId: string) => void; triggerNetworkStateChange: ( networkState: NetworkState, patches: Patch[], @@ -3378,6 +3510,7 @@ type WithControllerArgs = NetworkClientConfiguration >; mocks?: WithControllerMockArgs; + listAccounts?: InternalAccount[]; }, WithControllerCallback, ]; @@ -3403,6 +3536,7 @@ async function withController( options = {}, mockNetworkClientConfigurationsByNetworkClientId = {}, mocks = {} as WithControllerMockArgs, + listAccounts = [], }, fn, ] = args.length === 2 ? args : [{}, args[0]]; @@ -3427,12 +3561,14 @@ async function withController( 'NetworkController:getNetworkClientById', 'AccountsController:getAccount', 'AccountsController:getSelectedAccount', + 'AccountsController:listAccounts', ], allowedEvents: [ 'NetworkController:networkDidChange', 'NetworkController:stateChange', 'AccountsController:selectedEvmAccountChange', 'TokenListController:stateChange', + 'AccountsController:accountRemoved', ], }); @@ -3452,6 +3588,12 @@ async function withController( ), ); + const mockListAccounts = jest.fn().mockReturnValue(listAccounts); + messenger.registerActionHandler( + 'AccountsController:listAccounts', + mockListAccounts, + ); + const controller = new TokensController({ chainId: ChainId.mainnet, // The tests assume that this is set, but they shouldn't make that @@ -3471,6 +3613,10 @@ async function withController( ); }; + const triggerAccountRemoved = (accountId: string) => { + messenger.publish('AccountsController:accountRemoved', accountId); + }; + const changeNetwork = ({ selectedNetworkClientId, }: { @@ -3504,6 +3650,7 @@ async function withController( approvalController, triggerSelectedAccountChange, triggerNetworkStateChange, + triggerAccountRemoved, getAccountHandler, getSelectedAccountHandler, }); diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index 1260f5fd672..3cac5e0920a 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -1,8 +1,10 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; import type { + AccountsControllerAccountRemovedEvent, AccountsControllerGetAccountAction, AccountsControllerGetSelectedAccountAction, + AccountsControllerListAccountsAction, AccountsControllerSelectedEvmAccountChangeEvent, } from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; @@ -124,7 +126,8 @@ export type AllowedActions = | AddApprovalRequest | NetworkControllerGetNetworkClientByIdAction | AccountsControllerGetAccountAction - | AccountsControllerGetSelectedAccountAction; + | AccountsControllerGetSelectedAccountAction + | AccountsControllerListAccountsAction; export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -137,7 +140,8 @@ export type AllowedEvents = | NetworkControllerStateChangeEvent | NetworkControllerNetworkDidChangeEvent | TokenListStateChange - | AccountsControllerSelectedEvmAccountChangeEvent; + | AccountsControllerSelectedEvmAccountChangeEvent + | AccountsControllerAccountRemovedEvent; /** * The messenger of the {@link TokensController}. @@ -223,6 +227,12 @@ export class TokensController extends BaseController< this.#onNetworkStateChange.bind(this), ); + this.messagingSystem.subscribe( + 'AccountsController:accountRemoved', + (accountAddress: string) => + this.#handleOnAccountRemoved(accountAddress as Hex), + ); + this.messagingSystem.subscribe( 'TokenListController:stateChange', ({ tokensChainsCache }) => { @@ -260,6 +270,48 @@ export class TokensController extends BaseController< ); } + #handleOnAccountRemoved(accountId: string) { + // find the account address in allTokens, allDetectedTokens, allIgnoredTokens + const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; + const accounts = this.messagingSystem.call( + 'AccountsController:listAccounts', + ); + const accountAddress = accounts.find( + (account) => account.id === accountId, + )?.address; + + if (!accountAddress) { + return; + } + const newAllTokens = cloneDeep(allTokens); + const newAllDetectedTokens = cloneDeep(allDetectedTokens); + const newAllIgnoredTokens = cloneDeep(allIgnoredTokens); + + for (const chainId of Object.keys(newAllTokens)) { + if (newAllTokens[chainId as Hex][accountAddress]) { + delete newAllTokens[chainId as Hex][accountAddress]; + } + } + + for (const chainId of Object.keys(newAllDetectedTokens)) { + if (newAllDetectedTokens[chainId as Hex][accountAddress]) { + delete newAllDetectedTokens[chainId as Hex][accountAddress]; + } + } + + for (const chainId of Object.keys(newAllIgnoredTokens)) { + if (newAllIgnoredTokens[chainId as Hex][accountAddress]) { + delete newAllIgnoredTokens[chainId as Hex][accountAddress]; + } + } + + this.update((state) => { + state.allTokens = newAllTokens; + state.allIgnoredTokens = newAllIgnoredTokens; + state.allDetectedTokens = newAllDetectedTokens; + }); + } + /** * Handles the event when the network state changes. * @param _ - The network state.